Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from climategan.blocks import BaseDecoder, Conv2dBlock, InterpolateNearest2d | |
| from climategan.utils import find_target_size | |
| def create_depth_decoder(opts, no_init=False, verbose=0): | |
| if opts.gen.d.architecture == "base": | |
| decoder = BaseDepthDecoder(opts) | |
| if "s" in opts.task: | |
| assert opts.gen.s.use_dada is False | |
| if "m" in opts.tasks: | |
| assert opts.gen.m.use_dada is False | |
| else: | |
| decoder = DADADepthDecoder(opts) | |
| if verbose > 0: | |
| print(f" - Add {decoder.__class__.__name__}") | |
| return decoder | |
| class DADADepthDecoder(nn.Module): | |
| """ | |
| Depth decoder based on depth auxiliary task in DADA paper | |
| """ | |
| def __init__(self, opts): | |
| super().__init__() | |
| if ( | |
| opts.gen.encoder.architecture == "deeplabv3" | |
| and opts.gen.deeplabv3.backbone == "mobilenet" | |
| ): | |
| res_dim = 320 | |
| else: | |
| res_dim = 2048 | |
| mid_dim = 512 | |
| self.do_feat_fusion = False | |
| if opts.gen.m.use_dada or ("s" in opts.tasks and opts.gen.s.use_dada): | |
| self.do_feat_fusion = True | |
| self.dec4 = Conv2dBlock( | |
| 128, | |
| res_dim, | |
| 1, | |
| stride=1, | |
| padding=0, | |
| bias=True, | |
| activation="lrelu", | |
| norm="none", | |
| ) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.enc4_1 = Conv2dBlock( | |
| res_dim, | |
| mid_dim, | |
| 1, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| activation="lrelu", | |
| pad_type="reflect", | |
| norm="batch", | |
| ) | |
| self.enc4_2 = Conv2dBlock( | |
| mid_dim, | |
| mid_dim, | |
| 3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| activation="lrelu", | |
| pad_type="reflect", | |
| norm="batch", | |
| ) | |
| self.enc4_3 = Conv2dBlock( | |
| mid_dim, | |
| 128, | |
| 1, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| activation="lrelu", | |
| pad_type="reflect", | |
| norm="batch", | |
| ) | |
| self.upsample = None | |
| if opts.gen.d.upsample_featuremaps: | |
| self.upsample = nn.Sequential( | |
| *[ | |
| InterpolateNearest2d(), | |
| Conv2dBlock( | |
| 128, | |
| 32, | |
| 3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| activation="lrelu", | |
| pad_type="reflect", | |
| norm="batch", | |
| ), | |
| nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), | |
| ] | |
| ) | |
| self._target_size = find_target_size(opts, "d") | |
| print( | |
| " - {}: setting target size to {}".format( | |
| self.__class__.__name__, self._target_size | |
| ) | |
| ) | |
| def set_target_size(self, size): | |
| """ | |
| Set final interpolation's target size | |
| Args: | |
| size (int, list, tuple): target size (h, w). If int, target will be (i, i) | |
| """ | |
| if isinstance(size, (list, tuple)): | |
| self._target_size = size[:2] | |
| else: | |
| self._target_size = (size, size) | |
| def forward(self, z): | |
| if isinstance(z, (list, tuple)): | |
| z = z[0] | |
| z4_enc = self.enc4_1(z) | |
| z4_enc = self.enc4_2(z4_enc) | |
| z4_enc = self.enc4_3(z4_enc) | |
| z_depth = None | |
| if self.do_feat_fusion: | |
| z_depth = self.dec4(z4_enc) | |
| if self.upsample is not None: | |
| z4_enc = self.upsample(z4_enc) | |
| depth = torch.mean(z4_enc, dim=1, keepdim=True) # DADA paper decoder | |
| if depth.shape[-1] != self._target_size: | |
| depth = F.interpolate( | |
| depth, | |
| size=(384, 384), # size used in MiDaS inference | |
| mode="bicubic", # what MiDaS uses | |
| align_corners=False, | |
| ) | |
| depth = F.interpolate( | |
| depth, (self._target_size, self._target_size), mode="nearest" | |
| ) # what we used in the transforms to resize input | |
| return depth, z_depth | |
| def __str__(self): | |
| return "DADA Depth Decoder" | |
| class BaseDepthDecoder(BaseDecoder): | |
| def __init__(self, opts): | |
| low_level_feats_dim = -1 | |
| use_v3 = opts.gen.encoder.architecture == "deeplabv3" | |
| use_mobile_net = opts.gen.deeplabv3.backbone == "mobilenet" | |
| use_low = opts.gen.d.use_low_level_feats | |
| if use_v3 and use_mobile_net: | |
| input_dim = 320 | |
| if use_low: | |
| low_level_feats_dim = 24 | |
| elif use_v3: | |
| input_dim = 2048 | |
| if use_low: | |
| low_level_feats_dim = 256 | |
| else: | |
| input_dim = 2048 | |
| n_upsample = 1 if opts.gen.d.upsample_featuremaps else 0 | |
| output_dim = ( | |
| 1 | |
| if not opts.gen.d.classify.enable | |
| else opts.gen.d.classify.linspace.buckets | |
| ) | |
| self._target_size = find_target_size(opts, "d") | |
| print( | |
| " - {}: setting target size to {}".format( | |
| self.__class__.__name__, self._target_size | |
| ) | |
| ) | |
| super().__init__( | |
| n_upsample=n_upsample, | |
| n_res=opts.gen.d.n_res, | |
| input_dim=input_dim, | |
| proj_dim=opts.gen.d.proj_dim, | |
| output_dim=output_dim, | |
| norm=opts.gen.d.norm, | |
| activ=opts.gen.d.activ, | |
| pad_type=opts.gen.d.pad_type, | |
| output_activ="none", | |
| low_level_feats_dim=low_level_feats_dim, | |
| ) | |
| def set_target_size(self, size): | |
| """ | |
| Set final interpolation's target size | |
| Args: | |
| size (int, list, tuple): target size (h, w). If int, target will be (i, i) | |
| """ | |
| if isinstance(size, (list, tuple)): | |
| self._target_size = size[:2] | |
| else: | |
| self._target_size = (size, size) | |
| def forward(self, z, cond=None): | |
| if self._target_size is None: | |
| error = "self._target_size should be set with self.set_target_size()" | |
| error += "to interpolate depth to the target depth map's size" | |
| raise ValueError(error) | |
| d = super().forward(z) | |
| preds = F.interpolate( | |
| d, size=self._target_size, mode="bilinear", align_corners=True | |
| ) | |
| return preds, None | |