import streamlit as st import torch from PIL import Image from model import UnetGenerator, CycleGAN from utils import get_val_transform, de_normalize def load_model(checkpoint_path, device): generator_AtoB = UnetGenerator(in_channels=3, out_channels=3, num_downs=5, base_filters=64, dropout=0.0) generator_BtoA = UnetGenerator(in_channels=3, out_channels=3, num_downs=5, base_filters=64, dropout=0.0) model = CycleGAN(generator_AtoB, generator_BtoA) model.to(device) checkpoint = torch.load(checkpoint_path, map_location=device) state_dict = checkpoint['model_state_dict'] filtered_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("discriminator")} model.load_state_dict(filtered_state_dict) model.eval() return model mean_a = [0.385, 0.409, 0.405] std_a = [0.263, 0.252, 0.292] mean_b = [0.403, 0.423, 0.453] std_b = [0.271, 0.266, 0.295] transform_a = get_val_transform(mean_a, std_a, image_size=256, crop_size=224) transform_b = get_val_transform(mean_b, std_b, image_size=256, crop_size=224) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @st.cache_resource def get_model(): checkpoint_path = "assets/cycle_gan#4.pt" model = load_model(checkpoint_path, device) return model model = get_model() st.title("CycleGAN Demo") with st.expander("Описание"): st.write("### Обучение генератора изображений") # ResNet st.write("**1. ResNet генератор**") st.write("Изначально я начал обучать ResNet генератор, но изображения получались слишком мутными.") st.write("Пример изображения:") st.image("assets/images/example_resnet.png", caption="Пример ResNet", use_container_width=True) st.write("Кажется, что обучение шло адекватно, поэтому как на резнете получить что-то приличное, я не очень понял:") st.image("assets/images/loss_resnet.png", caption="Loss ResNet", use_container_width=True) # UNet st.write("**2. Переход на UNet генератор**") st.write("После небольшого исследования я решил перейти на UNet генератор. Изображения стали гораздо четче, особенно реконструкция, хотя переведённые (translated) изображения оставляли желать лучшего.") st.write("Пример результатов в этой итерации:") st.image("assets/images/bad_examples_unet.png", caption="Bad Examples UNet", use_container_width=True) # аугментации st.write("**3. Улучшение аугментаций**") st.write("Так как транслированные изображения оставались мутными, я изменил аугментации, добавив:") st.write("- **RandomResized**") st.write("- **Salt and Pepper**") st.write("- **Random Elastic Transform**") st.write("После этих изменений тестовые изображения стали значительно лучше (если не считать небо из oblivion))):") st.image("assets/images/eamples1_unet.png", caption="Example 1 UNet", use_container_width=True) st.image("assets/images/examples2_unet.png", caption="Example 2 UNet", use_container_width=True) # датасет st.write("**4. Работа с датасетом**") st.write("Исходный датасет содержал много смешанных изображений (например, снежные горы и зелёный лес на одном изображении), что, возможно, негативно влияло на обучение. Я попробовал собрать зимние и летние датасеты — получилось около 4 тыс. изображений (3 тыс. зимних и 1 тыс. летних), которые я добавил к исходному датасету. Однако качество обучения не улучшилось, а на некоторых примерах даже ухудшилось, поэтому я решил остановиться на первоначальном варианте.") # Обучение в 512 st.write("**5. Обучение на разрешении 512**") st.write("При обучении в 512 реконструированные изображения получались лучше и четче, чем при 256, но зимние сцены превратились в белое пятно:") st.image("assets/images/example_unet_512.png", caption="Example UNet 512", use_container_width=True) st.write("Мне кажется, я просто все переобучил и надо было остановиться на эпохе 60ой") st.image("assets/images/loss_unet_512.png", caption="Loss UNet 512", use_container_width=True) st.write("**6. Вместо summary**") st.write("В целом конечно можно много чего улучшать: баловаться с гиперпараметрами, почистить получше изображения, потестировать другие лоссы и тд. Но потратить ещё столько же времени я уже физически не смогу, увы. Поэтому как есть. Спасибо за внимание!") st.write("Загрузите изображение и выберите тип преобразования:") direction = st.radio("Направление преобразования", ("Летний домен -> Зимний домен", "Зимний домен -> Летний домен")) uploaded_file = st.file_uploader("Загрузите изображение", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file).convert("RGB") if st.button("Сгенерировать"): with st.spinner("Обработка..."): if direction == "Летний домен -> Зимний домен": input_transform = transform_a de_norm = lambda t: de_normalize(t, mean_a, std_a) input_tensor = input_transform(image).unsqueeze(0).to(device) with torch.no_grad(): fake = model.generator_AtoB(input_tensor) rec = model.generator_BtoA(fake) else: input_transform = transform_b de_norm = lambda t: de_normalize(t, mean_b, std_b) input_tensor = input_transform(image).unsqueeze(0).to(device) with torch.no_grad(): fake = model.generator_BtoA(input_tensor) rec = model.generator_AtoB(fake) fake_image = de_norm(fake.squeeze(0)) rec_image = de_norm(rec.squeeze(0)) cols = st.columns(3) cols[0].image(image, caption="Исходное изображение", width=255) cols[1].image(fake_image, caption="Сгенерированное изображение", width=255) cols[2].image(rec_image, caption="Реконструкция", width=255)