Add amp_flow_training_single_gpu_full_data.py
Browse files
src/amp_flow_training_single_gpu_full_data.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import argparse
|
| 12 |
+
import time
|
| 13 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 14 |
+
import wandb # For logging (optional)
|
| 15 |
+
|
| 16 |
+
# Import your existing components
|
| 17 |
+
from compressor_with_embeddings import Compressor, Decompressor, PrecomputedEmbeddingDataset
|
| 18 |
+
from final_flow_model import AMPFlowMatcherCFGConcat, SinusoidalTimeEmbedding
|
| 19 |
+
from cfg_dataset import CFGFlowDataset, create_cfg_dataloader
|
| 20 |
+
|
| 21 |
+
# ---------------- Optimized Configuration for H100 ----------------
|
| 22 |
+
ESM_DIM = 1280 # ESM-2 hidden dim (esm2_t33_650M_UR50D)
|
| 23 |
+
COMP_RATIO = 16 # compression factor
|
| 24 |
+
COMP_DIM = ESM_DIM // COMP_RATIO
|
| 25 |
+
MAX_SEQ_LEN = 50 # Actual sequence length from final_sequence_encoder.py
|
| 26 |
+
|
| 27 |
+
# OPTIMIZED H100 hyperparameters - HIGH THROUGHPUT + STABLE TRAINING
|
| 28 |
+
BATCH_SIZE = 512 # PUSH H100 TO LIMITS - using ~70GB memory
|
| 29 |
+
EPOCHS = 2000 # Slightly more epochs with safer LR for same 5-6 hour target
|
| 30 |
+
BASE_LR = 8e-4 # SAFE but effective LR - 2x original, not 4x
|
| 31 |
+
LR_MIN = 4e-4 # Conservative minimum learning rate
|
| 32 |
+
WARMUP_STEPS = 4000 # Gentler warmup to avoid explosion
|
| 33 |
+
GPU_ID = 0 # Use GPU 0
|
| 34 |
+
|
| 35 |
+
# Training optimizations
|
| 36 |
+
USE_MIXED_PRECISION = True # BF16 for H100
|
| 37 |
+
GRADIENT_CLIP_NORM = 0.5 # TIGHTER gradient clipping for flow matching stability
|
| 38 |
+
WEIGHT_DECAY = 0.01 # Weight decay for regularization
|
| 39 |
+
VALIDATION_INTERVAL = 5000 # Validate every 5K steps (more frequent)
|
| 40 |
+
CHECKPOINT_INTERVAL = 300 # Save checkpoint every 300 epochs (more frequent)
|
| 41 |
+
NUM_WORKERS = 32 # MAXIMIZED data loading workers for H100
|
| 42 |
+
|
| 43 |
+
# CFG training parameters
|
| 44 |
+
CFG_DROPOUT_RATE = 0.15 # 15% of batches as unconditional for CFG
|
| 45 |
+
|
| 46 |
+
class AMPFlowTrainerSingleGPUFullData:
|
| 47 |
+
"""
|
| 48 |
+
Optimized Single GPU training pipeline for AMP generation using ProtFlow methodology.
|
| 49 |
+
Uses ALL available data with H100-optimized settings for overnight training.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, embeddings_path, cfg_data_path, use_wandb=False):
|
| 53 |
+
self.device = torch.device(f'cuda:{GPU_ID}')
|
| 54 |
+
self.embeddings_path = embeddings_path
|
| 55 |
+
self.cfg_data_path = cfg_data_path
|
| 56 |
+
self.use_wandb = use_wandb
|
| 57 |
+
|
| 58 |
+
# Enable H100 optimizations
|
| 59 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 60 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 61 |
+
|
| 62 |
+
print(f"Using GPU {GPU_ID} for optimized H100 training")
|
| 63 |
+
print(f"Mixed precision: {USE_MIXED_PRECISION}")
|
| 64 |
+
print(f"Batch size: {BATCH_SIZE}")
|
| 65 |
+
print(f"Target epochs: {EPOCHS}")
|
| 66 |
+
print(f"Learning rate: {BASE_LR} -> {LR_MIN}")
|
| 67 |
+
|
| 68 |
+
# Initialize mixed precision training
|
| 69 |
+
if USE_MIXED_PRECISION:
|
| 70 |
+
self.scaler = GradScaler()
|
| 71 |
+
print("✓ Mixed precision training enabled (BF16)")
|
| 72 |
+
|
| 73 |
+
# Initialize wandb if requested
|
| 74 |
+
if self.use_wandb:
|
| 75 |
+
wandb.init(
|
| 76 |
+
project="amp-flow-training",
|
| 77 |
+
config={
|
| 78 |
+
"batch_size": BATCH_SIZE,
|
| 79 |
+
"epochs": EPOCHS,
|
| 80 |
+
"base_lr": BASE_LR,
|
| 81 |
+
"lr_min": LR_MIN,
|
| 82 |
+
"warmup_steps": WARMUP_STEPS,
|
| 83 |
+
"mixed_precision": USE_MIXED_PRECISION,
|
| 84 |
+
"gradient_clip": GRADIENT_CLIP_NORM,
|
| 85 |
+
"weight_decay": WEIGHT_DECAY
|
| 86 |
+
}
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
print(f"Loading ALL AMP embeddings from {embeddings_path}...")
|
| 90 |
+
|
| 91 |
+
# Load ALL embeddings (use the combined file instead of individual files)
|
| 92 |
+
self._load_all_embeddings()
|
| 93 |
+
|
| 94 |
+
# Compute normalization statistics
|
| 95 |
+
print("Computing preprocessing statistics...")
|
| 96 |
+
self._compute_preprocessing_stats()
|
| 97 |
+
|
| 98 |
+
# Initialize models
|
| 99 |
+
self._initialize_models()
|
| 100 |
+
|
| 101 |
+
# Initialize datasets and dataloaders
|
| 102 |
+
self._initialize_data()
|
| 103 |
+
|
| 104 |
+
# Initialize optimizer and scheduler
|
| 105 |
+
self._initialize_optimizer()
|
| 106 |
+
|
| 107 |
+
print("✓ Optimized Single GPU training setup complete with FULL DATA!")
|
| 108 |
+
|
| 109 |
+
def _load_all_embeddings(self):
|
| 110 |
+
"""Load ALL peptide embeddings from the combined file."""
|
| 111 |
+
# Try to load the combined embeddings file first
|
| 112 |
+
combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt")
|
| 113 |
+
|
| 114 |
+
if os.path.exists(combined_path):
|
| 115 |
+
print(f"Loading combined embeddings from {combined_path}...")
|
| 116 |
+
self.embeddings = torch.load(combined_path, map_location=self.device)
|
| 117 |
+
print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}")
|
| 118 |
+
else:
|
| 119 |
+
print("Combined embeddings file not found, loading individual files...")
|
| 120 |
+
# Fallback to individual files
|
| 121 |
+
import glob
|
| 122 |
+
|
| 123 |
+
embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt"))
|
| 124 |
+
embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')]
|
| 125 |
+
|
| 126 |
+
print(f"Found {len(embedding_files)} individual embedding files")
|
| 127 |
+
|
| 128 |
+
# Load and stack all embeddings
|
| 129 |
+
embeddings_list = []
|
| 130 |
+
for file_path in embedding_files:
|
| 131 |
+
try:
|
| 132 |
+
embedding = torch.load(file_path)
|
| 133 |
+
if embedding.dim() == 2: # (seq_len, hidden_dim)
|
| 134 |
+
embeddings_list.append(embedding)
|
| 135 |
+
else:
|
| 136 |
+
print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}")
|
| 137 |
+
except Exception as e:
|
| 138 |
+
print(f"Warning: Could not load {file_path}: {e}")
|
| 139 |
+
|
| 140 |
+
if not embeddings_list:
|
| 141 |
+
raise ValueError("No valid embeddings found!")
|
| 142 |
+
|
| 143 |
+
self.embeddings = torch.stack(embeddings_list)
|
| 144 |
+
print(f"Loaded {len(self.embeddings)} embeddings from individual files")
|
| 145 |
+
|
| 146 |
+
def _compute_preprocessing_stats(self):
|
| 147 |
+
"""Compute normalization statistics for embeddings."""
|
| 148 |
+
# Flatten all embeddings
|
| 149 |
+
flat_embeddings = self.embeddings.reshape(-1, ESM_DIM)
|
| 150 |
+
|
| 151 |
+
# Compute statistics
|
| 152 |
+
mean = flat_embeddings.mean(dim=0)
|
| 153 |
+
std = flat_embeddings.std(dim=0)
|
| 154 |
+
min_val = flat_embeddings.min()
|
| 155 |
+
max_val = flat_embeddings.max()
|
| 156 |
+
|
| 157 |
+
self.stats = {
|
| 158 |
+
'mean': mean,
|
| 159 |
+
'std': std,
|
| 160 |
+
'min': min_val,
|
| 161 |
+
'max': max_val
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
# Save statistics
|
| 165 |
+
torch.save(self.stats, 'normalization_stats.pt')
|
| 166 |
+
print(f"✓ Statistics computed and saved:")
|
| 167 |
+
print(f" Total embeddings: {len(self.embeddings):,}")
|
| 168 |
+
print(f" Mean: {mean.mean():.4f} ± {mean.std():.4f}")
|
| 169 |
+
print(f" Std: {std.mean():.4f} ± {std.std():.4f}")
|
| 170 |
+
print(f" Range: [{min_val:.4f}, {max_val:.4f}]")
|
| 171 |
+
|
| 172 |
+
def _initialize_models(self):
|
| 173 |
+
"""Initialize compressor, decompressor, and flow model."""
|
| 174 |
+
print("Initializing models...")
|
| 175 |
+
|
| 176 |
+
# Load pre-trained compressor and decompressor
|
| 177 |
+
self.compressor = Compressor().to(self.device)
|
| 178 |
+
self.decompressor = Decompressor().to(self.device)
|
| 179 |
+
|
| 180 |
+
self.compressor.load_state_dict(torch.load('final_compressor_model.pth', map_location=self.device))
|
| 181 |
+
self.decompressor.load_state_dict(torch.load('final_decompressor_model.pth', map_location=self.device))
|
| 182 |
+
|
| 183 |
+
# Initialize flow model with CFG
|
| 184 |
+
self.flow_model = AMPFlowMatcherCFGConcat(
|
| 185 |
+
hidden_dim=480,
|
| 186 |
+
compressed_dim=COMP_DIM,
|
| 187 |
+
n_layers=12,
|
| 188 |
+
n_heads=16,
|
| 189 |
+
dim_ff=3072,
|
| 190 |
+
max_seq_len=25, # MAX_SEQ_LEN // 2 due to pooling
|
| 191 |
+
use_cfg=True
|
| 192 |
+
).to(self.device)
|
| 193 |
+
|
| 194 |
+
# Compile model for PyTorch 2.x speedup (if available)
|
| 195 |
+
try:
|
| 196 |
+
self.flow_model = torch.compile(self.flow_model, mode="reduce-overhead")
|
| 197 |
+
print("✓ Model compiled with torch.compile for speedup")
|
| 198 |
+
except Exception as e:
|
| 199 |
+
print(f"⚠️ Model compilation failed: {e}")
|
| 200 |
+
|
| 201 |
+
# Set models to training mode
|
| 202 |
+
self.compressor.train()
|
| 203 |
+
self.decompressor.train()
|
| 204 |
+
self.flow_model.train()
|
| 205 |
+
|
| 206 |
+
print(f"✓ Models initialized:")
|
| 207 |
+
print(f" Compressor parameters: {sum(p.numel() for p in self.compressor.parameters()):,}")
|
| 208 |
+
print(f" Decompressor parameters: {sum(p.numel() for p in self.decompressor.parameters()):,}")
|
| 209 |
+
print(f" Flow model parameters: {sum(p.numel() for p in self.flow_model.parameters()):,}")
|
| 210 |
+
|
| 211 |
+
def _initialize_data(self):
|
| 212 |
+
"""Initialize datasets and dataloaders with FULL data."""
|
| 213 |
+
print("Initializing datasets with FULL data...")
|
| 214 |
+
|
| 215 |
+
# Create CFG dataset with FULL UniProt data
|
| 216 |
+
self.cfg_dataset = CFGFlowDataset(
|
| 217 |
+
embeddings_path=self.embeddings_path,
|
| 218 |
+
cfg_data_path=self.cfg_data_path,
|
| 219 |
+
use_masked_labels=True,
|
| 220 |
+
max_seq_len=MAX_SEQ_LEN,
|
| 221 |
+
device=self.device
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Create dataloader with optimized settings
|
| 225 |
+
self.dataloader = create_cfg_dataloader(
|
| 226 |
+
self.cfg_dataset,
|
| 227 |
+
batch_size=BATCH_SIZE,
|
| 228 |
+
shuffle=True,
|
| 229 |
+
num_workers=NUM_WORKERS
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Calculate total steps and validation intervals
|
| 233 |
+
self.total_steps = len(self.dataloader) * EPOCHS
|
| 234 |
+
self.validation_steps = VALIDATION_INTERVAL
|
| 235 |
+
|
| 236 |
+
print(f"✓ Dataset initialized with FULL data:")
|
| 237 |
+
print(f" Total samples: {len(self.cfg_dataset):,}")
|
| 238 |
+
print(f" Batch size: {BATCH_SIZE}")
|
| 239 |
+
print(f" Batches per epoch: {len(self.dataloader):,}")
|
| 240 |
+
print(f" Total training steps: {self.total_steps:,}")
|
| 241 |
+
print(f" Validation every: {self.validation_steps:,} steps")
|
| 242 |
+
|
| 243 |
+
def _initialize_optimizer(self):
|
| 244 |
+
"""Initialize optimizer and learning rate scheduler."""
|
| 245 |
+
print("Initializing optimizer and scheduler...")
|
| 246 |
+
|
| 247 |
+
# Optimizer for flow model only (compressor/decompressor are frozen)
|
| 248 |
+
self.optimizer = optim.AdamW(
|
| 249 |
+
self.flow_model.parameters(),
|
| 250 |
+
lr=BASE_LR,
|
| 251 |
+
weight_decay=WEIGHT_DECAY,
|
| 252 |
+
betas=(0.9, 0.98), # Optimized betas for flow matching
|
| 253 |
+
eps=1e-6 # Lower epsilon for numerical stability
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Learning rate scheduler with proper warmup and cosine annealing
|
| 257 |
+
warmup_scheduler = LinearLR(
|
| 258 |
+
self.optimizer,
|
| 259 |
+
start_factor=0.1,
|
| 260 |
+
end_factor=1.0,
|
| 261 |
+
total_iters=WARMUP_STEPS
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
main_scheduler = CosineAnnealingLR(
|
| 265 |
+
self.optimizer,
|
| 266 |
+
T_max=self.total_steps - WARMUP_STEPS,
|
| 267 |
+
eta_min=LR_MIN
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
self.scheduler = SequentialLR(
|
| 271 |
+
self.optimizer,
|
| 272 |
+
schedulers=[warmup_scheduler, main_scheduler],
|
| 273 |
+
milestones=[WARMUP_STEPS]
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
print(f"✓ Optimizer initialized:")
|
| 277 |
+
print(f" Base LR: {BASE_LR}")
|
| 278 |
+
print(f" Min LR: {LR_MIN}")
|
| 279 |
+
print(f" Warmup steps: {WARMUP_STEPS}")
|
| 280 |
+
print(f" Weight decay: {WEIGHT_DECAY}")
|
| 281 |
+
print(f" Gradient clip norm: {GRADIENT_CLIP_NORM}")
|
| 282 |
+
|
| 283 |
+
def _preprocess_batch(self, batch):
|
| 284 |
+
"""Preprocess a batch of data for training."""
|
| 285 |
+
# Extract data
|
| 286 |
+
embeddings = batch['embeddings'].to(self.device) # (B, L, ESM_DIM)
|
| 287 |
+
labels = batch['labels'].to(self.device) # (B,)
|
| 288 |
+
|
| 289 |
+
# Normalize embeddings
|
| 290 |
+
m, s = self.stats['mean'].to(self.device), self.stats['std'].to(self.device)
|
| 291 |
+
mn, mx = self.stats['min'].to(self.device), self.stats['max'].to(self.device)
|
| 292 |
+
|
| 293 |
+
embeddings = (embeddings - m) / (s + 1e-8)
|
| 294 |
+
embeddings = (embeddings - mn) / (mx - mn + 1e-8)
|
| 295 |
+
|
| 296 |
+
# Compress embeddings
|
| 297 |
+
with torch.no_grad():
|
| 298 |
+
compressed = self.compressor(embeddings) # (B, L, COMP_DIM)
|
| 299 |
+
|
| 300 |
+
return compressed, labels
|
| 301 |
+
|
| 302 |
+
def _compute_validation_metrics(self):
|
| 303 |
+
"""Compute validation metrics on a subset of data."""
|
| 304 |
+
self.flow_model.eval()
|
| 305 |
+
val_losses = []
|
| 306 |
+
|
| 307 |
+
# Use a subset of data for validation
|
| 308 |
+
val_samples = min(1000, len(self.cfg_dataset))
|
| 309 |
+
val_indices = torch.randperm(len(self.cfg_dataset))[:val_samples]
|
| 310 |
+
|
| 311 |
+
with torch.no_grad():
|
| 312 |
+
for i in range(0, val_samples, BATCH_SIZE):
|
| 313 |
+
batch_indices = val_indices[i:i+BATCH_SIZE]
|
| 314 |
+
batch_data = [self.cfg_dataset[idx] for idx in batch_indices]
|
| 315 |
+
|
| 316 |
+
# Collate batch
|
| 317 |
+
embeddings = torch.stack([item['embedding'] for item in batch_data])
|
| 318 |
+
labels = torch.stack([item['label'] for item in batch_data])
|
| 319 |
+
|
| 320 |
+
# Preprocess
|
| 321 |
+
compressed, labels = self._preprocess_batch({
|
| 322 |
+
'embeddings': embeddings,
|
| 323 |
+
'labels': labels
|
| 324 |
+
})
|
| 325 |
+
|
| 326 |
+
B, L, D = compressed.shape
|
| 327 |
+
|
| 328 |
+
# Sample random time
|
| 329 |
+
t = torch.rand(B, device=self.device)
|
| 330 |
+
|
| 331 |
+
# Sample random noise
|
| 332 |
+
eps = torch.randn_like(compressed)
|
| 333 |
+
|
| 334 |
+
# Compute target
|
| 335 |
+
xt = (1 - t.unsqueeze(-1).unsqueeze(-1)) * compressed + t.unsqueeze(-1).unsqueeze(-1) * eps
|
| 336 |
+
|
| 337 |
+
# Predict vector field
|
| 338 |
+
vt_pred = self.flow_model(xt, t, labels=labels)
|
| 339 |
+
|
| 340 |
+
# Target vector field
|
| 341 |
+
vt_target = eps - compressed
|
| 342 |
+
|
| 343 |
+
# Compute loss
|
| 344 |
+
loss = F.mse_loss(vt_pred, vt_target)
|
| 345 |
+
val_losses.append(loss.item())
|
| 346 |
+
|
| 347 |
+
self.flow_model.train()
|
| 348 |
+
return np.mean(val_losses)
|
| 349 |
+
|
| 350 |
+
def train_flow_matching(self):
|
| 351 |
+
"""Train the flow matching model with FULL data and optimizations."""
|
| 352 |
+
print(f"🚀 Starting Optimized Single GPU Flow Matching Training with FULL DATA")
|
| 353 |
+
print(f"GPU: {GPU_ID}")
|
| 354 |
+
print(f"Total iterations: {EPOCHS}")
|
| 355 |
+
print(f"Batch size: {BATCH_SIZE}")
|
| 356 |
+
print(f"Total samples: {len(self.cfg_dataset):,}")
|
| 357 |
+
print(f"Mixed precision: {USE_MIXED_PRECISION}")
|
| 358 |
+
print(f"Estimated time: ~8-10 hours (overnight training with ALL data)")
|
| 359 |
+
print("=" * 60)
|
| 360 |
+
|
| 361 |
+
# Training loop
|
| 362 |
+
best_loss = float('inf')
|
| 363 |
+
losses = []
|
| 364 |
+
val_losses = []
|
| 365 |
+
global_step = 0
|
| 366 |
+
start_time = time.time()
|
| 367 |
+
|
| 368 |
+
for epoch in tqdm(range(EPOCHS), desc="Training Flow Model"):
|
| 369 |
+
epoch_losses = []
|
| 370 |
+
epoch_start_time = time.time()
|
| 371 |
+
|
| 372 |
+
for batch_idx, batch in enumerate(self.dataloader):
|
| 373 |
+
# Preprocess batch
|
| 374 |
+
compressed, labels = self._preprocess_batch(batch)
|
| 375 |
+
B, L, D = compressed.shape
|
| 376 |
+
|
| 377 |
+
# CFG training: randomly mask some labels for unconditional training
|
| 378 |
+
if torch.rand(1).item() < CFG_DROPOUT_RATE:
|
| 379 |
+
labels = torch.full_like(labels, fill_value=-1) # Unconditional
|
| 380 |
+
|
| 381 |
+
# Sample random time
|
| 382 |
+
t = torch.rand(B, device=self.device) # (B,)
|
| 383 |
+
|
| 384 |
+
# Sample random noise
|
| 385 |
+
eps = torch.randn_like(compressed) # (B, L, D)
|
| 386 |
+
|
| 387 |
+
# Compute target: x_t = (1-t) * x_0 + t * eps
|
| 388 |
+
xt = (1 - t.unsqueeze(-1).unsqueeze(-1)) * compressed + t.unsqueeze(-1).unsqueeze(-1) * eps
|
| 389 |
+
|
| 390 |
+
# Forward pass with mixed precision
|
| 391 |
+
if USE_MIXED_PRECISION:
|
| 392 |
+
with autocast(dtype=torch.bfloat16):
|
| 393 |
+
vt_pred = self.flow_model(xt, t, labels=labels) # (B, L, D)
|
| 394 |
+
vt_target = eps - compressed # (B, L, D)
|
| 395 |
+
loss = F.mse_loss(vt_pred, vt_target)
|
| 396 |
+
|
| 397 |
+
# Backward pass with gradient scaling
|
| 398 |
+
self.optimizer.zero_grad()
|
| 399 |
+
self.scaler.scale(loss).backward()
|
| 400 |
+
|
| 401 |
+
# Gradient clipping
|
| 402 |
+
self.scaler.unscale_(self.optimizer)
|
| 403 |
+
torch.nn.utils.clip_grad_norm_(self.flow_model.parameters(), max_norm=GRADIENT_CLIP_NORM)
|
| 404 |
+
|
| 405 |
+
self.scaler.step(self.optimizer)
|
| 406 |
+
self.scaler.update()
|
| 407 |
+
else:
|
| 408 |
+
# Standard training
|
| 409 |
+
vt_pred = self.flow_model(xt, t, labels=labels) # (B, L, D)
|
| 410 |
+
vt_target = eps - compressed # (B, L, D)
|
| 411 |
+
loss = F.mse_loss(vt_pred, vt_target)
|
| 412 |
+
|
| 413 |
+
# Backward pass
|
| 414 |
+
self.optimizer.zero_grad()
|
| 415 |
+
loss.backward()
|
| 416 |
+
|
| 417 |
+
# Gradient clipping
|
| 418 |
+
torch.nn.utils.clip_grad_norm_(self.flow_model.parameters(), max_norm=GRADIENT_CLIP_NORM)
|
| 419 |
+
|
| 420 |
+
self.optimizer.step()
|
| 421 |
+
|
| 422 |
+
# Update learning rate
|
| 423 |
+
self.scheduler.step()
|
| 424 |
+
|
| 425 |
+
epoch_losses.append(loss.item())
|
| 426 |
+
global_step += 1
|
| 427 |
+
|
| 428 |
+
# Logging
|
| 429 |
+
if batch_idx % 100 == 0:
|
| 430 |
+
current_lr = self.scheduler.get_last_lr()[0]
|
| 431 |
+
elapsed_time = time.time() - start_time
|
| 432 |
+
steps_per_sec = global_step / elapsed_time
|
| 433 |
+
eta_hours = (self.total_steps - global_step) / steps_per_sec / 3600
|
| 434 |
+
|
| 435 |
+
print(f"Epoch {epoch:4d} | Step {global_step:6d}/{self.total_steps:6d} | "
|
| 436 |
+
f"Loss: {loss.item():.6f} | LR: {current_lr:.2e} | "
|
| 437 |
+
f"Speed: {steps_per_sec:.1f} steps/s | ETA: {eta_hours:.1f}h")
|
| 438 |
+
|
| 439 |
+
# Log to wandb
|
| 440 |
+
if self.use_wandb:
|
| 441 |
+
wandb.log({
|
| 442 |
+
'train/loss': loss.item(),
|
| 443 |
+
'train/learning_rate': current_lr,
|
| 444 |
+
'train/steps_per_sec': steps_per_sec,
|
| 445 |
+
'train/global_step': global_step
|
| 446 |
+
})
|
| 447 |
+
|
| 448 |
+
# Validation
|
| 449 |
+
if global_step % self.validation_steps == 0:
|
| 450 |
+
val_loss = self._compute_validation_metrics()
|
| 451 |
+
val_losses.append(val_loss)
|
| 452 |
+
|
| 453 |
+
print(f"Validation at step {global_step}: Loss = {val_loss:.6f}")
|
| 454 |
+
|
| 455 |
+
if self.use_wandb:
|
| 456 |
+
wandb.log({
|
| 457 |
+
'val/loss': val_loss,
|
| 458 |
+
'val/global_step': global_step
|
| 459 |
+
})
|
| 460 |
+
|
| 461 |
+
# Early stopping check
|
| 462 |
+
if val_loss < best_loss:
|
| 463 |
+
best_loss = val_loss
|
| 464 |
+
self._save_checkpoint(epoch, val_loss, global_step, is_final=False, is_best=True)
|
| 465 |
+
|
| 466 |
+
# Compute epoch statistics
|
| 467 |
+
avg_loss = np.mean(epoch_losses)
|
| 468 |
+
losses.append(avg_loss)
|
| 469 |
+
epoch_time = time.time() - epoch_start_time
|
| 470 |
+
|
| 471 |
+
print(f"Epoch {epoch:4d} | Avg Loss: {avg_loss:.6f} | "
|
| 472 |
+
f"LR: {self.scheduler.get_last_lr()[0]:.2e} | "
|
| 473 |
+
f"Time: {epoch_time:.1f}s | Samples: {len(self.cfg_dataset):,}")
|
| 474 |
+
|
| 475 |
+
# Save checkpoint
|
| 476 |
+
if (epoch + 1) % CHECKPOINT_INTERVAL == 0:
|
| 477 |
+
self._save_checkpoint(epoch, avg_loss, global_step, is_final=True)
|
| 478 |
+
|
| 479 |
+
# Save final model
|
| 480 |
+
self._save_checkpoint(EPOCHS - 1, losses[-1], global_step, is_final=True)
|
| 481 |
+
|
| 482 |
+
total_time = time.time() - start_time
|
| 483 |
+
print("=" * 60)
|
| 484 |
+
print("🎉 Optimized Training Complete with FULL DATA!")
|
| 485 |
+
print(f"Best validation loss: {best_loss:.6f}")
|
| 486 |
+
print(f"Total training time: {total_time/3600:.1f} hours")
|
| 487 |
+
print(f"Total samples used: {len(self.cfg_dataset):,}")
|
| 488 |
+
print(f"Final model saved as: amp_flow_model_final_optimized.pth")
|
| 489 |
+
|
| 490 |
+
return losses, val_losses
|
| 491 |
+
|
| 492 |
+
def _save_checkpoint(self, step, loss, global_step, is_final=False, is_best=False):
|
| 493 |
+
"""Save model checkpoint."""
|
| 494 |
+
# Create output directory if it doesn't exist
|
| 495 |
+
output_dir = '/data2/edwardsun/flow_checkpoints'
|
| 496 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 497 |
+
|
| 498 |
+
if is_best:
|
| 499 |
+
filename = os.path.join(output_dir, 'amp_flow_model_best_optimized.pth')
|
| 500 |
+
elif is_final:
|
| 501 |
+
filename = os.path.join(output_dir, 'amp_flow_model_final_optimized.pth')
|
| 502 |
+
else:
|
| 503 |
+
filename = os.path.join(output_dir, f'amp_flow_checkpoint_optimized_step_{step:04d}.pth')
|
| 504 |
+
|
| 505 |
+
checkpoint = {
|
| 506 |
+
'step': step,
|
| 507 |
+
'global_step': global_step,
|
| 508 |
+
'loss': loss,
|
| 509 |
+
'flow_model_state_dict': self.flow_model.state_dict(),
|
| 510 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 511 |
+
'scheduler_state_dict': self.scheduler.state_dict(),
|
| 512 |
+
'stats': self.stats,
|
| 513 |
+
'total_samples': len(self.cfg_dataset),
|
| 514 |
+
'config': {
|
| 515 |
+
'batch_size': BATCH_SIZE,
|
| 516 |
+
'epochs': EPOCHS,
|
| 517 |
+
'base_lr': BASE_LR,
|
| 518 |
+
'lr_min': LR_MIN,
|
| 519 |
+
'warmup_steps': WARMUP_STEPS,
|
| 520 |
+
'mixed_precision': USE_MIXED_PRECISION,
|
| 521 |
+
'gradient_clip': GRADIENT_CLIP_NORM,
|
| 522 |
+
'weight_decay': WEIGHT_DECAY
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
torch.save(checkpoint, filename)
|
| 527 |
+
print(f"✓ Checkpoint saved: {filename} (loss: {loss:.6f}, step: {global_step})")
|
| 528 |
+
|
| 529 |
+
def main():
|
| 530 |
+
"""Main training function."""
|
| 531 |
+
global BATCH_SIZE, EPOCHS
|
| 532 |
+
|
| 533 |
+
parser = argparse.ArgumentParser(description='Optimized Single GPU AMP Flow Training with FULL DATA')
|
| 534 |
+
parser.add_argument('--embeddings', default='/data2/edwardsun/flow_project/peptide_embeddings/',
|
| 535 |
+
help='Path to peptide embeddings directory')
|
| 536 |
+
parser.add_argument('--cfg_data', default='/data2/edwardsun/flow_project/test_uniprot_processed/uniprot_processed_data.json',
|
| 537 |
+
help='Path to FULL CFG data file')
|
| 538 |
+
parser.add_argument('--use_wandb', action='store_true', help='Use wandb for logging')
|
| 539 |
+
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training')
|
| 540 |
+
parser.add_argument('--epochs', type=int, default=EPOCHS, help='Number of training epochs')
|
| 541 |
+
|
| 542 |
+
args = parser.parse_args()
|
| 543 |
+
|
| 544 |
+
# Update global variables if provided
|
| 545 |
+
if args.batch_size != BATCH_SIZE:
|
| 546 |
+
BATCH_SIZE = args.batch_size
|
| 547 |
+
if args.epochs != EPOCHS:
|
| 548 |
+
EPOCHS = args.epochs
|
| 549 |
+
|
| 550 |
+
print(f"Starting optimized training with batch_size={BATCH_SIZE}, epochs={EPOCHS}")
|
| 551 |
+
|
| 552 |
+
# Initialize trainer
|
| 553 |
+
trainer = AMPFlowTrainerSingleGPUFullData(args.embeddings, args.cfg_data, args.use_wandb)
|
| 554 |
+
|
| 555 |
+
# Start training
|
| 556 |
+
losses, val_losses = trainer.train_flow_matching()
|
| 557 |
+
|
| 558 |
+
print("Optimized training completed successfully with FULL DATA!")
|
| 559 |
+
|
| 560 |
+
if __name__ == "__main__":
|
| 561 |
+
main()
|