Spaces:
Runtime error
Runtime error
| """Complete Generator architecture: | |
| * OmniGenerator | |
| * Encoder | |
| * Decoders | |
| """ | |
| from pathlib import Path | |
| import traceback | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import yaml | |
| from addict import Dict | |
| from torch import softmax | |
| import climategan.strings as strings | |
| from climategan.deeplab import create_encoder, create_segmentation_decoder | |
| from climategan.depth import create_depth_decoder | |
| from climategan.masker import create_mask_decoder | |
| from climategan.painter import create_painter | |
| from climategan.tutils import init_weights, mix_noise, normalize | |
| def create_generator(opts, device="cpu", latent_shape=None, no_init=False, verbose=0): | |
| G = OmniGenerator(opts, latent_shape, verbose, no_init) | |
| if no_init: | |
| print("Sending to", device) | |
| return G.to(device) | |
| for model in G.decoders: | |
| net = G.decoders[model] | |
| if model == "s": | |
| continue | |
| if isinstance(net, nn.ModuleDict): | |
| for domain, domain_model in net.items(): | |
| init_weights( | |
| net[domain_model], | |
| init_type=opts.gen[model].init_type, | |
| init_gain=opts.gen[model].init_gain, | |
| verbose=verbose, | |
| caller=f"create_generator decoder {model} {domain}", | |
| ) | |
| else: | |
| init_weights( | |
| G.decoders[model], | |
| init_type=opts.gen[model].init_type, | |
| init_gain=opts.gen[model].init_gain, | |
| verbose=verbose, | |
| caller=f"create_generator decoder {model}", | |
| ) | |
| if G.encoder is not None and opts.gen.encoder.architecture == "base": | |
| init_weights( | |
| G.encoder, | |
| init_type=opts.gen.encoder.init_type, | |
| init_gain=opts.gen.encoder.init_gain, | |
| verbose=verbose, | |
| caller="create_generator encoder", | |
| ) | |
| print("Sending to", device) | |
| return G.to(device) | |
| class OmniGenerator(nn.Module): | |
| def __init__(self, opts, latent_shape=None, verbose=0, no_init=False): | |
| """Creates the generator. All decoders listed in opts.gen will be added | |
| to the Generator.decoders ModuleDict if opts.gen.DecoderInitial is not True. | |
| Then can be accessed as G.decoders.T or G.decoders["T"] for instance, | |
| for the image Translation decoder | |
| Args: | |
| opts (addict.Dict): configuration dict | |
| """ | |
| super().__init__() | |
| self.opts = opts | |
| self.verbose = verbose | |
| self.encoder = None | |
| if any(t in opts.tasks for t in "msd"): | |
| self.encoder = create_encoder(opts, no_init, verbose) | |
| self.decoders = {} | |
| self.painter = nn.Module() | |
| if "d" in opts.tasks: | |
| self.decoders["d"] = create_depth_decoder(opts, no_init, verbose) | |
| if self.verbose > 0: | |
| print(f" - Add {self.decoders['d'].__class__.__name__}") | |
| if "s" in opts.tasks: | |
| self.decoders["s"] = create_segmentation_decoder(opts, no_init, verbose) | |
| if "m" in opts.tasks: | |
| self.decoders["m"] = create_mask_decoder(opts, no_init, verbose) | |
| self.decoders = nn.ModuleDict(self.decoders) | |
| if "p" in self.opts.tasks: | |
| self.painter = create_painter(opts, no_init, verbose) | |
| else: | |
| if self.verbose > 0: | |
| print(" - Add Empty Painter") | |
| def device(self): | |
| return next(self.parameters()).device | |
| def __str__(self): | |
| return strings.generator(self) | |
| def encode(self, x): | |
| """ | |
| Forward x through the encoder | |
| Args: | |
| x (torch.Tensor): B3HW input tensor | |
| Returns: | |
| list: High and Low level features from the encoder | |
| """ | |
| assert self.encoder is not None | |
| return self.encoder.forward(x) | |
| def decode(self, x=None, z=None, return_z=False, return_z_depth=False): | |
| """ | |
| Comptutes the predictions of all available decoders from either x or z. | |
| If using spade for the masker with 15 channels, x *must* be provided, | |
| whether z is too or not. | |
| Args: | |
| x (torch.Tensor, optional): Input tensor (B3HW). Defaults to None. | |
| z (list, optional): List of high and low-level features as BCHW. | |
| Defaults to None. | |
| return_z (bool, optional): whether or not to return z in the dict. | |
| Defaults to False. | |
| return_z_depth (bool, optional): whether or not to return z_depth | |
| in the dict. Defaults to False. | |
| Raises: | |
| ValueError: If using spade for the masker with 15 channels but x is None | |
| Returns: | |
| dict: {task: prediction_tensor} (may include z and z_depth | |
| depending on args) | |
| """ | |
| assert x is not None or z is not None | |
| if self.opts.gen.m.use_spade and self.opts.m.spade.cond_nc == 15: | |
| if x is None: | |
| raise ValueError( | |
| "When using spade for the Masker with 15 channels," | |
| + " x MUST be provided" | |
| ) | |
| z_depth = cond = d = s = None | |
| out = {} | |
| if z is None: | |
| z = self.encode(x) | |
| if return_z: | |
| out["z"] = z | |
| if "d" in self.decoders: | |
| d, z_depth = self.decoders["d"](z) | |
| out["d"] = d | |
| if return_z_depth: | |
| out["z_depth"] = z_depth | |
| if "s" in self.decoders: | |
| s = self.decoders["s"](z, z_depth) | |
| out["s"] = s | |
| if "m" in self.decoders: | |
| if s is not None and d is not None: | |
| cond = self.make_m_cond(d, s, x) | |
| m = self.mask(z=z, cond=cond) | |
| out["m"] = m | |
| return out | |
| def sample_painter_z(self, batch_size, device, force_half=False): | |
| if self.opts.gen.p.no_z: | |
| return None | |
| z = torch.empty( | |
| batch_size, | |
| self.opts.gen.p.latent_dim, | |
| self.painter.z_h, | |
| self.painter.z_w, | |
| device=device, | |
| ).normal_(mean=0, std=1.0) | |
| if force_half: | |
| z = z.half() | |
| return z | |
| def make_m_cond(self, d, s, x=None): | |
| """ | |
| Create the masker's conditioning input when using spade from the | |
| d and s predictions and from the input x when cond_nc == 15. | |
| d and s are assumed to have the the same spatial resolution. | |
| if cond_nc == 15 then x is interpolated to match that dimension. | |
| Args: | |
| d (torch.Tensor): Raw depth prediction (B1HW) | |
| s (torch.Tensor): Raw segmentation prediction (BCHW) | |
| x (torch.Tensor, optional): Input tensor (B3hW). Mandatory | |
| when opts.gen.m.spade.cond_nc == 15 | |
| Raises: | |
| ValueError: opts.gen.m.spade.cond_nc == 15 but x is None | |
| Returns: | |
| torch.Tensor: B x cond_nc x H x W conditioning tensor. | |
| """ | |
| if self.opts.gen.m.spade.detach: | |
| d = d.detach() | |
| s = s.detach() | |
| cats = [normalize(d), softmax(s, dim=1)] | |
| if self.opts.gen.m.spade.cond_nc == 15: | |
| if x is None: | |
| raise ValueError( | |
| "When using spade for the Masker with 15 channels," | |
| + " x MUST be provided" | |
| ) | |
| cats += [ | |
| F.interpolate(x, s.shape[-2:], mode="bilinear", align_corners=True) | |
| ] | |
| return torch.cat(cats, dim=1) | |
| def mask(self, x=None, z=None, cond=None, z_depth=None, sigmoid=True): | |
| """ | |
| Create a mask from either an input x or a latent vector z. | |
| Optionally if the Masker has a spade architecture the conditioning tensor | |
| may be provided (cond). Default behavior applies an element-wise | |
| sigmoid, but can be deactivated (sigmoid=False). | |
| At least one of x or z must be provided (i.e. not None). | |
| If the Masker has a spade architecture and cond_nc == 15 then x cannot | |
| be None. | |
| Args: | |
| x (torch.Tensor, optional): Input tensor B3HW. Defaults to None. | |
| z (list, optional): High and Low level features of the encoder. | |
| Will be computed if None. Defaults to None. | |
| cond ([type], optional): [description]. Defaults to None. | |
| sigmoid (bool, optional): [description]. Defaults to True. | |
| Returns: | |
| torch.Tensor: B1HW mask tensor | |
| """ | |
| assert x is not None or z is not None | |
| if z is None: | |
| z = self.encode(x) | |
| if cond is None and self.opts.gen.m.use_spade: | |
| assert "s" in self.opts.tasks and "d" in self.opts.tasks | |
| with torch.no_grad(): | |
| d_pred, z_d = self.decoders["d"](z) | |
| s_pred = self.decoders["s"](z, z_d) | |
| cond = self.make_m_cond(d_pred, s_pred, x) | |
| if z_depth is None and self.opts.gen.m.use_dada: | |
| assert "d" in self.opts.tasks | |
| with torch.no_grad(): | |
| _, z_depth = self.decoders["d"](z) | |
| if cond is not None: | |
| device = z[0].device if isinstance(z, (tuple, list)) else z.device | |
| cond = cond.to(device) | |
| logits = self.decoders["m"](z, cond, z_depth) | |
| if not sigmoid: | |
| return logits | |
| return torch.sigmoid(logits) | |
| def paint(self, m, x, no_paste=False): | |
| """ | |
| Paints given a mask and an image | |
| calls painter(z, x * (1.0 - m)) | |
| Mask has 1s where water should be painted | |
| Args: | |
| m (torch.Tensor): Mask | |
| x (torch.Tensor): Image to paint | |
| Returns: | |
| torch.Tensor: painted image | |
| """ | |
| z_paint = self.sample_painter_z(x.shape[0], x.device) | |
| m = m.to(x.dtype) | |
| fake = self.painter(z_paint, x * (1.0 - m)) | |
| if self.opts.gen.p.paste_original_content and not no_paste: | |
| return x * (1.0 - m) + fake * m | |
| return fake | |
| def paint_cloudy(self, m, x, s, sky_idx=9, res=(8, 8), weight=0.8): | |
| """ | |
| Paints x with water in m through an intermediary cloudy image | |
| where the sky has been replaced with perlin noise to imitate clouds. | |
| The intermediary cloudy image is only used to control the painter's | |
| painting mode, probing it with a cloudy input. | |
| Args: | |
| m (torch.Tensor): water mask | |
| x (torch.Tensor): input tensor | |
| s (torch.Tensor): segmentation prediction (BCHW) | |
| sky_idx (int, optional): Index of the sky class along s's C dimension. | |
| Defaults to 9. | |
| res (tuple, optional): Perlin noise spatial resolution. Defaults to (8, 8). | |
| weight (float, optional): Intermediate image's cloud proportion | |
| (w * cloud + (1-w) * original_sky). Defaults to 0.8. | |
| Returns: | |
| torch.Tensor: painted image with original content pasted. | |
| """ | |
| sky_mask = ( | |
| torch.argmax( | |
| F.interpolate(s, x.shape[-2:], mode="bilinear"), dim=1, keepdim=True | |
| ) | |
| == sky_idx | |
| ).to(x.dtype) | |
| noised_x = mix_noise(x, sky_mask, res=res, weight=weight).to(x.dtype) | |
| fake = self.paint(m, noised_x, no_paste=True) | |
| return x * (1.0 - m) + fake * m | |
| def depth(self, x=None, z=None, return_z_depth=False): | |
| """ | |
| Compute the depth head's output | |
| Args: | |
| x (torch.Tensor, optional): Input B3HW tensor. Defaults to None. | |
| z (list, optional): High and Low level features of the encoder. | |
| Defaults to None. | |
| Returns: | |
| torch.Tensor: B1HW tensor of depth predictions | |
| """ | |
| assert x is not None or z is not None | |
| assert not (x is not None and z is not None) | |
| if z is None: | |
| z = self.encode(x) | |
| depth, z_depth = self.decoders["d"](z) | |
| if depth.shape[1] > 1: | |
| depth = torch.argmax(depth, dim=1) | |
| depth = depth / depth.max() | |
| if return_z_depth: | |
| return depth, z_depth | |
| return depth | |
| def load_val_painter(self): | |
| """ | |
| Loads a validation painter if available in opts.val.val_painter | |
| Returns: | |
| bool: operation success status | |
| """ | |
| try: | |
| # key exists in opts | |
| assert self.opts.val.val_painter | |
| # path exists | |
| ckpt_path = Path(self.opts.val.val_painter).resolve() | |
| assert ckpt_path.exists() | |
| # path is a checkpoint path | |
| assert ckpt_path.is_file() | |
| # opts are available in that path | |
| opts_path = ckpt_path.parent.parent / "opts.yaml" | |
| assert opts_path.exists() | |
| # load opts | |
| with opts_path.open("r") as f: | |
| val_painter_opts = Dict(yaml.safe_load(f)) | |
| # load checkpoint | |
| state_dict = torch.load(ckpt_path, map_location=self.device) | |
| # create dummy painter from loaded opts | |
| painter = create_painter(val_painter_opts) | |
| # load state-dict in the dummy painter | |
| painter.load_state_dict( | |
| {k.replace("painter.", ""): v for k, v in state_dict["G"].items()} | |
| ) | |
| # send to current device in evaluation mode | |
| device = next(self.parameters()).device | |
| self.painter = painter.eval().to(device) | |
| # disable gradients | |
| for p in self.painter.parameters(): | |
| p.requires_grad = False | |
| # success | |
| print(" - Loaded validation-only painter") | |
| return True | |
| except Exception as e: | |
| # something happened, aborting gracefully | |
| print(traceback.format_exc()) | |
| print(e) | |
| print(">>> WARNING: error (^) in load_val_painter, aborting.") | |
| return False | |