| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch.nn as nn |
| from monai.networks.blocks import Warp |
| from monai.networks.nets import resnet18 |
| from monai.networks.nets.regunet import AffineHead |
|
|
|
|
| class RegResNet(nn.Module): |
| def __init__( |
| self, |
| image_size=(64, 64), |
| spatial_dims=2, |
| mod=None, |
| mode="bilinear", |
| padding_mode="border", |
| features=400, |
| ): |
| super().__init__() |
| self.features = resnet18(n_input_channels=2, spatial_dims=spatial_dims) if mod is None else mod |
| self.affine_head = AffineHead( |
| spatial_dims=spatial_dims, image_size=image_size, decode_size=[1] * spatial_dims, in_channels=features |
| ) |
| self.warp = Warp(mode=mode, padding_mode=padding_mode) |
| self.image_size = image_size |
|
|
| def forward(self, x): |
| self.features.to(device=x.device) |
| self.affine_head.to(device=x.device) |
| out = self.features(x) |
| ddf = self.affine_head([out], self.image_size) |
| f = self.warp(x[:, :1], ddf) |
| return f |
|
|