--- language: en tags: - pytorch - gan - conditional-generation - mnist - image-generation - projection-discriminator - generative-adversarial-network --- # MNIST conditional GAN (cGAN) **Class-conditional** synthesis of **28×28 grayscale** MNIST-style digits. The generator maps noise **z** and digit label **y** to an image; the discriminator uses a **projection discriminator** (Miyato & Koyama, ICLR 2018) with spectral normalization. ## Files in this model repo | File | Description | |------|-------------| | `mnist_cgan_generator.pth` | Generator ``state_dict`` for inference (matches ``submission_digit_cgan.py``). | | `training_.pt` | Original checkpoint file (full training state when applicable). | | `cgan_architecture.py` | Copy of ``digit_cgan/model.py`` (**Generator** + **Discriminator** definitions). | | `generator_config.json` | Inferred constructor kwargs and metadata. | ## Weights **Checkpoint:** generator-only export (epoch not in file). **Inferred architecture (from tensors):** - ``latent_dim=100``, ``embed_dim=100``, ``base_channels=384``, ``num_classes=10`` - Output shape: ``(B, 1, 28, 28)``, values in ``[-1, 1]`` (``tanh``). Source file: ``mnist_cgan_generator.pth``. ### Load the generator (example) ```python import torch from huggingface_hub import hf_hub_download import sys sys.path.insert(0, "/path/to/week-06") from digit_cgan.model import Generator repo_id = "" weights = hf_hub_download(repo_id, "mnist_cgan_generator.pth") G = Generator( latent_dim=100, embed_dim=100, base_channels=384, num_classes=10, ) G.load_state_dict(torch.load(weights, map_location="cpu", weights_only=True)) G.eval() with torch.no_grad(): z = torch.randn(4, 100) y = torch.tensor([0, 1, 2, 3]) fake = G(z, y) ``` ## Architecture (`cgan_architecture.py`) - **Generator:** class embedding concatenated with **z**, linear reshape to **7×7** features, two **ConvTranspose2d** stages to **28×28**, conv to 1 channel + **tanh**. - **Discriminator:** convolutional backbone with **spectral norm**, global pool, linear map to a feature vector; score is **unconditional linear term** plus **inner product** between features and a **class embedding** (projection term). See T. Miyato & M. Koyama, *cGANs with Projection Discriminator*, ICLR 2018. ## Training (typical) ``python -m digit_cgan.train`` — hinge loss, Adam, optional **EMA** on the generator for sampling; **best FID** checkpoints use the EMA weights in ``best_generator.pth``. CLI defaults in ``train.py`` include ``latent_dim=100``, ``embed_dim=100``; ``base_channels_g`` / ``base_channels_d`` / ``feature_dim`` may differ per run — always use ``generator_config.json`` or infer from weights as above. ## Limitations MNIST is a simple benchmark; generalization to out-of-distribution digit styles is not guaranteed. ## References - Takeru Miyato, Masanori Koyama, *cGANs with Projection Discriminator*, https://arxiv.org/abs/1802.05637