Spaces:
Runtime error
Runtime error
| """ | |
| Main component: the trainer handles everything: | |
| * initializations | |
| * training | |
| * saving | |
| """ | |
| import inspect | |
| import warnings | |
| from copy import deepcopy | |
| from pathlib import Path | |
| from time import time | |
| import numpy as np | |
| from comet_ml import ExistingExperiment, Experiment | |
| warnings.simplefilter("ignore", UserWarning) | |
| import torch | |
| import torch.nn as nn | |
| from addict import Dict | |
| from torch import autograd, sigmoid, softmax | |
| from torch.cuda.amp import GradScaler, autocast | |
| from tqdm import tqdm | |
| from climategan.data import get_all_loaders, decode_segmap_merged_labels | |
| from climategan.discriminator import OmniDiscriminator, create_discriminator | |
| from climategan.eval_metrics import accuracy, mIOU | |
| from climategan.fid import compute_val_fid | |
| from climategan.fire import add_fire | |
| from climategan.generator import OmniGenerator, create_generator | |
| from climategan.logger import Logger | |
| from climategan.losses import get_losses | |
| from climategan.optim import get_optimizer | |
| from climategan.transforms import DiffTransforms | |
| from climategan.tutils import ( | |
| divide_pred, | |
| get_num_params, | |
| get_WGAN_gradient, | |
| lrgb2srgb, | |
| normalize, | |
| print_num_parameters, | |
| shuffle_batch_tuple, | |
| srgb2lrgb, | |
| tensor_to_uint8_numpy_image, | |
| vgg_preprocess, | |
| zero_grad, | |
| ) | |
| from climategan.utils import ( | |
| comet_kwargs, | |
| div_dict, | |
| find_target_size, | |
| flatten_opts, | |
| get_display_indices, | |
| get_existing_comet_id, | |
| get_latest_opts, | |
| merge, | |
| resolve, | |
| sum_dict, | |
| Timer, | |
| ) | |
| try: | |
| import torch_xla.core.xla_model as xm # type: ignore | |
| except ImportError: | |
| pass | |
| class Trainer: | |
| """Main trainer class""" | |
| def __init__(self, opts, comet_exp=None, verbose=0, device=None): | |
| """Trainer class to gather various model training procedures | |
| such as training evaluating saving and logging | |
| init: | |
| * creates an addict.Dict logger | |
| * creates logger.exp as a comet_exp experiment if `comet` arg is True | |
| * sets the device (1 GPU or CPU) | |
| Args: | |
| opts (addict.Dict): options to configure the trainer, the data, the models | |
| comet (bool, optional): whether to log the trainer with comet.ml. | |
| Defaults to False. | |
| verbose (int, optional): printing level to debug. Defaults to 0. | |
| """ | |
| super().__init__() | |
| self.opts = opts | |
| self.verbose = verbose | |
| self.logger = Logger(self) | |
| self.losses = None | |
| self.G = self.D = None | |
| self.real_val_fid_stats = None | |
| self.use_pl4m = False | |
| self.is_setup = False | |
| self.loaders = self.all_loaders = None | |
| self.exp = None | |
| self.current_mode = "train" | |
| self.diff_transforms = None | |
| self.kitti_pretrain = self.opts.train.kitti.pretrain | |
| self.pseudo_training_tasks = set(self.opts.train.pseudo.tasks) | |
| self.lr_names = {} | |
| self.base_display_images = {} | |
| self.kitty_display_images = {} | |
| self.domain_labels = {"s": 0, "r": 1} | |
| self.device = device or torch.device( | |
| "cuda:0" if torch.cuda.is_available() else "cpu" | |
| ) | |
| if isinstance(comet_exp, Experiment): | |
| self.exp = comet_exp | |
| if self.opts.train.amp: | |
| optimizers = [ | |
| self.opts.gen.opt.optimizer.lower(), | |
| self.opts.dis.opt.optimizer.lower(), | |
| ] | |
| if "extraadam" in optimizers: | |
| raise ValueError( | |
| "AMP does not work with ExtraAdam ({})".format(optimizers) | |
| ) | |
| self.grad_scaler_d = GradScaler() | |
| self.grad_scaler_g = GradScaler() | |
| # ------------------------------- | |
| # ----- Legacy Overwrites ----- | |
| # ------------------------------- | |
| if ( | |
| self.opts.gen.s.depth_feat_fusion is True | |
| or self.opts.gen.s.depth_dada_fusion is True | |
| ): | |
| self.opts.gen.s.use_dada = True | |
| def paint_and_mask(self, image_batch, mask_batch=None, resolution="approx"): | |
| """ | |
| Paints a batch of images (or a single image with a batch dim of 1). If | |
| masks are not provided, they are inferred from the masker. | |
| Resolution can either be the train-time resolution or the closest | |
| multiple of 2 ** spade_n_up | |
| Operations performed without gradient | |
| If resolution == "approx" then the output image has the shape: | |
| (dim // 2 ** spade_n_up) * 2 ** spade_n_up, for dim in [height, width] | |
| eg: (1000, 1300) => (896, 1280) for spade_n_up = 7 | |
| If resolution == "exact" then the output image has the same shape: | |
| we first process in "approx" mode then upsample bilinear | |
| If resolution == "basic" image output shape is the train-time's | |
| (typically 640x640) | |
| If resolution == "upsample" image is inferred as "basic" and | |
| then upsampled to original size | |
| Args: | |
| image_batch (torch.Tensor): 4D batch of images to flood | |
| mask_batch (torch.Tensor, optional): Masks for the images. | |
| Defaults to None (infer with Masker). | |
| resolution (str, optional): "approx", "exact" or False | |
| Returns: | |
| torch.Tensor: N x C x H x W where H and W depend on `resolution` | |
| """ | |
| assert resolution in {"approx", "exact", "basic", "upsample"} | |
| previous_mode = self.current_mode | |
| if previous_mode == "train": | |
| self.eval_mode() | |
| if mask_batch is None: | |
| mask_batch = self.G.mask(x=image_batch) | |
| else: | |
| assert len(image_batch) == len(mask_batch) | |
| assert image_batch.shape[-2:] == mask_batch.shape[-2:] | |
| if resolution not in {"approx", "exact"}: | |
| painted = self.G.paint(mask_batch, image_batch) | |
| if resolution == "upsample": | |
| painted = nn.functional.interpolate( | |
| painted, size=image_batch.shape[-2:], mode="bilinear" | |
| ) | |
| else: | |
| # save latent shape | |
| zh = self.G.painter.z_h | |
| zw = self.G.painter.z_w | |
| # adapt latent shape to approximately keep the resolution | |
| self.G.painter.z_h = ( | |
| image_batch.shape[-2] // 2**self.opts.gen.p.spade_n_up | |
| ) | |
| self.G.painter.z_w = ( | |
| image_batch.shape[-1] // 2**self.opts.gen.p.spade_n_up | |
| ) | |
| painted = self.G.paint(mask_batch, image_batch) | |
| self.G.painter.z_h = zh | |
| self.G.painter.z_w = zw | |
| if resolution == "exact": | |
| painted = nn.functional.interpolate( | |
| painted, size=image_batch.shape[-2:], mode="bilinear" | |
| ) | |
| if previous_mode == "train": | |
| self.train_mode() | |
| return painted | |
| def _p(self, *args, **kwargs): | |
| """ | |
| verbose-dependant print util | |
| """ | |
| if self.verbose > 0: | |
| print(*args, **kwargs) | |
| def infer_all( | |
| self, | |
| x, | |
| numpy=True, | |
| stores={}, | |
| bin_value=-1, | |
| half=False, | |
| xla=False, | |
| cloudy=True, | |
| auto_resize_640=False, | |
| ignore_event=set(), | |
| return_intermediates=False, | |
| ): | |
| """ | |
| Create a dictionary of events from a numpy or tensor, | |
| single or batch image data. | |
| stores is a dictionary of times for the Timer class. | |
| bin_value is used to binarize (or not) flood masks | |
| all values in the output dictionary have 4 dimensions: | |
| BxHxWxC if numpy else BxCxHxW | |
| """ | |
| assert self.is_setup | |
| assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}" | |
| # convert numpy to tensor | |
| if not isinstance(x, torch.Tensor): | |
| x = torch.tensor(x, device=self.device) | |
| # add batch dimension | |
| if len(x.shape) == 3: | |
| x.unsqueeze_(0) | |
| # permute channels as second dimension | |
| if x.shape[1] != 3: | |
| assert x.shape[-1] == 3, f"Unknown x shape to permute {x.shape}" | |
| x = x.permute(0, 3, 1, 2) | |
| # send to device | |
| if x.device != self.device: | |
| x = x.to(self.device) | |
| # interpolate to standard input size | |
| if auto_resize_640 and (x.shape[-1] != 640 or x.shape[-2] != 640): | |
| x = torch.nn.functional.interpolate(x, (640, 640), mode="bilinear") | |
| if half: | |
| x = x.half() | |
| # adjust painter's latent vector | |
| self.G.painter.set_latent_shape(x.shape, True) | |
| with Timer(store=stores.get("all events", [])): | |
| # encode | |
| with Timer(store=stores.get("encode", [])): | |
| z = self.G.encode(x) | |
| if xla: | |
| xm.mark_step() | |
| # predict from masker | |
| with Timer(store=stores.get("depth", [])): | |
| depth, z_depth = self.G.decoders["d"](z) | |
| if xla: | |
| xm.mark_step() | |
| with Timer(store=stores.get("segmentation", [])): | |
| segmentation = self.G.decoders["s"](z, z_depth) | |
| if xla: | |
| xm.mark_step() | |
| with Timer(store=stores.get("mask", [])): | |
| cond = self.G.make_m_cond(depth, segmentation, x) | |
| mask = self.G.mask(z=z, cond=cond, z_depth=z_depth) | |
| if xla: | |
| xm.mark_step() | |
| # apply events | |
| if "wildfire" not in ignore_event: | |
| with Timer(store=stores.get("wildfire", [])): | |
| wildfire = self.compute_fire(x, seg_preds=segmentation) | |
| if "smog" not in ignore_event: | |
| with Timer(store=stores.get("smog", [])): | |
| smog = self.compute_smog(x, d=depth, s=segmentation) | |
| if "flood" not in ignore_event: | |
| with Timer(store=stores.get("flood", [])): | |
| flood = self.compute_flood( | |
| x, | |
| m=mask, | |
| s=segmentation, | |
| cloudy=cloudy, | |
| bin_value=bin_value, | |
| ) | |
| if xla: | |
| xm.mark_step() | |
| output_data = {} | |
| if numpy: | |
| with Timer(store=stores.get("numpy", [])): | |
| if "flood" not in ignore_event: | |
| # normalize to 0-1 | |
| flood = tensor_to_uint8_numpy_image(flood) | |
| # convert to 0-255 uint8 | |
| output_data["flood"] = flood | |
| if "wildfire" not in ignore_event: | |
| wildfire = tensor_to_uint8_numpy_image(wildfire) | |
| output_data["wildfire"] = wildfire | |
| if "smog" not in ignore_event: | |
| smog = tensor_to_uint8_numpy_image(smog) | |
| output_data["smog"] = smog | |
| if return_intermediates: | |
| if numpy: | |
| output_data["mask"] = ( | |
| ((mask > bin_value) * 255).cpu().numpy().astype(np.uint8) | |
| ) | |
| output_data["depth"] = tensor_to_uint8_numpy_image(depth) | |
| output_data["segmentation"] = ( | |
| decode_segmap_merged_labels(segmentation, "r", False) | |
| .cpu() | |
| .permute(0, 2, 3, 1) | |
| .numpy() | |
| .astype(np.uint8) | |
| ) | |
| else: | |
| output_data["mask"] = mask | |
| output_data["depth"] = depth | |
| output_data["segmentation"] = segmentation | |
| return output_data | |
| def resume_from_path( | |
| cls, | |
| path, | |
| overrides={}, | |
| setup=True, | |
| inference=False, | |
| new_exp=False, | |
| device=None, | |
| verbose=1, | |
| ): | |
| """ | |
| Resume and optionally setup a trainer from a specific path, | |
| using the latest opts and checkpoint. Requires path to contain opts.yaml | |
| (or increased), url.txt (or increased) and checkpoints/ | |
| Args: | |
| path (str | pathlib.Path): Trainer to resume | |
| overrides (dict, optional): Override loaded opts with those. Defaults to {}. | |
| setup (bool, optional): Wether or not to setup the trainer before | |
| returning it. Defaults to True. | |
| inference (bool, optional): Setup should be done in inference mode or not. | |
| Defaults to False. | |
| new_exp (bool, optional): Re-use existing comet exp in path or create | |
| a new one? Defaults to False. | |
| device (torch.device, optional): Device to use | |
| Returns: | |
| climategan.Trainer: Loaded and resumed trainer | |
| """ | |
| p = resolve(path) | |
| assert p.exists() | |
| c = p / "checkpoints" | |
| assert c.exists() and c.is_dir() | |
| opts = get_latest_opts(p) | |
| opts = Dict(merge(overrides, opts)) | |
| opts.train.resume = True | |
| if new_exp is None: | |
| exp = None | |
| elif new_exp is True: | |
| exp = Experiment(project_name="climategan", **comet_kwargs) | |
| exp.log_asset_folder( | |
| str(resolve(Path(__file__)).parent), | |
| recursive=True, | |
| log_file_name=True, | |
| ) | |
| exp.log_parameters(flatten_opts(opts)) | |
| else: | |
| comet_id = get_existing_comet_id(p) | |
| exp = ExistingExperiment(previous_experiment=comet_id, **comet_kwargs) | |
| trainer = cls(opts, comet_exp=exp, device=device, verbose=verbose) | |
| if setup: | |
| trainer.setup(inference=inference) | |
| return trainer | |
| def save(self): | |
| save_dir = Path(self.opts.output_path) / Path("checkpoints") | |
| save_dir.mkdir(exist_ok=True) | |
| save_path = save_dir / "latest_ckpt.pth" | |
| # Construct relevant state dicts / optims: | |
| # Save at least G | |
| save_dict = { | |
| "epoch": self.logger.epoch, | |
| "G": self.G.state_dict(), | |
| "g_opt": self.g_opt.state_dict(), | |
| "step": self.logger.global_step, | |
| } | |
| if self.D is not None and get_num_params(self.D) > 0: | |
| save_dict["D"] = self.D.state_dict() | |
| save_dict["d_opt"] = self.d_opt.state_dict() | |
| if ( | |
| self.logger.epoch >= self.opts.train.min_save_epoch | |
| and self.logger.epoch % self.opts.train.save_n_epochs == 0 | |
| ): | |
| torch.save(save_dict, save_dir / f"epoch_{self.logger.epoch}_ckpt.pth") | |
| torch.save(save_dict, save_path) | |
| def resume(self, inference=False): | |
| tpu = "xla" in str(self.device) | |
| if tpu: | |
| print("Resuming on TPU:", self.device) | |
| m_path = Path(self.opts.load_paths.m) | |
| p_path = Path(self.opts.load_paths.p) | |
| pm_path = Path(self.opts.load_paths.pm) | |
| output_path = Path(self.opts.output_path) | |
| map_loc = self.device if not tpu else "cpu" | |
| if "m" in self.opts.tasks and "p" in self.opts.tasks: | |
| # ---------------------------------------- | |
| # ----- Masker and Painter Loading ----- | |
| # ---------------------------------------- | |
| # want to resume a pm model but no path was provided: | |
| # resume a single pm model from output_path | |
| if all([str(p) == "none" for p in [m_path, p_path, pm_path]]): | |
| checkpoint_path = output_path / "checkpoints/latest_ckpt.pth" | |
| print("Resuming P+M model from", str(checkpoint_path)) | |
| checkpoint = torch.load(checkpoint_path, map_location=map_loc) | |
| # want to resume a pm model with a pm_path provided: | |
| # resume a single pm model from load_paths.pm | |
| # depending on whether a dir or a file is specified | |
| elif str(pm_path) != "none": | |
| assert pm_path.exists() | |
| if pm_path.is_dir(): | |
| checkpoint_path = pm_path / "checkpoints/latest_ckpt.pth" | |
| else: | |
| assert pm_path.suffix == ".pth" | |
| checkpoint_path = pm_path | |
| print("Resuming P+M model from", str(checkpoint_path)) | |
| checkpoint = torch.load(checkpoint_path, map_location=map_loc) | |
| # want to resume a pm model, pm_path not provided: | |
| # m_path and p_path must be provided as dirs or pth files | |
| elif m_path != p_path: | |
| assert m_path.exists() | |
| assert p_path.exists() | |
| if m_path.is_dir(): | |
| m_path = m_path / "checkpoints/latest_ckpt.pth" | |
| if p_path.is_dir(): | |
| p_path = p_path / "checkpoints/latest_ckpt.pth" | |
| assert m_path.suffix == ".pth" | |
| assert p_path.suffix == ".pth" | |
| print(f"Resuming P+M model from \n -{p_path} \nand \n -{m_path}") | |
| m_checkpoint = torch.load(m_path, map_location=map_loc) | |
| p_checkpoint = torch.load(p_path, map_location=map_loc) | |
| checkpoint = merge(m_checkpoint, p_checkpoint) | |
| else: | |
| raise ValueError( | |
| "Cannot resume a P+M model with provided load_paths:\n{}".format( | |
| self.opts.load_paths | |
| ) | |
| ) | |
| else: | |
| # ---------------------------------- | |
| # ----- Single Model Loading ----- | |
| # ---------------------------------- | |
| # cannot specify both paths | |
| if str(m_path) != "none" and str(p_path) != "none": | |
| raise ValueError( | |
| "Opts tasks are {} but received 2 values for the load_paths".format( | |
| self.opts.tasks | |
| ) | |
| ) | |
| # specified m | |
| elif str(m_path) != "none": | |
| assert m_path.exists() | |
| assert "m" in self.opts.tasks | |
| model = "M" | |
| if m_path.is_dir(): | |
| m_path = m_path / "checkpoints/latest_ckpt.pth" | |
| checkpoint_path = m_path | |
| # specified m | |
| elif str(p_path) != "none": | |
| assert p_path.exists() | |
| assert "p" in self.opts.tasks | |
| model = "P" | |
| if p_path.is_dir(): | |
| p_path = p_path / "checkpoints/latest_ckpt.pth" | |
| checkpoint_path = p_path | |
| # specified neither p nor m: resume from output_path | |
| else: | |
| model = "P" if "p" in self.opts.tasks else "M" | |
| checkpoint_path = output_path / "checkpoints/latest_ckpt.pth" | |
| print(f"Resuming {model} model from {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location=map_loc) | |
| # On TPUs must send the data to the xla device as it cannot be mapped | |
| # there directly from torch.load | |
| if tpu: | |
| checkpoint = xm.send_cpu_data_to_device(checkpoint, self.device) | |
| # ----------------------- | |
| # ----- Restore G ----- | |
| # ----------------------- | |
| if inference: | |
| incompatible_keys = self.G.load_state_dict(checkpoint["G"], strict=False) | |
| if incompatible_keys.missing_keys: | |
| print("WARNING: Missing keys in self.G.load_state_dict, keeping inits") | |
| print(incompatible_keys.missing_keys) | |
| if incompatible_keys.unexpected_keys: | |
| print("WARNING: Ignoring Unexpected keys in self.G.load_state_dict") | |
| print(incompatible_keys.unexpected_keys) | |
| else: | |
| self.G.load_state_dict(checkpoint["G"]) | |
| if inference: | |
| # only G is needed to infer | |
| print("Done loading checkpoints.") | |
| return | |
| self.g_opt.load_state_dict(checkpoint["g_opt"]) | |
| # ------------------------------ | |
| # ----- Resume scheduler ----- | |
| # ------------------------------ | |
| # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822 | |
| for _ in range(self.logger.epoch + 1): | |
| self.update_learning_rates() | |
| # ----------------------- | |
| # ----- Restore D ----- | |
| # ----------------------- | |
| if self.D is not None and get_num_params(self.D) > 0: | |
| self.D.load_state_dict(checkpoint["D"]) | |
| self.d_opt.load_state_dict(checkpoint["d_opt"]) | |
| # --------------------------- | |
| # ----- Resore logger ----- | |
| # --------------------------- | |
| self.logger.epoch = checkpoint["epoch"] | |
| self.logger.global_step = checkpoint["step"] | |
| self.exp.log_text( | |
| "Resuming from epoch {} & step {}".format( | |
| checkpoint["epoch"], checkpoint["step"] | |
| ) | |
| ) | |
| # Round step to even number for extraGradient | |
| if self.logger.global_step % 2 != 0: | |
| self.logger.global_step += 1 | |
| def eval_mode(self): | |
| """ | |
| Set trainer's models in eval mode | |
| """ | |
| if self.G is not None: | |
| self.G.eval() | |
| if self.D is not None: | |
| self.D.eval() | |
| self.current_mode = "eval" | |
| def train_mode(self): | |
| """ | |
| Set trainer's models in train mode | |
| """ | |
| if self.G is not None: | |
| self.G.train() | |
| if self.D is not None: | |
| self.D.train() | |
| self.current_mode = "train" | |
| def assert_z_matches_x(self, x, z): | |
| assert x.shape[0] == ( | |
| z.shape[0] if not isinstance(z, (list, tuple)) else z[0].shape[0] | |
| ), "x-> {}, z->{}".format( | |
| x.shape, z.shape if not isinstance(z, (list, tuple)) else z[0].shape | |
| ) | |
| def batch_to_device(self, b): | |
| """sends the data in b to self.device | |
| Args: | |
| b (dict): the batch dictionnay | |
| Returns: | |
| dict: the batch dictionnary with its "data" field sent to self.device | |
| """ | |
| for task, tensor in b["data"].items(): | |
| b["data"][task] = tensor.to(self.device) | |
| return b | |
| def sample_painter_z(self, batch_size): | |
| return self.G.sample_painter_z(batch_size, self.device) | |
| def train_loaders(self): | |
| """Get a zip of all training loaders | |
| Returns: | |
| generator: zip generator yielding tuples: | |
| (batch_rf, batch_rn, batch_sf, batch_sn) | |
| """ | |
| return zip(*list(self.loaders["train"].values())) | |
| def val_loaders(self): | |
| """Get a zip of all validation loaders | |
| Returns: | |
| generator: zip generator yielding tuples: | |
| (batch_rf, batch_rn, batch_sf, batch_sn) | |
| """ | |
| return zip(*list(self.loaders["val"].values())) | |
| def compute_latent_shape(self): | |
| """Compute the latent shape, i.e. the Encoder's output shape, | |
| from a batch. | |
| Raises: | |
| ValueError: If no loader, the latent_shape cannot be inferred | |
| Returns: | |
| tuple: (c, h, w) | |
| """ | |
| x = None | |
| for mode in self.all_loaders: | |
| for domain in self.all_loaders.loaders[mode]: | |
| x = ( | |
| self.all_loaders[mode][domain] | |
| .dataset[0]["data"]["x"] | |
| .to(self.device) | |
| ) | |
| break | |
| if x is not None: | |
| break | |
| if x is None: | |
| raise ValueError("No batch found to compute_latent_shape") | |
| x = x.unsqueeze(0) | |
| z = self.G.encode(x) | |
| return z.shape[1:] if not isinstance(z, (list, tuple)) else z[0].shape[1:] | |
| def g_opt_step(self): | |
| """Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation | |
| step every other step | |
| """ | |
| if "extra" in self.opts.gen.opt.optimizer.lower() and ( | |
| self.logger.global_step % 2 == 0 | |
| ): | |
| self.g_opt.extrapolation() | |
| else: | |
| self.g_opt.step() | |
| def d_opt_step(self): | |
| """Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation | |
| step every other step | |
| """ | |
| if "extra" in self.opts.dis.opt.optimizer.lower() and ( | |
| self.logger.global_step % 2 == 0 | |
| ): | |
| self.d_opt.extrapolation() | |
| else: | |
| self.d_opt.step() | |
| def update_learning_rates(self): | |
| if self.g_scheduler is not None: | |
| self.g_scheduler.step() | |
| if self.d_scheduler is not None: | |
| self.d_scheduler.step() | |
| def setup(self, inference=False): | |
| """Prepare the trainer before it can be used to train the models: | |
| * initialize G and D | |
| * creates 2 optimizers | |
| """ | |
| self.logger.global_step = 0 | |
| start_time = time() | |
| self.logger.time.start_time = start_time | |
| verbose = self.verbose | |
| if not inference: | |
| self.all_loaders = get_all_loaders(self.opts) | |
| # ----------------------- | |
| # ----- Generator ----- | |
| # ----------------------- | |
| __t = time() | |
| print("Creating generator...") | |
| self.G: OmniGenerator = create_generator( | |
| self.opts, device=self.device, no_init=inference, verbose=verbose | |
| ) | |
| self.has_painter = get_num_params(self.G.painter) or self.G.load_val_painter() | |
| if self.has_painter: | |
| self.G.painter.set_latent_shape(find_target_size(self.opts, "x"), True) | |
| print(f"Generator OK in {time() - __t:.1f}s.") | |
| if inference: # Inference mode: no more than a Generator needed | |
| print("Inference mode: no Discriminator, no optimizers") | |
| print_num_parameters(self) | |
| self.switch_data(to="base") | |
| if self.opts.train.resume: | |
| self.resume(True) | |
| self.eval_mode() | |
| print("Trainer is in evaluation mode.") | |
| print("Setup done.") | |
| self.is_setup = True | |
| return | |
| # --------------------------- | |
| # ----- Discriminator ----- | |
| # --------------------------- | |
| self.D: OmniDiscriminator = create_discriminator( | |
| self.opts, self.device, verbose=verbose | |
| ) | |
| print("Discriminator OK.") | |
| print_num_parameters(self) | |
| # -------------------------- | |
| # ----- Optimization ----- | |
| # -------------------------- | |
| # Get different optimizers for each task (different learning rates) | |
| self.g_opt, self.g_scheduler, self.lr_names["G"] = get_optimizer( | |
| self.G, self.opts.gen.opt, self.opts.tasks | |
| ) | |
| if get_num_params(self.D) > 0: | |
| self.d_opt, self.d_scheduler, self.lr_names["D"] = get_optimizer( | |
| self.D, self.opts.dis.opt, self.opts.tasks, True | |
| ) | |
| else: | |
| self.d_opt, self.d_scheduler = None, None | |
| self.losses = get_losses(self.opts, verbose, device=self.device) | |
| if "p" in self.opts.tasks and self.opts.gen.p.diff_aug.use: | |
| self.diff_transforms = DiffTransforms(self.opts.gen.p.diff_aug) | |
| if verbose > 0: | |
| for mode, mode_dict in self.all_loaders.items(): | |
| for domain, domain_loader in mode_dict.items(): | |
| print( | |
| "Loader {} {} : {}".format( | |
| mode, domain, len(domain_loader.dataset) | |
| ) | |
| ) | |
| # ---------------------------- | |
| # ----- Display images ----- | |
| # ---------------------------- | |
| self.set_display_images() | |
| # ------------------------------- | |
| # ----- Log Architectures ----- | |
| # ------------------------------- | |
| self.logger.log_architecture() | |
| # ----------------------------- | |
| # ----- Set data source ----- | |
| # ----------------------------- | |
| if self.kitti_pretrain: | |
| self.switch_data(to="kitti") | |
| else: | |
| self.switch_data(to="base") | |
| # ------------------------- | |
| # ----- Setup Done. ----- | |
| # ------------------------- | |
| print(" " * 50, end="\r") | |
| print("Done creating display images") | |
| if self.opts.train.resume: | |
| print("Resuming Model (inference: False)") | |
| self.resume(False) | |
| else: | |
| print("Not resuming: starting a new model") | |
| print("Setup done.") | |
| self.is_setup = True | |
| def switch_data(self, to="kitti"): | |
| caller = inspect.stack()[1].function | |
| print(f"[{caller}] Switching data source to", to) | |
| self.data_source = to | |
| if to == "kitti": | |
| self.display_images = self.kitty_display_images | |
| if self.all_loaders is not None: | |
| self.loaders = { | |
| mode: {"s": self.all_loaders[mode]["kitti"]} | |
| for mode in self.all_loaders | |
| } | |
| else: | |
| self.display_images = self.base_display_images | |
| if self.all_loaders is not None: | |
| self.loaders = { | |
| mode: { | |
| domain: self.all_loaders[mode][domain] | |
| for domain in self.all_loaders[mode] | |
| if domain != "kitti" | |
| } | |
| for mode in self.all_loaders | |
| } | |
| if ( | |
| self.logger.global_step % 2 != 0 | |
| and "extra" in self.opts.dis.opt.optimizer.lower() | |
| ): | |
| print( | |
| "Warning: artificially bumping step to run an extrapolation step first." | |
| ) | |
| self.logger.global_step += 1 | |
| def set_display_images(self, use_all=False): | |
| for mode, mode_dict in self.all_loaders.items(): | |
| if self.kitti_pretrain: | |
| self.kitty_display_images[mode] = {} | |
| self.base_display_images[mode] = {} | |
| for domain in mode_dict: | |
| if self.kitti_pretrain and domain == "kitti": | |
| target_dict = self.kitty_display_images | |
| else: | |
| if domain == "kitti": | |
| continue | |
| target_dict = self.base_display_images | |
| dataset = self.all_loaders[mode][domain].dataset | |
| display_indices = ( | |
| get_display_indices(self.opts, domain, len(dataset)) | |
| if not use_all | |
| else list(range(len(dataset))) | |
| ) | |
| ldis = len(display_indices) | |
| print( | |
| f" Creating {ldis} {mode} {domain} display images...", | |
| end="\r", | |
| flush=True, | |
| ) | |
| target_dict[mode][domain] = [ | |
| Dict(dataset[i]) | |
| for i in display_indices | |
| if (print(f"({i})", end="\r") is None and i < len(dataset)) | |
| ] | |
| if self.exp is not None: | |
| for im_id, d in enumerate(target_dict[mode][domain]): | |
| self.exp.log_parameter( | |
| "display_image_{}_{}_{}".format(mode, domain, im_id), | |
| d["paths"], | |
| ) | |
| def train(self): | |
| """For each epoch: | |
| * train | |
| * eval | |
| * save | |
| """ | |
| assert self.is_setup | |
| for self.logger.epoch in range( | |
| self.logger.epoch, self.logger.epoch + self.opts.train.epochs | |
| ): | |
| # backprop painter's disc loss to masker | |
| if ( | |
| self.logger.epoch == self.opts.gen.p.pl4m_epoch | |
| and get_num_params(self.G.painter) > 0 | |
| and "p" in self.opts.tasks | |
| and self.opts.gen.m.use_pl4m | |
| ): | |
| print( | |
| "\n\n >>> Enabling pl4m at epoch {}\n\n".format(self.logger.epoch) | |
| ) | |
| self.use_pl4m = True | |
| self.run_epoch() | |
| self.run_evaluation(verbose=1) | |
| self.save() | |
| # end vkitti2 pre-training | |
| if self.logger.epoch == self.opts.train.kitti.epochs - 1: | |
| self.switch_data(to="base") | |
| self.kitti_pretrain = False | |
| # end pseudo training | |
| if self.logger.epoch == self.opts.train.pseudo.epochs - 1: | |
| self.pseudo_training_tasks = set() | |
| def run_epoch(self): | |
| """Runs an epoch: | |
| * checks trainer is setup | |
| * gets a tuple of batches per domain | |
| * sends batches to device | |
| * updates sequentially G, D | |
| """ | |
| assert self.is_setup | |
| self.train_mode() | |
| if self.exp is not None: | |
| self.exp.log_parameter("epoch", self.logger.epoch) | |
| epoch_len = min(len(loader) for loader in self.loaders["train"].values()) | |
| epoch_desc = "Epoch {}".format(self.logger.epoch) | |
| self.logger.time.epoch_start = time() | |
| for multi_batch_tuple in tqdm( | |
| self.train_loaders, | |
| desc=epoch_desc, | |
| total=epoch_len, | |
| mininterval=0.5, | |
| unit="batch", | |
| ): | |
| self.logger.time.step_start = time() | |
| multi_batch_tuple = shuffle_batch_tuple(multi_batch_tuple) | |
| # The `[0]` is because the domain is contained in a list | |
| multi_domain_batch = { | |
| batch["domain"][0]: self.batch_to_device(batch) | |
| for batch in multi_batch_tuple | |
| } | |
| # ------------------------------ | |
| # ----- Update Generator ----- | |
| # ------------------------------ | |
| # freeze params of the discriminator | |
| if self.d_opt is not None: | |
| for param in self.D.parameters(): | |
| param.requires_grad = False | |
| self.update_G(multi_domain_batch) | |
| # ---------------------------------- | |
| # ----- Update Discriminator ----- | |
| # ---------------------------------- | |
| # unfreeze params of the discriminator | |
| if self.d_opt is not None and not self.kitti_pretrain: | |
| for param in self.D.parameters(): | |
| param.requires_grad = True | |
| self.update_D(multi_domain_batch) | |
| # ------------------------- | |
| # ----- Log Metrics ----- | |
| # ------------------------- | |
| self.logger.global_step += 1 | |
| self.logger.log_step_time(time()) | |
| if not self.kitti_pretrain: | |
| self.update_learning_rates() | |
| self.logger.log_learning_rates() | |
| self.logger.log_epoch_time(time()) | |
| def update_G(self, multi_domain_batch, verbose=0): | |
| """Perform an update on g from multi_domain_batch which is a dictionary | |
| domain => batch | |
| * automatic mixed precision according to self.opts.train.amp | |
| * compute loss for each task | |
| * loss.backward() | |
| * g_opt_step() | |
| * g_opt.step() or .extrapolation() depending on self.logger.global_step | |
| * logs losses on comet.ml with self.logger.log_losses(model_to_update="G") | |
| Args: | |
| multi_domain_batch (dict): dictionnary of domain batches | |
| """ | |
| zero_grad(self.G) | |
| if self.opts.train.amp: | |
| with autocast(): | |
| g_loss = self.get_G_loss(multi_domain_batch, verbose) | |
| self.grad_scaler_g.scale(g_loss).backward() | |
| self.grad_scaler_g.step(self.g_opt) | |
| self.grad_scaler_g.update() | |
| else: | |
| g_loss = self.get_G_loss(multi_domain_batch, verbose) | |
| g_loss.backward() | |
| self.g_opt_step() | |
| self.logger.log_losses(model_to_update="G", mode="train") | |
| def update_D(self, multi_domain_batch, verbose=0): | |
| zero_grad(self.D) | |
| if self.opts.train.amp: | |
| with autocast(): | |
| d_loss = self.get_D_loss(multi_domain_batch, verbose) | |
| self.grad_scaler_d.scale(d_loss).backward() | |
| self.grad_scaler_d.step(self.d_opt) | |
| self.grad_scaler_d.update() | |
| else: | |
| d_loss = self.get_D_loss(multi_domain_batch, verbose) | |
| d_loss.backward() | |
| self.d_opt_step() | |
| self.logger.losses.disc.total_loss = d_loss.item() | |
| self.logger.log_losses(model_to_update="D", mode="train") | |
| def get_D_loss(self, multi_domain_batch, verbose=0): | |
| """Compute the discriminators' losses: | |
| * for each domain-specific batch: | |
| * encode the image | |
| * get the conditioning tensor if using spade | |
| * source domain is the data's domain, sequentially r|s then f|n | |
| * get the target domain accordingly | |
| * compute the translated image from the data | |
| * compute the source domain discriminator's loss on the data | |
| * compute the target domain discriminator's loss on the translated image | |
| # ? In this setting, each D[decoder][domain] is updated twice towards | |
| # real or fake data | |
| See readme's update d section for details | |
| Args: | |
| multi_domain_batch ([type]): [description] | |
| Returns: | |
| [type]: [description] | |
| """ | |
| disc_loss = { | |
| "m": {"Advent": 0}, | |
| "s": {"Advent": 0}, | |
| } | |
| if self.opts.dis.p.use_local_discriminator: | |
| disc_loss["p"] = {"global": 0, "local": 0} | |
| else: | |
| disc_loss["p"] = {"gan": 0} | |
| for domain, batch in multi_domain_batch.items(): | |
| x = batch["data"]["x"] | |
| # --------------------- | |
| # ----- Painter ----- | |
| # --------------------- | |
| if domain == "rf" and self.has_painter: | |
| m = batch["data"]["m"] | |
| # sample vector | |
| with torch.no_grad(): | |
| # see spade compute_discriminator_loss | |
| fake = self.G.paint(m, x) | |
| if self.opts.gen.p.diff_aug.use: | |
| fake = self.diff_transforms(fake) | |
| x = self.diff_transforms(x) | |
| fake = fake.detach() | |
| fake.requires_grad_() | |
| if self.opts.dis.p.use_local_discriminator: | |
| fake_d_global = self.D["p"]["global"](fake) | |
| real_d_global = self.D["p"]["global"](x) | |
| fake_d_local = self.D["p"]["local"](fake * m) | |
| real_d_local = self.D["p"]["local"](x * m) | |
| global_loss = self.losses["D"]["p"](fake_d_global, False, True) | |
| global_loss += self.losses["D"]["p"](real_d_global, True, True) | |
| local_loss = self.losses["D"]["p"](fake_d_local, False, True) | |
| local_loss += self.losses["D"]["p"](real_d_local, True, True) | |
| disc_loss["p"]["global"] += global_loss | |
| disc_loss["p"]["local"] += local_loss | |
| else: | |
| real_cat = torch.cat([m, x], axis=1) | |
| fake_cat = torch.cat([m, fake], axis=1) | |
| real_fake_cat = torch.cat([real_cat, fake_cat], dim=0) | |
| real_fake_d = self.D["p"](real_fake_cat) | |
| real_d, fake_d = divide_pred(real_fake_d) | |
| disc_loss["p"]["gan"] = self.losses["D"]["p"](fake_d, False, True) | |
| disc_loss["p"]["gan"] += self.losses["D"]["p"](real_d, True, True) | |
| # -------------------- | |
| # ----- Masker ----- | |
| # -------------------- | |
| else: | |
| z = self.G.encode(x) | |
| s_pred = d_pred = cond = z_depth = None | |
| if "s" in batch["data"]: | |
| if "d" in self.opts.tasks and self.opts.gen.s.use_dada: | |
| d_pred, z_depth = self.G.decoders["d"](z) | |
| step_loss, s_pred = self.masker_s_loss( | |
| x, z, d_pred, z_depth, None, domain, for_="D" | |
| ) | |
| step_loss *= self.opts.train.lambdas.advent.adv_main | |
| disc_loss["s"]["Advent"] += step_loss | |
| if "m" in batch["data"]: | |
| if "d" in self.opts.tasks: | |
| if self.opts.gen.m.use_spade: | |
| if d_pred is None: | |
| d_pred, z_depth = self.G.decoders["d"](z) | |
| cond = self.G.make_m_cond(d_pred, s_pred, x) | |
| elif self.opts.gen.m.use_dada: | |
| if d_pred is None: | |
| d_pred, z_depth = self.G.decoders["d"](z) | |
| step_loss, _ = self.masker_m_loss( | |
| x, | |
| z, | |
| None, | |
| domain, | |
| for_="D", | |
| cond=cond, | |
| z_depth=z_depth, | |
| depth_preds=d_pred, | |
| ) | |
| step_loss *= self.opts.train.lambdas.advent.adv_main | |
| disc_loss["m"]["Advent"] += step_loss | |
| self.logger.losses.disc.update( | |
| { | |
| dom: { | |
| k: v.item() if isinstance(v, torch.Tensor) else v | |
| for k, v in d.items() | |
| } | |
| for dom, d in disc_loss.items() | |
| } | |
| ) | |
| loss = sum(v for d in disc_loss.values() for k, v in d.items()) | |
| return loss | |
| def get_G_loss(self, multi_domain_batch, verbose=0): | |
| m_loss = p_loss = None | |
| # For now, always compute "representation loss" | |
| g_loss = 0 | |
| if any(t in self.opts.tasks for t in "msd"): | |
| m_loss = self.get_masker_loss(multi_domain_batch) | |
| self.logger.losses.gen.masker = m_loss.item() | |
| g_loss += m_loss | |
| if "p" in self.opts.tasks and not self.kitti_pretrain: | |
| p_loss = self.get_painter_loss(multi_domain_batch) | |
| self.logger.losses.gen.painter = p_loss.item() | |
| g_loss += p_loss | |
| assert g_loss != 0 and not isinstance(g_loss, int), "No update in get_G_loss!" | |
| self.logger.losses.gen.total_loss = g_loss.item() | |
| return g_loss | |
| def get_masker_loss(self, multi_domain_batch): # TODO update docstrings | |
| """Only update the representation part of the model, meaning everything | |
| but the translation part | |
| * for each batch in available domains: | |
| * compute task-specific losses | |
| * compute the adaptation and translation decoders' auto-encoding losses | |
| * compute the adaptation decoder's translation losses (GAN and Cycle) | |
| Args: | |
| multi_domain_batch (dict): dictionnary mapping domain names to batches from | |
| the trainer's loaders | |
| Returns: | |
| torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas | |
| """ | |
| m_loss = 0 | |
| for domain, batch in multi_domain_batch.items(): | |
| # We don't care about the flooded domain here | |
| if domain == "rf": | |
| continue | |
| x = batch["data"]["x"] | |
| z = self.G.encode(x) | |
| # -------------------------------------- | |
| # ----- task-specific losses (2) ----- | |
| # -------------------------------------- | |
| d_pred = s_pred = z_depth = None | |
| for task in ["d", "s", "m"]: | |
| if task not in batch["data"]: | |
| continue | |
| target = batch["data"][task] | |
| if task == "d": | |
| loss, d_pred, z_depth = self.masker_d_loss( | |
| x, z, target, domain, "G" | |
| ) | |
| m_loss += loss | |
| self.logger.losses.gen.task["d"][domain] = loss.item() | |
| elif task == "s": | |
| loss, s_pred = self.masker_s_loss( | |
| x, z, d_pred, z_depth, target, domain, "G" | |
| ) | |
| m_loss += loss | |
| self.logger.losses.gen.task["s"][domain] = loss.item() | |
| elif task == "m": | |
| cond = None | |
| if self.opts.gen.m.use_spade: | |
| if not self.opts.gen.m.detach: | |
| d_pred = d_pred.clone() | |
| s_pred = s_pred.clone() | |
| cond = self.G.make_m_cond(d_pred, s_pred, x) | |
| loss, _ = self.masker_m_loss( | |
| x, | |
| z, | |
| target, | |
| domain, | |
| "G", | |
| cond=cond, | |
| z_depth=z_depth, | |
| depth_preds=d_pred, | |
| ) | |
| m_loss += loss | |
| self.logger.losses.gen.task["m"][domain] = loss.item() | |
| return m_loss | |
| def get_painter_loss(self, multi_domain_batch): | |
| """Computes the translation loss when flooding/deflooding images | |
| Args: | |
| multi_domain_batch (dict): dictionnary mapping domain names to batches from | |
| the trainer's loaders | |
| Returns: | |
| torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas | |
| """ | |
| step_loss = 0 | |
| # self.g_opt.zero_grad() | |
| lambdas = self.opts.train.lambdas | |
| batch_domain = "rf" | |
| batch = multi_domain_batch[batch_domain] | |
| x = batch["data"]["x"] | |
| # ! different mask: hides water to be reconstructed | |
| # ! 1 for water, 0 otherwise | |
| m = batch["data"]["m"] | |
| fake_flooded = self.G.paint(m, x) | |
| # ---------------------- | |
| # ----- VGG Loss ----- | |
| # ---------------------- | |
| if lambdas.G.p.vgg != 0: | |
| loss = self.losses["G"]["p"]["vgg"]( | |
| vgg_preprocess(fake_flooded * m), vgg_preprocess(x * m) | |
| ) | |
| loss *= lambdas.G.p.vgg | |
| self.logger.losses.gen.p.vgg = loss.item() | |
| step_loss += loss | |
| # --------------------- | |
| # ----- TV Loss ----- | |
| # --------------------- | |
| if lambdas.G.p.tv != 0: | |
| loss = self.losses["G"]["p"]["tv"](fake_flooded * m) | |
| loss *= lambdas.G.p.tv | |
| self.logger.losses.gen.p.tv = loss.item() | |
| step_loss += loss | |
| # -------------------------- | |
| # ----- Context Loss ----- | |
| # -------------------------- | |
| if lambdas.G.p.context != 0: | |
| loss = self.losses["G"]["p"]["context"](fake_flooded, x, m) | |
| loss *= lambdas.G.p.context | |
| self.logger.losses.gen.p.context = loss.item() | |
| step_loss += loss | |
| # --------------------------------- | |
| # ----- Reconstruction Loss ----- | |
| # --------------------------------- | |
| if lambdas.G.p.reconstruction != 0: | |
| loss = self.losses["G"]["p"]["reconstruction"](fake_flooded, x, m) | |
| loss *= lambdas.G.p.reconstruction | |
| self.logger.losses.gen.p.reconstruction = loss.item() | |
| step_loss += loss | |
| # ------------------------------------- | |
| # ----- Local & Global GAN Loss ----- | |
| # ------------------------------------- | |
| if self.opts.gen.p.diff_aug.use: | |
| fake_flooded = self.diff_transforms(fake_flooded) | |
| x = self.diff_transforms(x) | |
| if self.opts.dis.p.use_local_discriminator: | |
| fake_d_global = self.D["p"]["global"](fake_flooded) | |
| fake_d_local = self.D["p"]["local"](fake_flooded * m) | |
| real_d_global = self.D["p"]["global"](x) | |
| # Note: discriminator returns [out_1,...,out_num_D] outputs | |
| # Each out_i is a list [feat1, feat2, ..., pred_i] | |
| self.logger.losses.gen.p.gan = 0 | |
| loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False) | |
| loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False) | |
| loss *= lambdas.G["p"]["gan"] | |
| self.logger.losses.gen.p.gan = loss.item() | |
| step_loss += loss | |
| # ----------------------------------- | |
| # ----- Feature Matching Loss ----- | |
| # ----------------------------------- | |
| # (only on global discriminator) | |
| # Order must be real, fake | |
| if self.opts.dis.p.get_intermediate_features: | |
| loss = self.losses["G"]["p"]["featmatch"](real_d_global, fake_d_global) | |
| loss *= lambdas.G["p"]["featmatch"] | |
| if isinstance(loss, float): | |
| self.logger.losses.gen.p.featmatch = loss | |
| else: | |
| self.logger.losses.gen.p.featmatch = loss.item() | |
| step_loss += loss | |
| # ------------------------------------------- | |
| # ----- Single Discriminator GAN Loss ----- | |
| # ------------------------------------------- | |
| else: | |
| real_cat = torch.cat([m, x], axis=1) | |
| fake_cat = torch.cat([m, fake_flooded], axis=1) | |
| real_fake_cat = torch.cat([real_cat, fake_cat], dim=0) | |
| real_fake_d = self.D["p"](real_fake_cat) | |
| real_d, fake_d = divide_pred(real_fake_d) | |
| loss = self.losses["G"]["p"]["gan"](fake_d, True, False) | |
| self.logger.losses.gen.p.gan = loss.item() | |
| step_loss += loss | |
| # ----------------------------------- | |
| # ----- Feature Matching Loss ----- | |
| # ----------------------------------- | |
| if self.opts.dis.p.get_intermediate_features and lambdas.G.p.featmatch != 0: | |
| loss = self.losses["G"]["p"]["featmatch"](real_d, fake_d) | |
| loss *= lambdas.G.p.featmatch | |
| if isinstance(loss, float): | |
| self.logger.losses.gen.p.featmatch = loss | |
| else: | |
| self.logger.losses.gen.p.featmatch = loss.item() | |
| step_loss += loss | |
| return step_loss | |
| def masker_d_loss(self, x, z, target, domain, for_="G"): | |
| assert for_ in {"G", "D"} | |
| self.assert_z_matches_x(x, z) | |
| assert x.shape[0] == target.shape[0] | |
| zero_loss = torch.tensor(0.0, device=self.device) | |
| weight = self.opts.train.lambdas.G.d.main | |
| prediction, z_depth = self.G.decoders["d"](z) | |
| if self.opts.gen.d.classify.enable: | |
| target.squeeze_(1) | |
| full_loss = self.losses["G"]["tasks"]["d"](prediction, target) | |
| full_loss *= weight | |
| if weight == 0 or (domain == "r" and "d" not in self.pseudo_training_tasks): | |
| return zero_loss, prediction, z_depth | |
| return full_loss, prediction, z_depth | |
| def masker_s_loss(self, x, z, depth_preds, z_depth, target, domain, for_="G"): | |
| assert for_ in {"G", "D"} | |
| assert domain in {"r", "s"} | |
| self.assert_z_matches_x(x, z) | |
| assert x.shape[0] == target.shape[0] if target is not None else True | |
| full_loss = torch.tensor(0.0, device=self.device) | |
| softmax_preds = None | |
| # -------------------------- | |
| # ----- Segmentation ----- | |
| # -------------------------- | |
| pred = None | |
| if for_ == "G" or self.opts.gen.s.use_advent: | |
| pred = self.G.decoders["s"](z, z_depth) | |
| # Supervised segmentation loss: crossent for sim domain, | |
| # crossent_pseudo for real ; loss is crossent in any case | |
| if for_ == "G": | |
| if domain == "s" or "s" in self.pseudo_training_tasks: | |
| if domain == "s": | |
| logger = self.logger.losses.gen.task["s"]["crossent"] | |
| weight = self.opts.train.lambdas.G["s"]["crossent"] | |
| else: | |
| logger = self.logger.losses.gen.task["s"]["crossent_pseudo"] | |
| weight = self.opts.train.lambdas.G["s"]["crossent_pseudo"] | |
| if weight != 0: | |
| # Cross-Entropy loss | |
| loss_func = self.losses["G"]["tasks"]["s"]["crossent"] | |
| loss = loss_func(pred, target.squeeze(1)) | |
| loss *= weight | |
| full_loss += loss | |
| logger[domain] = loss.item() | |
| if domain == "r": | |
| weight = self.opts.train.lambdas.G["s"]["minent"] | |
| if self.opts.gen.s.use_minent and weight != 0: | |
| softmax_preds = softmax(pred, dim=1) | |
| # Entropy minimization loss | |
| loss = self.losses["G"]["tasks"]["s"]["minent"](softmax_preds) | |
| loss *= weight | |
| full_loss += loss | |
| self.logger.losses.gen.task["s"]["minent"]["r"] = loss.item() | |
| # Fool ADVENT discriminator | |
| if self.opts.gen.s.use_advent: | |
| if self.opts.gen.s.use_dada and depth_preds is not None: | |
| depth_preds = depth_preds.detach() | |
| else: | |
| depth_preds = None | |
| if for_ == "D": | |
| domain_label = domain | |
| logger = {} | |
| loss_func = self.losses["D"]["advent"] | |
| pred = pred.detach() | |
| weight = self.opts.train.lambdas.advent.adv_main | |
| else: | |
| domain_label = "s" | |
| logger = self.logger.losses.gen.task["s"]["advent"] | |
| loss_func = self.losses["G"]["tasks"]["s"]["advent"] | |
| weight = self.opts.train.lambdas.G["s"]["advent"] | |
| if (for_ == "D" or domain == "r") and weight != 0: | |
| if softmax_preds is None: | |
| softmax_preds = softmax(pred, dim=1) | |
| loss = loss_func( | |
| softmax_preds, | |
| self.domain_labels[domain_label], | |
| self.D["s"]["Advent"], | |
| depth_preds, | |
| ) | |
| loss *= weight | |
| full_loss += loss | |
| logger[domain] = loss.item() | |
| if for_ == "D": | |
| # WGAN: clipping or GP | |
| if self.opts.dis.s.gan_type == "GAN" or "WGAN_norm": | |
| pass | |
| elif self.opts.dis.s.gan_type == "WGAN": | |
| for p in self.D["s"]["Advent"].parameters(): | |
| p.data.clamp_( | |
| self.opts.dis.s.wgan_clamp_lower, | |
| self.opts.dis.s.wgan_clamp_upper, | |
| ) | |
| elif self.opts.dis.s.gan_type == "WGAN_gp": | |
| prob_need_grad = autograd.Variable(pred, requires_grad=True) | |
| d_out = self.D["s"]["Advent"](prob_need_grad) | |
| gp = get_WGAN_gradient(prob_need_grad, d_out) | |
| gp_loss = gp * self.opts.train.lambdas.advent.WGAN_gp | |
| full_loss += gp_loss | |
| else: | |
| raise NotImplementedError | |
| return full_loss, pred | |
| def masker_m_loss( | |
| self, x, z, target, domain, for_="G", cond=None, z_depth=None, depth_preds=None | |
| ): | |
| assert for_ in {"G", "D"} | |
| assert domain in {"r", "s"} | |
| self.assert_z_matches_x(x, z) | |
| assert x.shape[0] == target.shape[0] if target is not None else True | |
| full_loss = torch.tensor(0.0, device=self.device) | |
| pred_logits = self.G.decoders["m"](z, cond=cond, z_depth=z_depth) | |
| pred_prob = sigmoid(pred_logits) | |
| pred_prob_complementary = 1 - pred_prob | |
| prob = torch.cat([pred_prob, pred_prob_complementary], dim=1) | |
| if for_ == "G": | |
| # TV loss | |
| weight = self.opts.train.lambdas.G.m.tv | |
| if weight != 0: | |
| loss = self.losses["G"]["tasks"]["m"]["tv"](pred_prob) | |
| loss *= weight | |
| full_loss += loss | |
| self.logger.losses.gen.task["m"]["tv"][domain] = loss.item() | |
| weight = self.opts.train.lambdas.G.m.bce | |
| if domain == "s" and weight != 0: | |
| # CrossEnt Loss | |
| loss = self.losses["G"]["tasks"]["m"]["bce"](pred_logits, target) | |
| loss *= weight | |
| full_loss += loss | |
| self.logger.losses.gen.task["m"]["bce"]["s"] = loss.item() | |
| if domain == "r": | |
| weight = self.opts.train.lambdas.G["m"]["gi"] | |
| if self.opts.gen.m.use_ground_intersection and weight != 0: | |
| # GroundIntersection loss | |
| loss = self.losses["G"]["tasks"]["m"]["gi"](pred_prob, target) | |
| loss *= weight | |
| full_loss += loss | |
| self.logger.losses.gen.task["m"]["gi"]["r"] = loss.item() | |
| weight = self.opts.train.lambdas.G.m.pl4m | |
| if self.use_pl4m and weight != 0: | |
| # Painter loss | |
| pl4m_loss = self.painter_loss_for_masker(x, pred_prob) | |
| pl4m_loss *= weight | |
| full_loss += pl4m_loss | |
| self.logger.losses.gen.task.m.pl4m.r = pl4m_loss.item() | |
| weight = self.opts.train.lambdas.advent.ent_main | |
| if self.opts.gen.m.use_minent and weight != 0: | |
| # MinEnt loss | |
| loss = self.losses["G"]["tasks"]["m"]["minent"](prob) | |
| loss *= weight | |
| full_loss += loss | |
| self.logger.losses.gen.task["m"]["minent"]["r"] = loss.item() | |
| if self.opts.gen.m.use_advent: | |
| # AdvEnt loss | |
| if self.opts.gen.m.use_dada and depth_preds is not None: | |
| depth_preds = depth_preds.detach() | |
| depth_preds = torch.nn.functional.interpolate( | |
| depth_preds, size=x.shape[-2:], mode="nearest" | |
| ) | |
| else: | |
| depth_preds = None | |
| if for_ == "D": | |
| domain_label = domain | |
| logger = {} | |
| loss_func = self.losses["D"]["advent"] | |
| prob = prob.detach() | |
| weight = self.opts.train.lambdas.advent.adv_main | |
| else: | |
| domain_label = "s" | |
| logger = self.logger.losses.gen.task["m"]["advent"] | |
| loss_func = self.losses["G"]["tasks"]["m"]["advent"] | |
| weight = self.opts.train.lambdas.advent.adv_main | |
| if (for_ == "D" or domain == "r") and weight != 0: | |
| loss = loss_func( | |
| prob.to(self.device), | |
| self.domain_labels[domain_label], | |
| self.D["m"]["Advent"], | |
| depth_preds, | |
| ) | |
| loss *= weight | |
| full_loss += loss | |
| logger[domain] = loss.item() | |
| if for_ == "D": | |
| # WGAN: clipping or GP | |
| if self.opts.dis.m.gan_type == "GAN" or "WGAN_norm": | |
| pass | |
| elif self.opts.dis.m.gan_type == "WGAN": | |
| for p in self.D["s"]["Advent"].parameters(): | |
| p.data.clamp_( | |
| self.opts.dis.m.wgan_clamp_lower, | |
| self.opts.dis.m.wgan_clamp_upper, | |
| ) | |
| elif self.opts.dis.m.gan_type == "WGAN_gp": | |
| prob_need_grad = autograd.Variable(prob, requires_grad=True) | |
| d_out = self.D["s"]["Advent"](prob_need_grad) | |
| gp = get_WGAN_gradient(prob_need_grad, d_out) | |
| gp_loss = self.opts.train.lambdas.advent.WGAN_gp * gp | |
| full_loss += gp_loss | |
| else: | |
| raise NotImplementedError | |
| return full_loss, prob | |
| def painter_loss_for_masker(self, x, m): | |
| # pl4m loss | |
| # painter should not be updated | |
| for param in self.G.painter.parameters(): | |
| param.requires_grad = False | |
| # TODO for param in self.D.painter.parameters(): | |
| # param.requires_grad = False | |
| fake_flooded = self.G.paint(m, x) | |
| if self.opts.dis.p.use_local_discriminator: | |
| fake_d_global = self.D["p"]["global"](fake_flooded) | |
| fake_d_local = self.D["p"]["local"](fake_flooded * m) | |
| # Note: discriminator returns [out_1,...,out_num_D] outputs | |
| # Each out_i is a list [feat1, feat2, ..., pred_i] | |
| pl4m_loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False) | |
| pl4m_loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False) | |
| else: | |
| real_cat = torch.cat([m, x], axis=1) | |
| fake_cat = torch.cat([m, fake_flooded], axis=1) | |
| real_fake_cat = torch.cat([real_cat, fake_cat], dim=0) | |
| real_fake_d = self.D["p"](real_fake_cat) | |
| _, fake_d = divide_pred(real_fake_d) | |
| pl4m_loss = self.losses["G"]["p"]["gan"](fake_d, True, False) | |
| if "p" in self.opts.tasks: | |
| for param in self.G.painter.parameters(): | |
| param.requires_grad = True | |
| return pl4m_loss | |
| def run_evaluation(self, verbose=0): | |
| print("******************* Running Evaluation ***********************") | |
| start_time = time() | |
| self.eval_mode() | |
| val_logger = None | |
| nb_of_batches = None | |
| for i, multi_batch_tuple in enumerate(self.val_loaders): | |
| # create a dictionnary (domain => batch) from tuple | |
| # (batch_domain_0, ..., batch_domain_i) | |
| # and send it to self.device | |
| nb_of_batches = i + 1 | |
| multi_domain_batch = { | |
| batch["domain"][0]: self.batch_to_device(batch) | |
| for batch in multi_batch_tuple | |
| } | |
| self.get_G_loss(multi_domain_batch, verbose) | |
| if val_logger is None: | |
| val_logger = deepcopy(self.logger.losses.generator) | |
| else: | |
| val_logger = sum_dict(val_logger, self.logger.losses.generator) | |
| val_logger = div_dict(val_logger, nb_of_batches) | |
| self.logger.losses.generator = val_logger | |
| self.logger.log_losses(model_to_update="G", mode="val") | |
| for d in self.opts.domains: | |
| self.logger.log_comet_images("train", d) | |
| self.logger.log_comet_images("val", d) | |
| if "m" in self.opts.tasks and self.has_painter and not self.kitti_pretrain: | |
| self.logger.log_comet_combined_images("train", "r") | |
| self.logger.log_comet_combined_images("val", "r") | |
| if self.exp is not None: | |
| print() | |
| if "m" in self.opts.tasks or "s" in self.opts.tasks: | |
| self.eval_images("val", "r") | |
| self.eval_images("val", "s") | |
| if "p" in self.opts.tasks and not self.kitti_pretrain: | |
| val_fid = compute_val_fid(self) | |
| if self.exp is not None: | |
| self.exp.log_metric("val_fid", val_fid, step=self.logger.global_step) | |
| else: | |
| print("Validation FID Score", val_fid) | |
| self.train_mode() | |
| timing = int(time() - start_time) | |
| print("****************** Done in {}s *********************".format(timing)) | |
| def eval_images(self, mode, domain): | |
| if domain == "s" and self.kitti_pretrain: | |
| domain = "kitti" | |
| if domain == "rf" or domain not in self.display_images[mode]: | |
| return | |
| metric_funcs = {"accuracy": accuracy, "mIOU": mIOU} | |
| metric_avg_scores = {"m": {}} | |
| if "s" in self.opts.tasks: | |
| metric_avg_scores["s"] = {} | |
| if "d" in self.opts.tasks and domain == "s" and self.opts.gen.d.classify.enable: | |
| metric_avg_scores["d"] = {} | |
| for key in metric_funcs: | |
| for task in metric_avg_scores: | |
| metric_avg_scores[task][key] = [] | |
| for im_set in self.display_images[mode][domain]: | |
| x = im_set["data"]["x"].unsqueeze(0).to(self.device) | |
| z = self.G.encode(x) | |
| s_pred = d_pred = z_depth = None | |
| if "d" in metric_avg_scores: | |
| d_pred, z_depth = self.G.decoders["d"](z) | |
| d_pred = d_pred.detach().cpu() | |
| if domain == "s": | |
| d = im_set["data"]["d"].unsqueeze(0).detach() | |
| for metric in metric_funcs: | |
| metric_score = metric_funcs[metric](d_pred, d) | |
| metric_avg_scores["d"][metric].append(metric_score) | |
| if "s" in metric_avg_scores: | |
| if z_depth is None: | |
| if self.opts.gen.s.use_dada and "d" in self.opts.tasks: | |
| _, z_depth = self.G.decoders["d"](z) | |
| s_pred = self.G.decoders["s"](z, z_depth).detach().cpu() | |
| s = im_set["data"]["s"].unsqueeze(0).detach() | |
| for metric in metric_funcs: | |
| metric_score = metric_funcs[metric](s_pred, s) | |
| metric_avg_scores["s"][metric].append(metric_score) | |
| if "m" in self.opts: | |
| cond = None | |
| if s_pred is not None and d_pred is not None: | |
| cond = self.G.make_m_cond(d_pred, s_pred, x) | |
| if z_depth is None: | |
| if self.opts.gen.m.use_dada and "d" in self.opts.tasks: | |
| _, z_depth = self.G.decoders["d"](z) | |
| pred_mask = ( | |
| (self.G.mask(z=z, cond=cond, z_depth=z_depth)).detach().cpu() | |
| ) | |
| pred_mask = (pred_mask > 0.5).to(torch.float32) | |
| pred_prob = torch.cat([1 - pred_mask, pred_mask], dim=1) | |
| m = im_set["data"]["m"].unsqueeze(0).detach() | |
| for metric in metric_funcs: | |
| if metric != "mIOU": | |
| metric_score = metric_funcs[metric](pred_mask, m) | |
| else: | |
| metric_score = metric_funcs[metric](pred_prob, m) | |
| metric_avg_scores["m"][metric].append(metric_score) | |
| metric_avg_scores = { | |
| task: { | |
| metric: np.mean(values) if values else float("nan") | |
| for metric, values in met_dict.items() | |
| } | |
| for task, met_dict in metric_avg_scores.items() | |
| } | |
| metric_avg_scores = { | |
| task: { | |
| metric: value if not np.isnan(value) else -1 | |
| for metric, value in met_dict.items() | |
| } | |
| for task, met_dict in metric_avg_scores.items() | |
| } | |
| if self.exp is not None: | |
| self.exp.log_metrics( | |
| flatten_opts(metric_avg_scores), | |
| prefix=f"metrics_{mode}_{domain}", | |
| step=self.logger.global_step, | |
| ) | |
| else: | |
| print(f"metrics_{mode}_{domain}") | |
| print(flatten_opts(metric_avg_scores)) | |
| return 0 | |
| def functional_test_mode(self): | |
| import atexit | |
| self.opts.output_path = ( | |
| Path("~").expanduser() / "climategan" / "functional_tests" | |
| ) | |
| Path(self.opts.output_path).mkdir(parents=True, exist_ok=True) | |
| with open(Path(self.opts.output_path) / "is_functional.test", "w") as f: | |
| f.write("trainer functional test - delete this dir") | |
| if self.exp is not None: | |
| self.exp.log_parameter("is_functional_test", True) | |
| atexit.register(self.del_output_path) | |
| def del_output_path(self, force=False): | |
| import shutil | |
| if not Path(self.opts.output_path).exists(): | |
| return | |
| if (Path(self.opts.output_path) / "is_functional.test").exists() or force: | |
| shutil.rmtree(self.opts.output_path) | |
| def compute_fire(self, x, seg_preds=None, z=None, z_depth=None): | |
| """ | |
| Transforms input tensor given wildfires event | |
| Args: | |
| x (torch.Tensor): Input tensor | |
| seg_preds (torch.Tensor): Semantic segmentation | |
| predictions for input tensor | |
| z (torch.Tensor): Latent vector of encoded "x". | |
| Can be None if seg_preds is given. | |
| Returns: | |
| torch.Tensor: Wildfire version of input tensor | |
| """ | |
| if seg_preds is None: | |
| if z is None: | |
| z = self.G.encode(x) | |
| seg_preds = self.G.decoders["s"](z, z_depth) | |
| return add_fire(x, seg_preds, self.opts.events.fire) | |
| def compute_flood( | |
| self, x, z=None, z_depth=None, m=None, s=None, cloudy=None, bin_value=-1 | |
| ): | |
| """ | |
| Applies a flood (mask + paint) to an input image, with optionally | |
| pre-computed masker z or mask | |
| Args: | |
| x (torch.Tensor): B x C x H x W -1:1 input image | |
| z (torch.Tensor, optional): B x C x H x W Masker latent vector. | |
| Defaults to None. | |
| m (torch.Tensor, optional): B x 1 x H x W Mask. Defaults to None. | |
| bin_value (float, optional): Mask binarization value. | |
| Set to -1 to use smooth masks (no binarization) | |
| Returns: | |
| torch.Tensor: B x 3 x H x W -1:1 flooded image | |
| """ | |
| if m is None: | |
| if z is None: | |
| z = self.G.encode(x) | |
| if "d" in self.opts.tasks and self.opts.gen.m.use_dada and z_depth is None: | |
| _, z_depth = self.G.decoders["d"](z) | |
| m = self.G.mask(x=x, z=z, z_depth=z_depth) | |
| if bin_value >= 0: | |
| m = (m > bin_value).to(m.dtype) | |
| if cloudy: | |
| assert s is not None | |
| return self.G.paint_cloudy(m, x, s) | |
| return self.G.paint(m, x) | |
| def compute_smog(self, x, z=None, d=None, s=None, use_sky_seg=False): | |
| # implementation from the paper: | |
| # HazeRD: An outdoor scene dataset and benchmark for single image dehazing | |
| sky_mask = None | |
| if d is None or (use_sky_seg and s is None): | |
| if z is None: | |
| z = self.G.encode(x) | |
| if d is None: | |
| d, _ = self.G.decoders["d"](z) | |
| if use_sky_seg and s is None: | |
| if "s" not in self.opts.tasks: | |
| raise ValueError( | |
| "Cannot have " | |
| + "(use_sky_seg is True and s is None and 's' not in tasks)" | |
| ) | |
| s = self.G.decoders["s"](z) | |
| # TODO: s to sky mask | |
| # TODO: interpolate to d's size | |
| params = self.opts.events.smog | |
| airlight = params.airlight * torch.ones(3) | |
| airlight = airlight.view(1, -1, 1, 1).to(self.device) | |
| irradiance = srgb2lrgb(x) | |
| beta = torch.tensor([params.beta / params.vr] * 3) | |
| beta = beta.view(1, -1, 1, 1).to(self.device) | |
| d = normalize(d, mini=0.3, maxi=1.0) | |
| d = 1.0 / d | |
| d = normalize(d, mini=0.1, maxi=1) | |
| if sky_mask is not None: | |
| d[sky_mask] = 1 | |
| d = torch.nn.functional.interpolate( | |
| d, size=x.shape[-2:], mode="bilinear", align_corners=True | |
| ) | |
| d = d.repeat(1, 3, 1, 1) | |
| transmission = torch.exp(d * -beta) | |
| smogged = transmission * irradiance + (1 - transmission) * airlight | |
| smogged = lrgb2srgb(smogged) | |
| # add yellow filter | |
| alpha = params.alpha / 255 | |
| yellow_mask = torch.Tensor([params.yellow_color]) / 255 | |
| yellow_filter = ( | |
| yellow_mask.unsqueeze(2) | |
| .unsqueeze(2) | |
| .repeat(1, 1, smogged.shape[-2], smogged.shape[-1]) | |
| .to(self.device) | |
| ) | |
| smogged = smogged * (1 - alpha) + yellow_filter * alpha | |
| return smogged | |