import streamlit as st import torch import torch.nn as nn import torchvision.transforms as tr from PIL import Image, ImageOps import numpy as np class UNetDownBlock(nn.Module): def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0): super(UNetDownBlock, self).__init__() layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1, bias=False)] if normalize: layers.append(nn.InstanceNorm2d(out_channels)) layers.append(nn.LeakyReLU(0.2, inplace=True)) if dropout > 0.0: layers.append(nn.Dropout(dropout)) self.down = nn.Sequential(*layers) def forward(self, x): return self.down(x) class UNetUpBlock(nn.Module): def __init__(self, in_channels, skip_channels, out_channels, dropout=0.0): super(UNetUpBlock, self).__init__() self.up = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1, bias=False) self.conv = nn.Conv2d(out_channels + skip_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm = nn.InstanceNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None def forward(self, x, skip_input): x = self.up(x) x = torch.cat((x, skip_input), dim=1) x = self.conv(x) x = self.norm(x) x = self.relu(x) if self.dropout: x = self.dropout(x) return x class UnetGenerator(nn.Module): def __init__(self, in_channels=3, out_channels=3, num_downs=7, base_filters=64, dropout=0.0): super(UnetGenerator, self).__init__() # Encoder self.encoder_channels = [] self.down_blocks = nn.ModuleList() self.down_blocks.append(UNetDownBlock(in_channels, base_filters, normalize=False)) self.encoder_channels.append(base_filters) filters = base_filters for i in range(1, num_downs): next_filters = min(filters * 2, 512) self.down_blocks.append(UNetDownBlock(filters, next_filters, normalize=True, dropout=dropout)) self.encoder_channels.append(next_filters) filters = next_filters # Decoder self.up_blocks = nn.ModuleList() for i in range(num_downs - 1): prev_filters = filters skip_channels = self.encoder_channels[-(i+2)] filters = max(filters // 2, base_filters) self.up_blocks.append(UNetUpBlock(prev_filters, skip_channels, filters, dropout=dropout)) self.final_conv = nn.Sequential( nn.ConvTranspose2d(filters + self.encoder_channels[0], out_channels, 4, stride=2, padding=1), nn.Tanh() ) def forward(self, x): down_results = [] cur = x for down in self.down_blocks: cur = down(cur) down_results.append(cur) for i, up in enumerate(self.up_blocks): cur = up(cur, down_results[-(i+2)]) out = self.final_conv(torch.cat((cur, down_results[0]), dim=1)) return out class CycleGAN(nn.Module): def __init__(self, generator_AtoB, generator_BtoA): super(CycleGAN, self).__init__() self.generator_AtoB = generator_AtoB self.generator_BtoA = generator_BtoA def forward(self, x): # Для инференса не используется return self.generator_AtoB(x)