Spaces:
Runtime error
Runtime error
| """Normalization layers used in blocks | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class AdaptiveInstanceNorm2d(nn.Module): | |
| def __init__(self, num_features, eps=1e-5, momentum=0.1): | |
| super(AdaptiveInstanceNorm2d, self).__init__() | |
| self.num_features = num_features | |
| self.eps = eps | |
| self.momentum = momentum | |
| # weight and bias are dynamically assigned | |
| self.weight = None | |
| self.bias = None | |
| # just dummy buffers, not used | |
| self.register_buffer("running_mean", torch.zeros(num_features)) | |
| self.register_buffer("running_var", torch.ones(num_features)) | |
| def forward(self, x): | |
| assert ( | |
| self.weight is not None and self.bias is not None | |
| ), "Please assign weight and bias before calling AdaIN!" | |
| b, c = x.size(0), x.size(1) | |
| running_mean = self.running_mean.repeat(b) | |
| running_var = self.running_var.repeat(b) | |
| # Apply instance norm | |
| x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) | |
| out = F.batch_norm( | |
| x_reshaped, | |
| running_mean, | |
| running_var, | |
| self.weight, | |
| self.bias, | |
| True, | |
| self.momentum, | |
| self.eps, | |
| ) | |
| return out.view(b, c, *x.size()[2:]) | |
| def __repr__(self): | |
| return self.__class__.__name__ + "(" + str(self.num_features) + ")" | |
| class LayerNorm(nn.Module): | |
| def __init__(self, num_features, eps=1e-5, affine=True): | |
| super(LayerNorm, self).__init__() | |
| self.num_features = num_features | |
| self.affine = affine | |
| self.eps = eps | |
| if self.affine: | |
| self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) | |
| self.beta = nn.Parameter(torch.zeros(num_features)) | |
| def forward(self, x): | |
| shape = [-1] + [1] * (x.dim() - 1) | |
| # print(x.size()) | |
| if x.size(0) == 1: | |
| # These two lines run much faster in pytorch 0.4 | |
| # than the two lines listed below. | |
| mean = x.view(-1).mean().view(*shape) | |
| std = x.view(-1).std().view(*shape) | |
| else: | |
| mean = x.view(x.size(0), -1).mean(1).view(*shape) | |
| std = x.view(x.size(0), -1).std(1).view(*shape) | |
| x = (x - mean) / (std + self.eps) | |
| if self.affine: | |
| shape = [1, -1] + [1] * (x.dim() - 2) | |
| x = x * self.gamma.view(*shape) + self.beta.view(*shape) | |
| return x | |
| def l2normalize(v, eps=1e-12): | |
| return v / (v.norm() + eps) | |
| class SpectralNorm(nn.Module): | |
| """ | |
| Based on the paper "Spectral Normalization for Generative Adversarial Networks" | |
| by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida and the | |
| Pytorch implementation: | |
| https://github.com/christiancosgrove/pytorch-spectral-normalization-gan | |
| """ | |
| def __init__(self, module, name="weight", power_iterations=1): | |
| super().__init__() | |
| self.module = module | |
| self.name = name | |
| self.power_iterations = power_iterations | |
| if not self._made_params(): | |
| self._make_params() | |
| def _update_u_v(self): | |
| u = getattr(self.module, self.name + "_u") | |
| v = getattr(self.module, self.name + "_v") | |
| w = getattr(self.module, self.name + "_bar") | |
| height = w.data.shape[0] | |
| for _ in range(self.power_iterations): | |
| v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) | |
| u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) | |
| # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) | |
| sigma = u.dot(w.view(height, -1).mv(v)) | |
| setattr(self.module, self.name, w / sigma.expand_as(w)) | |
| def _made_params(self): | |
| try: | |
| u = getattr(self.module, self.name + "_u") # noqa: F841 | |
| v = getattr(self.module, self.name + "_v") # noqa: F841 | |
| w = getattr(self.module, self.name + "_bar") # noqa: F841 | |
| return True | |
| except AttributeError: | |
| return False | |
| def _make_params(self): | |
| w = getattr(self.module, self.name) | |
| height = w.data.shape[0] | |
| width = w.view(height, -1).data.shape[1] | |
| u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) | |
| v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) | |
| u.data = l2normalize(u.data) | |
| v.data = l2normalize(v.data) | |
| w_bar = nn.Parameter(w.data) | |
| del self.module._parameters[self.name] | |
| self.module.register_parameter(self.name + "_u", u) | |
| self.module.register_parameter(self.name + "_v", v) | |
| self.module.register_parameter(self.name + "_bar", w_bar) | |
| def forward(self, *args): | |
| self._update_u_v() | |
| return self.module.forward(*args) | |
| class SPADE(nn.Module): | |
| def __init__(self, param_free_norm_type, kernel_size, norm_nc, cond_nc): | |
| super().__init__() | |
| if param_free_norm_type == "instance": | |
| self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) | |
| # elif param_free_norm_type == "syncbatch": | |
| # self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) | |
| elif param_free_norm_type == "batch": | |
| self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) | |
| else: | |
| raise ValueError( | |
| "%s is not a recognized param-free norm type in SPADE" | |
| % param_free_norm_type | |
| ) | |
| # The dimension of the intermediate embedding space. Yes, hardcoded. | |
| nhidden = 128 | |
| pw = kernel_size // 2 | |
| self.mlp_shared = nn.Sequential( | |
| nn.Conv2d(cond_nc, nhidden, kernel_size=kernel_size, padding=pw), nn.ReLU() | |
| ) | |
| self.mlp_gamma = nn.Conv2d( | |
| nhidden, norm_nc, kernel_size=kernel_size, padding=pw | |
| ) | |
| self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=kernel_size, padding=pw) | |
| def forward(self, x, segmap): | |
| # Part 1. generate parameter-free normalized activations | |
| normalized = self.param_free_norm(x) | |
| # Part 2. produce scaling and bias conditioned on semantic map | |
| segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") | |
| actv = self.mlp_shared(segmap) | |
| gamma = self.mlp_gamma(actv) | |
| beta = self.mlp_beta(actv) | |
| # apply scale and bias | |
| out = normalized * (1 + gamma) + beta | |
| return out | |