initail upload
Browse files- README.md +185 -0
- get_parser.py +121 -0
- train.py +381 -0
- validation.ipynb +364 -0
README.md
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CAM-Seg: A Continuous-valued Embedding Approach for Semantic Image Generation
|
| 2 |
+
|
| 3 |
+
**Official PyTorch Implementation**
|
| 4 |
+
|
| 5 |
+
This is a PyTorch/GPU implementation of the paper [CAM-Seg: A Continuous-valued Embedding Approach for Semantic Image Generation](https://arxiv.org/abs/2503.15617)
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
@article{ahmed2025cam,
|
| 9 |
+
title={CAM-Seg: A Continuous-valued Embedding Approach for Semantic Image Generation},
|
| 10 |
+
author={Ahmed, Masud and Hasan, Zahid and Haque, Syed Arefinul and Faridee, Abu Zaher Md and Purushotham, Sanjay and You, Suya and Roy, Nirmalya},
|
| 11 |
+
journal={arXiv preprint arXiv:2503.15617},
|
| 12 |
+
year={2025}
|
| 13 |
+
}
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
## Abstract
|
| 17 |
+
Traditional transformer-based semantic segmentation relies on quantized embeddings. However, our analysis reveals that autoencoder accuracy on segmentation mask using quantized embeddings (e.g. VQ-VAE) is 8\% lower than continuous-valued embeddings (e.g. KL-VAE). Motivated by this, we propose a continuous-valued embedding framework for semantic segmentation. By reformulating semantic mask generation as a continuous image-to-embedding diffusion process, our approach eliminates the need for discrete latent representations while preserving fine-grained spatial and semantic details. Our key contribution includes a diffusion-guided autoregressive transformer that learns a continuous semantic embedding space by modeling long-range dependencies in image features. Our framework contains a unified architecture combining a VAE encoder for continuous feature extraction, a diffusion-guided transformer for conditioned embedding generation, and a VAE decoder for semantic mask reconstruction. Our setting facilitates zero-shot domain adaptation capabilities enabled by the continuity of the embedding space. Experiments across diverse datasets (e.g., Cityscapes and domain-shifted variants) demonstrate state-of-the-art robustness to distribution shifts, including adverse weather (e.g., fog, snow) and viewpoint variations. Our model also exhibits strong noise resilience, achieving robust performance ($\approx$ 95\% AP compared to baseline) under gaussian noise, moderate motion blur, and moderate brightness/contrast variations, while experiencing only a moderate impact ($\approx$ 90\% AP compared to baseline) from 50\% salt and pepper noise, saturation and hue shifts.
|
| 18 |
+
|
| 19 |
+
## Result
|
| 20 |
+
Trained on Cityscape dataset and tested on SemanticKITTI, ACDC, CADEdgeTune dataset
|
| 21 |
+
<p align="center">
|
| 22 |
+
<img src="demo/qualitative.png" width="720">
|
| 23 |
+
</p>
|
| 24 |
+
|
| 25 |
+
Quantitative results of semantic segmentation under various noise conditions
|
| 26 |
+
<p align="center">
|
| 27 |
+
<table>
|
| 28 |
+
<tr>
|
| 29 |
+
<td align="center"><img src="demo/saltpepper_noise.png" width="200"/><br>Salt & Pepper Noise</td>
|
| 30 |
+
<td align="center"><img src="demo/motion_blur.png" width="200"/><br>Motion Blur</td>
|
| 31 |
+
<td align="center"><img src="demo/gaussian_noise.png" width="200"/><br>Gaussian Noise</td>
|
| 32 |
+
<td align="center"><img src="demo/gaussian_blur.png" width="200"/><br>Gaussian Blur</td>
|
| 33 |
+
</tr>
|
| 34 |
+
<tr>
|
| 35 |
+
<td align="center"><img src="demo/brightness.png" width="200"/><br>Brightness Variation</td>
|
| 36 |
+
<td align="center"><img src="demo/contrast.png" width="200"/><br>Contrast Variation</td>
|
| 37 |
+
<td align="center"><img src="demo/saturation.png" width="200"/><br>Saturation Variation</td>
|
| 38 |
+
<td align="center"><img src="demo/hue.png" width="200"/><br>Hue Variation</td>
|
| 39 |
+
</tr>
|
| 40 |
+
</table>
|
| 41 |
+
</p>
|
| 42 |
+
|
| 43 |
+
## Prerequisite
|
| 44 |
+
To install the docker environment, first edit the `docker_env/Makefile`:
|
| 45 |
+
```
|
| 46 |
+
IMAGE=img_name/dl-aio
|
| 47 |
+
CONTAINER=containter_name
|
| 48 |
+
AVAILABLE_GPUS='0,1,2,3'
|
| 49 |
+
LOCAL_JUPYTER_PORT=18888
|
| 50 |
+
LOCAL_TENSORBOARD_PORT=18006
|
| 51 |
+
PASSWORD=yourpassword
|
| 52 |
+
WORKSPACE=workspace_directory
|
| 53 |
+
```
|
| 54 |
+
- Edit the `img_name`, `containter_name`, `available_gpus`, `jupyter_port`, `tensorboard_port`, `password`, `workspace_directory`
|
| 55 |
+
|
| 56 |
+
1. For the first time run the following commands in terminal:
|
| 57 |
+
```
|
| 58 |
+
cd docker_env
|
| 59 |
+
make docker-build
|
| 60 |
+
make docker-run
|
| 61 |
+
```
|
| 62 |
+
2. or further use to docker environment
|
| 63 |
+
- To stop the environmnet: `make docker-stop`
|
| 64 |
+
- To resume the environmente: `make docker-resume`
|
| 65 |
+
|
| 66 |
+
For coding open a web browser `ip_address:jupyter_port` e.g.,`http://localhost:18888`
|
| 67 |
+
|
| 68 |
+
## Dataset
|
| 69 |
+
Four Dataset is used in the work
|
| 70 |
+
1. [Cityscapes Dataset](https://www.cityscapes-dataset.com/)
|
| 71 |
+
2. [KITTI Dataset](https://www.cvlibs.net/datasets/kitti/eval_step.php)
|
| 72 |
+
3. [ACDC Dataset](https://acdc.vision.ee.ethz.ch/)
|
| 73 |
+
4. [CAD-EdgeTune Dataset](https://ieee-dataport.org/documents/cad-edgetune)
|
| 74 |
+
|
| 75 |
+
**Modify the trainlist and vallist file to edit train and test split**
|
| 76 |
+
|
| 77 |
+
### Dataset structure
|
| 78 |
+
- Cityscapes Dataset
|
| 79 |
+
```
|
| 80 |
+
|-CityScapes
|
| 81 |
+
|----leftImg8bit #contians the RGB images
|
| 82 |
+
|----gtFine #contains semantic segmentation labels
|
| 83 |
+
|----trainlist.txt #image list used for training
|
| 84 |
+
|----vallist.txt #image list used for testing
|
| 85 |
+
|----cityscape.yaml #configuration file for Cityscapes dataset
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
- ACDC Dataset
|
| 89 |
+
```
|
| 90 |
+
|-ACDC
|
| 91 |
+
|----rgb_anon #contians the RGB images
|
| 92 |
+
|----gt #contains semantic segmentation labels
|
| 93 |
+
|----vallist_fog.txt #image list used for testing fog data
|
| 94 |
+
|----vallist_rain.txt #image list used for testing rain data
|
| 95 |
+
|----vallist_snow.txt #image list used for testing snow data
|
| 96 |
+
|----acdc.yaml #configuration file for ACDC dataset
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
## Weights
|
| 100 |
+
To download the pretrained weights please visit [Hugging Face Repo](https://huggingface.co/mahmed10/CAM-Seg)
|
| 101 |
+
- **LDM model** Pretrained model from Rombach et al.'s Latent Diffusion Models is used [Link](https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/vae/modelf16.ckpt)
|
| 102 |
+
- **MAR model** Following mar model is used
|
| 103 |
+
|
| 104 |
+
|Training Data|Model|Params|Link|
|
| 105 |
+
|-------------|-----|------|----|
|
| 106 |
+
|Cityscapes | Mar-base| 217M|[link](https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/mar/city768.16.pth)|
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
Download this weight files and organize as follow
|
| 110 |
+
```
|
| 111 |
+
|-pretrained_models
|
| 112 |
+
|----mar
|
| 113 |
+
|--------city768.16.pth
|
| 114 |
+
|----vae
|
| 115 |
+
|--------modelf16.ckpt
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
**Alternative code to automatically download pretrain weights**
|
| 119 |
+
```
|
| 120 |
+
import os
|
| 121 |
+
import requests
|
| 122 |
+
|
| 123 |
+
# Define URLs and file paths
|
| 124 |
+
files_to_download = {
|
| 125 |
+
"https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/vae/modelf16.ckpt":
|
| 126 |
+
"pretrained_models/vae/modelf16.ckpt",
|
| 127 |
+
"https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/mar/city768.16.pth":
|
| 128 |
+
"pretrained_models/mar/city768.16.pth"
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
for url, path in files_to_download.items():
|
| 132 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 133 |
+
|
| 134 |
+
print(f"Downloading from {url}...")
|
| 135 |
+
response = requests.get(url, stream=True)
|
| 136 |
+
if response.status_code == 200:
|
| 137 |
+
with open(path, 'wb') as f:
|
| 138 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 139 |
+
f.write(chunk)
|
| 140 |
+
print(f"Saved to {path}")
|
| 141 |
+
else:
|
| 142 |
+
print(f"Failed to download from {url}, status code {response.status_code}")
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
## Validation
|
| 146 |
+
Open the `validation.ipnyb` file
|
| 147 |
+
|
| 148 |
+
Edit the **Block 6** to select which dataset is to use for validation
|
| 149 |
+
|
| 150 |
+
```
|
| 151 |
+
dataset_train = cityscapes.CityScapes('dataset/CityScapes/vallist.txt', data_set= 'val', transform=transform_train,seed=36, img_size=768)
|
| 152 |
+
# dataset_train = umbc.UMBC('dataset/UMBC/all.txt', data_set= 'val', transform=transform_train,seed=36, img_size=768)
|
| 153 |
+
# dataset_train = acdc.ACDC('dataset/ACDC/vallist_fog.txt', data_set= 'val', transform=transform_train,seed=36, img_size=768)
|
| 154 |
+
# dataset_train = semantickitti.SemanticKITTI('dataset/SemanticKitti/vallist.txt', data_set= 'val', transform=transform_train, seed=36, img_size=768)
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
Run all the blocks
|
| 158 |
+
|
| 159 |
+
## Training
|
| 160 |
+
|
| 161 |
+
### From Scratch
|
| 162 |
+
|
| 163 |
+
Run the following code in terminal
|
| 164 |
+
```
|
| 165 |
+
torchrun --nproc_per_node=4 train.py
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
it will save checkpoint in `output_dir/year.month.day.hour.min` folder, for e.g. `output_dir/2025.05.09.02.27`
|
| 169 |
+
|
| 170 |
+
### Resume Training
|
| 171 |
+
|
| 172 |
+
Run the following code in terminal
|
| 173 |
+
```
|
| 174 |
+
torchrun --nproc_per_node=4 train.py --resume year.month.day.hour.min
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
Here is an example code
|
| 178 |
+
```
|
| 179 |
+
torchrun --nproc_per_node=4 train.py --resume 2025.05.09.02.27
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
## Acknowlegement
|
| 183 |
+
The code is developed on top following codework
|
| 184 |
+
1. [latent-diffusion](https://github.com/CompVis/latent-diffusion)
|
| 185 |
+
2. [mar](https://github.com/LTH14/mar)
|
get_parser.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import yaml
|
| 4 |
+
def get_args_parser():
|
| 5 |
+
parser = argparse.ArgumentParser('MAR training with Diffusion Loss', add_help=False)
|
| 6 |
+
parser.add_argument('--batch_size', default=16, type=int,
|
| 7 |
+
help='Batch size per GPU (effective batch size is batch_size * # gpus')
|
| 8 |
+
parser.add_argument('--epochs', default=2000, type=int)
|
| 9 |
+
|
| 10 |
+
# Model parameters
|
| 11 |
+
parser.add_argument('--model', default='mar_base', type=str, metavar='MODEL',
|
| 12 |
+
help='Name of model to train')
|
| 13 |
+
parser.add_argument('--ckpt_path', default="pretrained_models/mar/city768.16.pth", type=str,
|
| 14 |
+
help='model checkpoint path')
|
| 15 |
+
|
| 16 |
+
# VAE parameters
|
| 17 |
+
parser.add_argument('--img_size', default=768, type=int,
|
| 18 |
+
help='images input size')
|
| 19 |
+
parser.add_argument('--vae_path', default="pretrained_models/vae/modelf16.ckpt", type=str,
|
| 20 |
+
help='images input size')
|
| 21 |
+
parser.add_argument('--vae_embed_dim', default=16, type=int,
|
| 22 |
+
help='vae output embedding dimension')
|
| 23 |
+
parser.add_argument('--vae_stride', default=16, type=int,
|
| 24 |
+
help='tokenizer stride, default use KL16')
|
| 25 |
+
parser.add_argument('--patch_size', default=1, type=int,
|
| 26 |
+
help='number of tokens to group as a patch.')
|
| 27 |
+
parser.add_argument('--config', default="ldm/config.yaml", type=str,
|
| 28 |
+
help='vae model configuration file')
|
| 29 |
+
|
| 30 |
+
# Generation parameters
|
| 31 |
+
parser.add_argument('--num_iter', default=64, type=int,
|
| 32 |
+
help='number of autoregressive iterations to generate an image')
|
| 33 |
+
parser.add_argument('--num_images', default=3000, type=int,
|
| 34 |
+
help='number of images to generate')
|
| 35 |
+
parser.add_argument('--cfg', default=1.0, type=float, help="classifier-free guidance")
|
| 36 |
+
parser.add_argument('--cfg_schedule', default="linear", type=str)
|
| 37 |
+
parser.add_argument('--label_drop_prob', default=0.1, type=float)
|
| 38 |
+
parser.add_argument('--eval_freq', type=int, default=40, help='evaluation frequency')
|
| 39 |
+
parser.add_argument('--save_last_freq', type=int, default=5, help='save last frequency')
|
| 40 |
+
parser.add_argument('--online_eval', action='store_true')
|
| 41 |
+
parser.add_argument('--evaluate', action='store_true')
|
| 42 |
+
parser.add_argument('--eval_bsz', type=int, default=64, help='generation batch size')
|
| 43 |
+
|
| 44 |
+
# Optimizer parameters
|
| 45 |
+
parser.add_argument('--weight_decay', type=float, default=0.02,
|
| 46 |
+
help='weight decay (default: 0.02)')
|
| 47 |
+
|
| 48 |
+
parser.add_argument('--grad_checkpointing', action='store_true')
|
| 49 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
| 50 |
+
help='learning rate (absolute lr)')
|
| 51 |
+
parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
|
| 52 |
+
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
| 53 |
+
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
|
| 54 |
+
help='lower lr bound for cyclic schedulers that hit 0')
|
| 55 |
+
parser.add_argument('--lr_schedule', type=str, default='constant',
|
| 56 |
+
help='learning rate schedule')
|
| 57 |
+
parser.add_argument('--warmup_epochs', type=int, default=100, metavar='N',
|
| 58 |
+
help='epochs to warmup LR')
|
| 59 |
+
parser.add_argument('--ema_rate', default=0.9999, type=float)
|
| 60 |
+
|
| 61 |
+
# MAR params
|
| 62 |
+
parser.add_argument('--mask_ratio_min', type=float, default=0.7,
|
| 63 |
+
help='Minimum mask ratio')
|
| 64 |
+
parser.add_argument('--grad_clip', type=float, default=3.0,
|
| 65 |
+
help='Gradient clip')
|
| 66 |
+
parser.add_argument('--attn_dropout', type=float, default=0.1,
|
| 67 |
+
help='attention dropout')
|
| 68 |
+
parser.add_argument('--proj_dropout', type=float, default=0.1,
|
| 69 |
+
help='projection dropout')
|
| 70 |
+
parser.add_argument('--buffer_size', type=int, default=64)
|
| 71 |
+
|
| 72 |
+
# Diffusion Loss params
|
| 73 |
+
parser.add_argument('--diffloss_d', type=int, default=6)
|
| 74 |
+
parser.add_argument('--diffloss_w', type=int, default=1024)
|
| 75 |
+
parser.add_argument('--num_sampling_steps', type=str, default="100")
|
| 76 |
+
parser.add_argument('--diffusion_batch_mul', type=int, default=4)
|
| 77 |
+
parser.add_argument('--temperature', default=1.0, type=float, help='diffusion loss sampling temperature')
|
| 78 |
+
|
| 79 |
+
# Dataset parameters
|
| 80 |
+
parser.add_argument('--output_dir', default='./output_dir',
|
| 81 |
+
help='path where to save, empty for no saving')
|
| 82 |
+
parser.add_argument('--log_dir', default='./output_dir',
|
| 83 |
+
help='path where to tensorboard log')
|
| 84 |
+
parser.add_argument('--device', default='cuda',
|
| 85 |
+
help='device to use for training / testing')
|
| 86 |
+
parser.add_argument('--seed', default=1, type=int)
|
| 87 |
+
parser.add_argument('--resume', default=None,#'pretrained_models/mar/mar_base',
|
| 88 |
+
help='resume from checkpoint')
|
| 89 |
+
|
| 90 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
| 91 |
+
help='start epoch')
|
| 92 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
| 93 |
+
parser.add_argument('--pin_mem', action='store_true',
|
| 94 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
| 95 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
| 96 |
+
parser.set_defaults(pin_mem=True)
|
| 97 |
+
|
| 98 |
+
# distributed training parameters
|
| 99 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 100 |
+
help='number of distributed processes')
|
| 101 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
| 102 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
| 103 |
+
parser.add_argument('--dist_url', default='env://',
|
| 104 |
+
help='url used to set up distributed training')
|
| 105 |
+
|
| 106 |
+
# caching latents
|
| 107 |
+
parser.add_argument('--use_cached', action='store_true', dest='use_cached',
|
| 108 |
+
help='Use cached latents')
|
| 109 |
+
parser.set_defaults(use_cached=False)
|
| 110 |
+
parser.add_argument('--cached_path', default='', help='path to cached latents')
|
| 111 |
+
|
| 112 |
+
return parser
|
| 113 |
+
|
| 114 |
+
args = get_args_parser()
|
| 115 |
+
args = args.parse_args()
|
| 116 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 117 |
+
args.log_dir = args.output_dir
|
| 118 |
+
|
| 119 |
+
with open(args.config, "r") as f:
|
| 120 |
+
config = yaml.safe_load(f)
|
| 121 |
+
args.ddconfig = config["ddconfig"]
|
train.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import datetime
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import yaml
|
| 8 |
+
import glob
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.backends.cudnn as cudnn
|
| 12 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 13 |
+
import torchvision.transforms as transforms
|
| 14 |
+
import torchvision.datasets as datasets
|
| 15 |
+
from data import cityscapes
|
| 16 |
+
|
| 17 |
+
from util.crop import center_crop_arr
|
| 18 |
+
import util.misc as misc
|
| 19 |
+
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
| 20 |
+
from util.loader import CachedFolder
|
| 21 |
+
|
| 22 |
+
from models.vae import AutoencoderKL
|
| 23 |
+
from models import mar
|
| 24 |
+
import copy
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
import util.lr_sched as lr_sched
|
| 28 |
+
|
| 29 |
+
import logging
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def update_ema(target_params, source_params, rate=0.99):
|
| 34 |
+
"""
|
| 35 |
+
Update target parameters to be closer to those of source parameters using
|
| 36 |
+
an exponential moving average.
|
| 37 |
+
|
| 38 |
+
:param target_params: the target parameter sequence.
|
| 39 |
+
:param source_params: the source parameter sequence.
|
| 40 |
+
:param rate: the EMA rate (closer to 1 means slower).
|
| 41 |
+
"""
|
| 42 |
+
for targ, src in zip(target_params, source_params):
|
| 43 |
+
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
| 44 |
+
|
| 45 |
+
def logger_file(path):
|
| 46 |
+
logger = logging.getLogger()
|
| 47 |
+
logger.setLevel(logging.DEBUG)
|
| 48 |
+
handler = logging.FileHandler(path,"w", encoding=None, delay="true")
|
| 49 |
+
handler.setLevel(logging.INFO)
|
| 50 |
+
formatter = logging.Formatter("%(message)s")
|
| 51 |
+
handler.setFormatter(formatter)
|
| 52 |
+
logger.addHandler(handler)
|
| 53 |
+
return logger
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_args_parser():
|
| 57 |
+
parser = argparse.ArgumentParser('MAR training with Diffusion Loss', add_help=False)
|
| 58 |
+
parser.add_argument('--batch_size', default=2, type=int,
|
| 59 |
+
help='Batch size per GPU (effective batch size is batch_size * # gpus')
|
| 60 |
+
parser.add_argument('--epochs', default=2000, type=int)
|
| 61 |
+
|
| 62 |
+
# Model parameters
|
| 63 |
+
parser.add_argument('--model', default='mar_base', type=str, metavar='MODEL',
|
| 64 |
+
help='Name of model to train')
|
| 65 |
+
parser.add_argument('--ckpt_path', default="pretrained_models/mar/city768.16.pth", type=str,
|
| 66 |
+
help='model checkpoint path')
|
| 67 |
+
|
| 68 |
+
# VAE parameters
|
| 69 |
+
parser.add_argument('--img_size', default=768, type=int,
|
| 70 |
+
help='images input size')
|
| 71 |
+
parser.add_argument('--vae_path', default="pretrained_models/vae/modelf16.ckpt", type=str,
|
| 72 |
+
help='images input size')
|
| 73 |
+
parser.add_argument('--vae_embed_dim', default=16, type=int,
|
| 74 |
+
help='vae output embedding dimension')
|
| 75 |
+
parser.add_argument('--vae_stride', default=16, type=int,
|
| 76 |
+
help='tokenizer stride, default use KL16')
|
| 77 |
+
parser.add_argument('--patch_size', default=1, type=int,
|
| 78 |
+
help='number of tokens to group as a patch.')
|
| 79 |
+
parser.add_argument('--config', default="ldm/config.yaml", type=str,
|
| 80 |
+
help='vae model configuration file')
|
| 81 |
+
|
| 82 |
+
# Generation parameters
|
| 83 |
+
parser.add_argument('--num_iter', default=64, type=int,
|
| 84 |
+
help='number of autoregressive iterations to generate an image')
|
| 85 |
+
parser.add_argument('--num_images', default=3000, type=int,
|
| 86 |
+
help='number of images to generate')
|
| 87 |
+
parser.add_argument('--cfg', default=1.0, type=float, help="classifier-free guidance")
|
| 88 |
+
parser.add_argument('--cfg_schedule', default="linear", type=str)
|
| 89 |
+
parser.add_argument('--label_drop_prob', default=0.1, type=float)
|
| 90 |
+
parser.add_argument('--eval_freq', type=int, default=40, help='evaluation frequency')
|
| 91 |
+
parser.add_argument('--save_last_freq', type=int, default=5, help='save last frequency')
|
| 92 |
+
parser.add_argument('--online_eval', action='store_true')
|
| 93 |
+
parser.add_argument('--evaluate', action='store_true')
|
| 94 |
+
parser.add_argument('--eval_bsz', type=int, default=64, help='generation batch size')
|
| 95 |
+
|
| 96 |
+
# Optimizer parameters
|
| 97 |
+
parser.add_argument('--weight_decay', type=float, default=0.02,
|
| 98 |
+
help='weight decay (default: 0.02)')
|
| 99 |
+
|
| 100 |
+
parser.add_argument('--grad_checkpointing', action='store_true')
|
| 101 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
| 102 |
+
help='learning rate (absolute lr)')
|
| 103 |
+
parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
|
| 104 |
+
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
| 105 |
+
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
|
| 106 |
+
help='lower lr bound for cyclic schedulers that hit 0')
|
| 107 |
+
parser.add_argument('--lr_schedule', type=str, default='constant',
|
| 108 |
+
help='learning rate schedule')
|
| 109 |
+
parser.add_argument('--warmup_epochs', type=int, default=100, metavar='N',
|
| 110 |
+
help='epochs to warmup LR')
|
| 111 |
+
parser.add_argument('--ema_rate', default=0.9999, type=float)
|
| 112 |
+
|
| 113 |
+
# MAR params
|
| 114 |
+
parser.add_argument('--mask_ratio_min', type=float, default=0.7,
|
| 115 |
+
help='Minimum mask ratio')
|
| 116 |
+
parser.add_argument('--grad_clip', type=float, default=3.0,
|
| 117 |
+
help='Gradient clip')
|
| 118 |
+
parser.add_argument('--attn_dropout', type=float, default=0.1,
|
| 119 |
+
help='attention dropout')
|
| 120 |
+
parser.add_argument('--proj_dropout', type=float, default=0.1,
|
| 121 |
+
help='projection dropout')
|
| 122 |
+
parser.add_argument('--buffer_size', type=int, default=64)
|
| 123 |
+
|
| 124 |
+
# Diffusion Loss params
|
| 125 |
+
parser.add_argument('--diffloss_d', type=int, default=6)
|
| 126 |
+
parser.add_argument('--diffloss_w', type=int, default=1024)
|
| 127 |
+
parser.add_argument('--num_sampling_steps', type=str, default="100")
|
| 128 |
+
parser.add_argument('--diffusion_batch_mul', type=int, default=4)
|
| 129 |
+
parser.add_argument('--temperature', default=1.0, type=float, help='diffusion loss sampling temperature')
|
| 130 |
+
|
| 131 |
+
# Dataset parameters
|
| 132 |
+
parser.add_argument('--output_dir', default='./output_dir',
|
| 133 |
+
help='path where to save, empty for no saving')
|
| 134 |
+
parser.add_argument('--log_dir', default='./output_dir',
|
| 135 |
+
help='path where to tensorboard log')
|
| 136 |
+
parser.add_argument('--device', default='cuda',
|
| 137 |
+
help='device to use for training / testing')
|
| 138 |
+
parser.add_argument('--seed', default=1, type=int)
|
| 139 |
+
parser.add_argument('--resume', default=None,
|
| 140 |
+
help='resume from checkpoint')
|
| 141 |
+
|
| 142 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
| 143 |
+
help='start epoch')
|
| 144 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
| 145 |
+
parser.add_argument('--pin_mem', action='store_true',
|
| 146 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
| 147 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
| 148 |
+
parser.set_defaults(pin_mem=True)
|
| 149 |
+
|
| 150 |
+
# distributed training parameters
|
| 151 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 152 |
+
help='number of distributed processes')
|
| 153 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
| 154 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
| 155 |
+
parser.add_argument('--dist_url', default='env://',
|
| 156 |
+
help='url used to set up distributed training')
|
| 157 |
+
|
| 158 |
+
# caching latents
|
| 159 |
+
parser.add_argument('--use_cached', action='store_true', dest='use_cached',
|
| 160 |
+
help='Use cached latents')
|
| 161 |
+
parser.set_defaults(use_cached=False)
|
| 162 |
+
parser.add_argument('--cached_path', default='', help='path to cached latents')
|
| 163 |
+
|
| 164 |
+
return parser
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def main(args):
|
| 168 |
+
misc.init_distributed_mode(args)
|
| 169 |
+
|
| 170 |
+
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
| 171 |
+
print("{}".format(args).replace(', ', ',\n'))
|
| 172 |
+
|
| 173 |
+
device = torch.device(args.device)
|
| 174 |
+
|
| 175 |
+
# fix the seed for reproducibility
|
| 176 |
+
seed = args.seed + misc.get_rank()
|
| 177 |
+
torch.manual_seed(seed)
|
| 178 |
+
np.random.seed(seed)
|
| 179 |
+
|
| 180 |
+
cudnn.benchmark = True
|
| 181 |
+
|
| 182 |
+
num_tasks = misc.get_world_size()
|
| 183 |
+
global_rank = misc.get_rank()
|
| 184 |
+
|
| 185 |
+
log_writer = None
|
| 186 |
+
|
| 187 |
+
# augmentation following DiT and ADM
|
| 188 |
+
transform_train = transforms.Compose([
|
| 189 |
+
transforms.ToTensor(),
|
| 190 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
| 191 |
+
])
|
| 192 |
+
|
| 193 |
+
dataset_train = cityscapes.CityScapes('dataset/CityScapes/trainlist.txt', transform=transform_train, img_size=args.img_size)
|
| 194 |
+
|
| 195 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
| 196 |
+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
| 197 |
+
)
|
| 198 |
+
print("Sampler_train = %s" % str(sampler_train))
|
| 199 |
+
|
| 200 |
+
data_loader_train = torch.utils.data.DataLoader(
|
| 201 |
+
dataset_train, sampler=sampler_train,
|
| 202 |
+
batch_size=args.batch_size,
|
| 203 |
+
num_workers=args.num_workers,
|
| 204 |
+
pin_memory=args.pin_mem,
|
| 205 |
+
drop_last=True,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# define the vae and mar model
|
| 209 |
+
with open(args.config, "r") as f:
|
| 210 |
+
config = yaml.safe_load(f)
|
| 211 |
+
args.ddconfig = config["ddconfig"]
|
| 212 |
+
print('cofig: ', config)
|
| 213 |
+
|
| 214 |
+
vae = AutoencoderKL(
|
| 215 |
+
ddconfig=args.ddconfig,
|
| 216 |
+
embed_dim=args.vae_embed_dim,
|
| 217 |
+
ckpt_path=args.vae_path
|
| 218 |
+
).cuda().eval()
|
| 219 |
+
|
| 220 |
+
for param in vae.parameters():
|
| 221 |
+
param.requires_grad = False
|
| 222 |
+
|
| 223 |
+
model = mar.__dict__[args.model](
|
| 224 |
+
img_size=args.img_size,
|
| 225 |
+
vae_stride=args.vae_stride,
|
| 226 |
+
patch_size=args.patch_size,
|
| 227 |
+
vae_embed_dim=args.vae_embed_dim,
|
| 228 |
+
mask_ratio_min=args.mask_ratio_min,
|
| 229 |
+
label_drop_prob=args.label_drop_prob,
|
| 230 |
+
attn_dropout=args.attn_dropout,
|
| 231 |
+
proj_dropout=args.proj_dropout,
|
| 232 |
+
buffer_size=args.buffer_size,
|
| 233 |
+
diffloss_d=args.diffloss_d,
|
| 234 |
+
diffloss_w=args.diffloss_w,
|
| 235 |
+
num_sampling_steps=args.num_sampling_steps,
|
| 236 |
+
diffusion_batch_mul=args.diffusion_batch_mul,
|
| 237 |
+
grad_checkpointing=args.grad_checkpointing,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
if args.ckpt_path:
|
| 241 |
+
checkpoint = torch.load(args.ckpt_path, map_location='cpu')
|
| 242 |
+
model.load_state_dict(checkpoint['model'])
|
| 243 |
+
|
| 244 |
+
print("Model = %s" % str(model))
|
| 245 |
+
# following timm: set wd as 0 for bias and norm layers
|
| 246 |
+
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 247 |
+
print("Number of trainable parameters: {}M".format(n_params / 1e6))
|
| 248 |
+
|
| 249 |
+
model.to(device)
|
| 250 |
+
model_without_ddp = model
|
| 251 |
+
|
| 252 |
+
eff_batch_size = args.batch_size * misc.get_world_size()
|
| 253 |
+
|
| 254 |
+
if args.lr is None: # only base_lr is specified
|
| 255 |
+
args.lr = args.blr
|
| 256 |
+
|
| 257 |
+
print("base lr: %.2e" % args.blr)
|
| 258 |
+
print("actual lr: %.2e" % args.lr)
|
| 259 |
+
print("effective batch size: %d" % eff_batch_size)
|
| 260 |
+
|
| 261 |
+
if args.distributed:
|
| 262 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
| 263 |
+
model_without_ddp = model.module
|
| 264 |
+
|
| 265 |
+
# no weight decay on bias, norm layers, and diffloss MLP
|
| 266 |
+
param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay)
|
| 267 |
+
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
|
| 268 |
+
print(optimizer)
|
| 269 |
+
loss_scaler = NativeScaler()
|
| 270 |
+
|
| 271 |
+
# resume training
|
| 272 |
+
if args.resume and glob.glob(os.path.join(args.output_dir, args.resume, 'checkpoint*.pth')):
|
| 273 |
+
try:
|
| 274 |
+
checkpoint = torch.load(sorted(glob.glob(os.path.join(args.output_dir, args.resume, 'checkpoint*.pth')))[-1], map_location='cpu')
|
| 275 |
+
model.load_state_dict(checkpoint['model'])
|
| 276 |
+
except:
|
| 277 |
+
checkpoint = torch.load(sorted(glob.glob(os.path.join(args.output_dir, args.resume, 'checkpoint*.pth')))[-2], map_location='cpu')
|
| 278 |
+
model.load_state_dict(checkpoint['model'])
|
| 279 |
+
state_dict = {key.replace("module.", ""): value for key, value in checkpoint['model'].items()}
|
| 280 |
+
model_without_ddp.load_state_dict(state_dict)
|
| 281 |
+
model_params = list(model_without_ddp.parameters())
|
| 282 |
+
ema_params = copy.deepcopy(model_params)
|
| 283 |
+
ema_state_dict = {key.replace("module.", ""): value for key, value in checkpoint['model_ema'].items()}
|
| 284 |
+
ema_params = [ema_state_dict[name].cuda() for name, _ in model_without_ddp.named_parameters()]
|
| 285 |
+
print("Resume checkpoint %s" % args.resume)
|
| 286 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
|
| 287 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 288 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
| 289 |
+
if 'scaler' in checkpoint:
|
| 290 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
| 291 |
+
print("With optim & sched!")
|
| 292 |
+
del checkpoint
|
| 293 |
+
|
| 294 |
+
args.output_dir = os.path.join(args.output_dir, args.resume)
|
| 295 |
+
|
| 296 |
+
logger = logger_file(args.log_dir+'/'+args.resume+'.log')
|
| 297 |
+
if os.path.exists(args.log_dir+'/'+args.resume+'.log'):
|
| 298 |
+
with open(args.log_dir+'/'+args.resume+'.log', 'r') as infile:
|
| 299 |
+
for line in infile:
|
| 300 |
+
logger.info(line.rstrip())
|
| 301 |
+
else:
|
| 302 |
+
logger.info("All the arguments")
|
| 303 |
+
for k, v in vars(args).items():
|
| 304 |
+
logger.info(f"{k}: {v}")
|
| 305 |
+
logger.info("\n\n Loss information")
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
else:
|
| 310 |
+
model_params = list(model_without_ddp.parameters())
|
| 311 |
+
ema_params = copy.deepcopy(model_params)
|
| 312 |
+
print("Training from scratch")
|
| 313 |
+
args.resume = datetime.datetime.now().strftime("%Y.%m.%d.%H.%M")
|
| 314 |
+
args.output_dir = os.path.join(args.output_dir, args.resume)
|
| 315 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 316 |
+
|
| 317 |
+
logger = logger_file(args.log_dir+'/'+args.resume+'.log')
|
| 318 |
+
logger.info("All the arguments")
|
| 319 |
+
for k, v in vars(args).items():
|
| 320 |
+
logger.info(f"{k}: {v}")
|
| 321 |
+
logger.info("\n\n Loss information")
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
print(f"Start training for {args.epochs} epochs")
|
| 325 |
+
start_time = time.time()
|
| 326 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 327 |
+
if args.distributed:
|
| 328 |
+
data_loader_train.sampler.set_epoch(epoch)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
for epoch in tqdm(range(args.start_epoch, args.epochs), desc="Training Progress"):
|
| 333 |
+
model.train(True)
|
| 334 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 335 |
+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 336 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 337 |
+
print_freq = 20
|
| 338 |
+
|
| 339 |
+
optimizer.zero_grad()
|
| 340 |
+
|
| 341 |
+
for data_iter_step, (samples, labels, _) in enumerate(data_loader_train):
|
| 342 |
+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_train) + epoch, args)
|
| 343 |
+
samples = samples.to(device, non_blocking=True)
|
| 344 |
+
labels = labels.to(device, non_blocking=True)
|
| 345 |
+
|
| 346 |
+
with torch.no_grad():
|
| 347 |
+
posterior_x = vae.encode(samples)
|
| 348 |
+
posterior_y = vae.encode(labels)
|
| 349 |
+
x = posterior_x.sample().mul_(0.2325)
|
| 350 |
+
y = posterior_y.sample().mul_(0.2325)
|
| 351 |
+
with torch.cuda.amp.autocast():
|
| 352 |
+
loss = model(x,y)
|
| 353 |
+
loss_value = loss.item()
|
| 354 |
+
loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), update_grad=True)
|
| 355 |
+
optimizer.zero_grad()
|
| 356 |
+
torch.cuda.synchronize()
|
| 357 |
+
|
| 358 |
+
update_ema(ema_params, model_params, rate=args.ema_rate)
|
| 359 |
+
metric_logger.update(loss=loss_value)
|
| 360 |
+
|
| 361 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 362 |
+
metric_logger.update(lr=lr)
|
| 363 |
+
|
| 364 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
| 365 |
+
metric_logger.synchronize_between_processes()
|
| 366 |
+
logger.info(f"epoch: {epoch:4d}, Averaged stats: {metric_logger}")
|
| 367 |
+
if (epoch+1)% args.save_last_freq == 0:
|
| 368 |
+
misc.save_model(args=args, model=model, model_without_ddp=model, optimizer=optimizer,
|
| 369 |
+
loss_scaler=loss_scaler, epoch=epoch, ema_params=ema_params, epoch_name=str(epoch).zfill(5))
|
| 370 |
+
|
| 371 |
+
total_time = time.time() - start_time
|
| 372 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 373 |
+
print('Training time {}'.format(total_time_str))
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
if __name__ == '__main__':
|
| 377 |
+
args = get_args_parser()
|
| 378 |
+
args = args.parse_args()
|
| 379 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 380 |
+
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
|
| 381 |
+
main(args)
|
validation.ipynb
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "c524f796-e657-4a59-abcf-540531a38995",
|
| 7 |
+
"metadata": {
|
| 8 |
+
"tags": []
|
| 9 |
+
},
|
| 10 |
+
"outputs": [],
|
| 11 |
+
"source": [
|
| 12 |
+
"%run get_parser.py"
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "code",
|
| 17 |
+
"execution_count": null,
|
| 18 |
+
"id": "4c1cf01e-8229-4d28-bcb2-01c07fa641c2",
|
| 19 |
+
"metadata": {
|
| 20 |
+
"tags": []
|
| 21 |
+
},
|
| 22 |
+
"outputs": [],
|
| 23 |
+
"source": [
|
| 24 |
+
"import os\n",
|
| 25 |
+
"import requests\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"# Define URLs and file paths\n",
|
| 28 |
+
"files_to_download = {\n",
|
| 29 |
+
" \"https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/vae/modelf16.ckpt\":\n",
|
| 30 |
+
" \"pretrained_models/vae/modelf16.ckpt\",\n",
|
| 31 |
+
" \"https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/mar/city768.16.pth\":\n",
|
| 32 |
+
" \"pretrained_models/mar/city768.16.pth\"\n",
|
| 33 |
+
"}\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"for url, path in files_to_download.items():\n",
|
| 36 |
+
" os.makedirs(os.path.dirname(path), exist_ok=True)\n",
|
| 37 |
+
" \n",
|
| 38 |
+
" if os.path.exists(path):\n",
|
| 39 |
+
" print(f\"File already exists: {path} β skipping download.\")\n",
|
| 40 |
+
" continue\n",
|
| 41 |
+
"\n",
|
| 42 |
+
" print(f\"Downloading from {url}...\")\n",
|
| 43 |
+
" response = requests.get(url, stream=True)\n",
|
| 44 |
+
" if response.status_code == 200:\n",
|
| 45 |
+
" with open(path, 'wb') as f:\n",
|
| 46 |
+
" for chunk in response.iter_content(chunk_size=8192):\n",
|
| 47 |
+
" f.write(chunk)\n",
|
| 48 |
+
" print(f\"Saved to {path}\")\n",
|
| 49 |
+
" else:\n",
|
| 50 |
+
" print(f\"Failed to download from {url}, status code {response.status_code}\")"
|
| 51 |
+
]
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "code",
|
| 55 |
+
"execution_count": 3,
|
| 56 |
+
"id": "3a7ac93b-1cbc-45f3-8ec5-8e8257a39786",
|
| 57 |
+
"metadata": {
|
| 58 |
+
"tags": []
|
| 59 |
+
},
|
| 60 |
+
"outputs": [],
|
| 61 |
+
"source": [
|
| 62 |
+
"import numpy as np\n",
|
| 63 |
+
"from tqdm import tqdm\n",
|
| 64 |
+
"from PIL import Image\n",
|
| 65 |
+
"import yaml\n",
|
| 66 |
+
"import math\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"import torch\n",
|
| 69 |
+
"import torch.backends.cudnn as cudnn\n",
|
| 70 |
+
"import torchvision.transforms as transforms\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"from data import cityscapes\n",
|
| 73 |
+
"import util.misc as misc\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"from models.vae import AutoencoderKL\n",
|
| 76 |
+
"from models import mar"
|
| 77 |
+
]
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"cell_type": "code",
|
| 81 |
+
"execution_count": 4,
|
| 82 |
+
"id": "e2bde6fd-9b39-40fd-8d4d-d0a5f9c8217a",
|
| 83 |
+
"metadata": {
|
| 84 |
+
"tags": []
|
| 85 |
+
},
|
| 86 |
+
"outputs": [],
|
| 87 |
+
"source": [
|
| 88 |
+
"def mask_by_order(mask_len, order, bsz, seq_len):\n",
|
| 89 |
+
" masking = torch.zeros(bsz, seq_len).cuda()\n",
|
| 90 |
+
" masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()\n",
|
| 91 |
+
" return masking\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"def fast_hist(pred, label, n):\n",
|
| 94 |
+
" k = (label >= 0) & (label < n)\n",
|
| 95 |
+
" bin_count = np.bincount(\n",
|
| 96 |
+
" n * label[k].astype(int) + pred[k], minlength=n ** 2)\n",
|
| 97 |
+
" return bin_count[:n ** 2].reshape(n, n)\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"color_pallete = np.round(np.array([\n",
|
| 100 |
+
" 0, 0, 0,\n",
|
| 101 |
+
" 128, 64, 128,\n",
|
| 102 |
+
" 244, 35, 232,\n",
|
| 103 |
+
" 70, 70, 70,\n",
|
| 104 |
+
" 102, 102, 156,\n",
|
| 105 |
+
" 190, 153, 153,\n",
|
| 106 |
+
" 153, 153, 153,\n",
|
| 107 |
+
" 250, 170, 30,\n",
|
| 108 |
+
" 220, 220, 0,\n",
|
| 109 |
+
" 107, 142, 35,\n",
|
| 110 |
+
" 152, 251, 152,\n",
|
| 111 |
+
" 0, 130, 180,\n",
|
| 112 |
+
" 220, 20, 60,\n",
|
| 113 |
+
" 255, 0, 0,\n",
|
| 114 |
+
" 0, 0, 142,\n",
|
| 115 |
+
" 0, 0, 70,\n",
|
| 116 |
+
" 0, 60, 100,\n",
|
| 117 |
+
" 0, 80, 100,\n",
|
| 118 |
+
" 0, 0, 230,\n",
|
| 119 |
+
" 119, 11, 32,\n",
|
| 120 |
+
" ])/255.0, 4)\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"color_pallete = color_pallete.reshape(-1, 3)"
|
| 123 |
+
]
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"cell_type": "code",
|
| 127 |
+
"execution_count": 5,
|
| 128 |
+
"id": "c189ac7b-ccff-4745-af56-460ec88770b4",
|
| 129 |
+
"metadata": {
|
| 130 |
+
"tags": []
|
| 131 |
+
},
|
| 132 |
+
"outputs": [],
|
| 133 |
+
"source": [
|
| 134 |
+
"device = torch.device(args.device)\n",
|
| 135 |
+
"device = torch.device('cuda:0')\n",
|
| 136 |
+
"args.batch_size = 1\n",
|
| 137 |
+
"\n",
|
| 138 |
+
"# fix the seed for reproducibility\n",
|
| 139 |
+
"seed = args.seed + misc.get_rank()\n",
|
| 140 |
+
"torch.manual_seed(seed)\n",
|
| 141 |
+
"np.random.seed(seed)\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"cudnn.benchmark = True\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"num_tasks = misc.get_world_size()\n",
|
| 146 |
+
"global_rank = misc.get_rank()"
|
| 147 |
+
]
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"cell_type": "code",
|
| 151 |
+
"execution_count": 6,
|
| 152 |
+
"id": "28d13453-a3ac-4d2e-8906-0c179e85c2f9",
|
| 153 |
+
"metadata": {
|
| 154 |
+
"tags": []
|
| 155 |
+
},
|
| 156 |
+
"outputs": [],
|
| 157 |
+
"source": [
|
| 158 |
+
"transform_train = transforms.Compose([\n",
|
| 159 |
+
" transforms.ToTensor(),\n",
|
| 160 |
+
" transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
|
| 161 |
+
"])\n",
|
| 162 |
+
"\n",
|
| 163 |
+
"dataset_train = cityscapes.CityScapes('dataset/CityScapes/vallist.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n",
|
| 164 |
+
"# dataset_train = umbc.UMBC('dataset/UMBC/all.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n",
|
| 165 |
+
"# dataset_train = acdc.ACDC('dataset/ACDC/vallist_fog.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n",
|
| 166 |
+
"# dataset_train = semantickitti.SemanticKITTI('dataset/SemanticKitti/vallist.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n",
|
| 167 |
+
"\n",
|
| 168 |
+
"\n",
|
| 169 |
+
"sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=1, rank=0, shuffle=False)\n",
|
| 170 |
+
"\n",
|
| 171 |
+
"data_loader_train = torch.utils.data.DataLoader(\n",
|
| 172 |
+
" dataset_train, sampler=sampler_train,\n",
|
| 173 |
+
" batch_size=args.batch_size,\n",
|
| 174 |
+
" num_workers=args.num_workers,\n",
|
| 175 |
+
" pin_memory=args.pin_mem,\n",
|
| 176 |
+
" drop_last=True,\n",
|
| 177 |
+
")"
|
| 178 |
+
]
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"cell_type": "code",
|
| 182 |
+
"execution_count": null,
|
| 183 |
+
"id": "2e22d231-02db-4586-b489-01a97314aed9",
|
| 184 |
+
"metadata": {
|
| 185 |
+
"tags": []
|
| 186 |
+
},
|
| 187 |
+
"outputs": [],
|
| 188 |
+
"source": [
|
| 189 |
+
"vae = AutoencoderKL(\n",
|
| 190 |
+
" ddconfig=args.ddconfig,\n",
|
| 191 |
+
" embed_dim=args.vae_embed_dim,\n",
|
| 192 |
+
" ckpt_path=args.vae_path\n",
|
| 193 |
+
").to(device).eval()\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"for param in vae.parameters():\n",
|
| 196 |
+
" param.requires_grad = False\n",
|
| 197 |
+
" \n",
|
| 198 |
+
"model = mar.mar_base(\n",
|
| 199 |
+
" img_size=args.img_size,\n",
|
| 200 |
+
" vae_stride=args.vae_stride,\n",
|
| 201 |
+
" patch_size=args.patch_size,\n",
|
| 202 |
+
" vae_embed_dim=args.vae_embed_dim,\n",
|
| 203 |
+
" mask_ratio_min=args.mask_ratio_min,\n",
|
| 204 |
+
" label_drop_prob=args.label_drop_prob,\n",
|
| 205 |
+
" attn_dropout=args.attn_dropout,\n",
|
| 206 |
+
" proj_dropout=args.proj_dropout,\n",
|
| 207 |
+
" buffer_size=args.buffer_size,\n",
|
| 208 |
+
" diffloss_d=args.diffloss_d,\n",
|
| 209 |
+
" diffloss_w=args.diffloss_w,\n",
|
| 210 |
+
" num_sampling_steps=args.num_sampling_steps,\n",
|
| 211 |
+
" diffusion_batch_mul=args.diffusion_batch_mul,\n",
|
| 212 |
+
" grad_checkpointing=args.grad_checkpointing,\n",
|
| 213 |
+
")\n",
|
| 214 |
+
"\n",
|
| 215 |
+
"n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
| 216 |
+
"print(\"Number of trainable parameters: {}M\".format(n_params / 1e6))\n",
|
| 217 |
+
"\n",
|
| 218 |
+
"\n",
|
| 219 |
+
"checkpoint = torch.load(args.ckpt_path, map_location='cpu')\n",
|
| 220 |
+
"model.load_state_dict(checkpoint['model'])\n",
|
| 221 |
+
"model.to(device)\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"eff_batch_size = args.batch_size * misc.get_world_size()\n",
|
| 224 |
+
"\n",
|
| 225 |
+
"print(\"effective batch size: %d\" % eff_batch_size)"
|
| 226 |
+
]
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"cell_type": "code",
|
| 230 |
+
"execution_count": 8,
|
| 231 |
+
"id": "4c83c0eb-35a5-4241-b869-d52eb6cd31e0",
|
| 232 |
+
"metadata": {
|
| 233 |
+
"tags": []
|
| 234 |
+
},
|
| 235 |
+
"outputs": [
|
| 236 |
+
{
|
| 237 |
+
"name": "stderr",
|
| 238 |
+
"output_type": "stream",
|
| 239 |
+
"text": [
|
| 240 |
+
"Training Progress: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 500/500 [13:11<00:00, 1.58s/it]"
|
| 241 |
+
]
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"name": "stdout",
|
| 245 |
+
"output_type": "stream",
|
| 246 |
+
"text": [
|
| 247 |
+
"road : 98.06\n",
|
| 248 |
+
"sidewalk : 86.32\n",
|
| 249 |
+
"building : 89.23\n",
|
| 250 |
+
"wall : 47.44\n",
|
| 251 |
+
"fence : 43.78\n",
|
| 252 |
+
"pole : 60.14\n",
|
| 253 |
+
"tlight : 63.16\n",
|
| 254 |
+
"tsign : 82.48\n",
|
| 255 |
+
"vtation : 92.72\n",
|
| 256 |
+
"terrain : 80.45\n",
|
| 257 |
+
"sky : 95.99\n",
|
| 258 |
+
"person : 70.83\n",
|
| 259 |
+
"rider : 64.25\n",
|
| 260 |
+
"car : 94.06\n",
|
| 261 |
+
"truck : 44.90\n",
|
| 262 |
+
"bus : 66.81\n",
|
| 263 |
+
"train : 44.04\n",
|
| 264 |
+
"motorcycle : 47.34\n",
|
| 265 |
+
"bicycle : 62.50\n",
|
| 266 |
+
"Avg Pre : 70.24\n"
|
| 267 |
+
]
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
"name": "stderr",
|
| 271 |
+
"output_type": "stream",
|
| 272 |
+
"text": [
|
| 273 |
+
"\n"
|
| 274 |
+
]
|
| 275 |
+
}
|
| 276 |
+
],
|
| 277 |
+
"source": [
|
| 278 |
+
"hist = []\n",
|
| 279 |
+
"model.eval()\n",
|
| 280 |
+
"for data_iter_step, (samples, labels, path) in enumerate(tqdm(data_loader_train, desc=\"Training Progress\")):\n",
|
| 281 |
+
" samples = samples.to(device, non_blocking=True)\n",
|
| 282 |
+
" labels = labels.to(device, non_blocking=True)\n",
|
| 283 |
+
"\n",
|
| 284 |
+
" with torch.no_grad():\n",
|
| 285 |
+
" posterior_x = vae.encode(samples)\n",
|
| 286 |
+
" posterior_y = vae.encode(labels)\n",
|
| 287 |
+
" x = posterior_x.sample().mul_(0.2325)\n",
|
| 288 |
+
" y = posterior_y.sample().mul_(0.2325)\n",
|
| 289 |
+
" x = model.patchify(x)\n",
|
| 290 |
+
" y = model.patchify(y)\n",
|
| 291 |
+
" gt_latents = y.clone().detach()\n",
|
| 292 |
+
" cfg_iter = 1.0\n",
|
| 293 |
+
" temperature = 1.0\n",
|
| 294 |
+
" mask_actual = torch.cat([torch.zeros(args.batch_size, model.seq_len), torch.ones(args.batch_size, model.seq_len)], dim=1).cuda()\n",
|
| 295 |
+
" tokens = torch.zeros(args.batch_size, model.seq_len, model.token_embed_dim).cuda()\n",
|
| 296 |
+
"\n",
|
| 297 |
+
" with torch.no_grad():\n",
|
| 298 |
+
" x1 = model.forward_mae_encoder(x, mask_actual, tokens)\n",
|
| 299 |
+
" z = model.forward_mae_decoder(x1, mask_actual)\n",
|
| 300 |
+
" z = z[0]\n",
|
| 301 |
+
" sampled_token_latent = model.diffloss.sample(z, temperature, cfg_iter)\n",
|
| 302 |
+
"\n",
|
| 303 |
+
" tokens[0] = sampled_token_latent[model.seq_len:]\n",
|
| 304 |
+
" tokens = model.unpatchify(tokens)\n",
|
| 305 |
+
" \n",
|
| 306 |
+
" sampled_images = vae.decode(tokens / 0.2325)\n",
|
| 307 |
+
" \n",
|
| 308 |
+
" image_tensor = labels[0] \n",
|
| 309 |
+
" image_tensor = image_tensor * 0.5 + 0.5\n",
|
| 310 |
+
" gt_np = image_tensor.permute(1, 2, 0).cpu().numpy()\n",
|
| 311 |
+
" H, W, _ = gt_np.shape\n",
|
| 312 |
+
" pixels = gt_np.reshape(-1, 3)\n",
|
| 313 |
+
" distances = np.linalg.norm(pixels[:, None, :] - color_pallete[None, :, :], axis=2)\n",
|
| 314 |
+
" output = np.argmin(distances, axis=1)\n",
|
| 315 |
+
" gt = output.reshape(H, W)\n",
|
| 316 |
+
" \n",
|
| 317 |
+
" image_tensor = sampled_images[0]\n",
|
| 318 |
+
" image_tensor = image_tensor * 0.5 + 0.5 \n",
|
| 319 |
+
" ss_np = image_tensor.permute(1, 2, 0).cpu().numpy()\n",
|
| 320 |
+
" H, W, _ = ss_np.shape\n",
|
| 321 |
+
" pixels = ss_np.reshape(-1, 3)\n",
|
| 322 |
+
" distances = np.linalg.norm(pixels[:, None, :] - color_pallete[None, :, :], axis=2)\n",
|
| 323 |
+
" output = np.argmin(distances, axis=1)\n",
|
| 324 |
+
" output = output.reshape(H, W)\n",
|
| 325 |
+
" \n",
|
| 326 |
+
" hist.append(fast_hist(output.reshape(-1), gt.reshape(-1), 20))\n",
|
| 327 |
+
"\n",
|
| 328 |
+
"cm = np.sum(hist, axis=0)\n",
|
| 329 |
+
"\n",
|
| 330 |
+
"epsilon = 1e-10\n",
|
| 331 |
+
"class_precision = np.diag(cm[1:,1:]) / (np.sum(cm[1:,1:], axis=0) + epsilon)\n",
|
| 332 |
+
"class_names = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'tlight', 'tsign', \n",
|
| 333 |
+
" 'vtation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \n",
|
| 334 |
+
" 'motorcycle', 'bicycle']\n",
|
| 335 |
+
"\n",
|
| 336 |
+
"for i in range(len(class_names)):\n",
|
| 337 |
+
" print(f\"{class_names[i]:<12}: {class_precision[i]*100:6.2f}\")\n",
|
| 338 |
+
"average_precision = np.mean(class_precision)\n",
|
| 339 |
+
"print(f\"{'Avg Pre':<12}: {average_precision*100:6.2f}\")"
|
| 340 |
+
]
|
| 341 |
+
}
|
| 342 |
+
],
|
| 343 |
+
"metadata": {
|
| 344 |
+
"kernelspec": {
|
| 345 |
+
"display_name": "Python 3 (ipykernel)",
|
| 346 |
+
"language": "python",
|
| 347 |
+
"name": "python3"
|
| 348 |
+
},
|
| 349 |
+
"language_info": {
|
| 350 |
+
"codemirror_mode": {
|
| 351 |
+
"name": "ipython",
|
| 352 |
+
"version": 3
|
| 353 |
+
},
|
| 354 |
+
"file_extension": ".py",
|
| 355 |
+
"mimetype": "text/x-python",
|
| 356 |
+
"name": "python",
|
| 357 |
+
"nbconvert_exporter": "python",
|
| 358 |
+
"pygments_lexer": "ipython3",
|
| 359 |
+
"version": "3.8.10"
|
| 360 |
+
}
|
| 361 |
+
},
|
| 362 |
+
"nbformat": 4,
|
| 363 |
+
"nbformat_minor": 5
|
| 364 |
+
}
|