esunAI commited on
Commit
321da93
·
verified ·
1 Parent(s): 37158e8

Add amp_flow_training_single_gpu_full_data.py

Browse files
src/amp_flow_training_single_gpu_full_data.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader
6
+ from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import json
10
+ import os
11
+ import argparse
12
+ import time
13
+ from torch.cuda.amp import autocast, GradScaler
14
+ import wandb # For logging (optional)
15
+
16
+ # Import your existing components
17
+ from compressor_with_embeddings import Compressor, Decompressor, PrecomputedEmbeddingDataset
18
+ from final_flow_model import AMPFlowMatcherCFGConcat, SinusoidalTimeEmbedding
19
+ from cfg_dataset import CFGFlowDataset, create_cfg_dataloader
20
+
21
+ # ---------------- Optimized Configuration for H100 ----------------
22
+ ESM_DIM = 1280 # ESM-2 hidden dim (esm2_t33_650M_UR50D)
23
+ COMP_RATIO = 16 # compression factor
24
+ COMP_DIM = ESM_DIM // COMP_RATIO
25
+ MAX_SEQ_LEN = 50 # Actual sequence length from final_sequence_encoder.py
26
+
27
+ # OPTIMIZED H100 hyperparameters - HIGH THROUGHPUT + STABLE TRAINING
28
+ BATCH_SIZE = 512 # PUSH H100 TO LIMITS - using ~70GB memory
29
+ EPOCHS = 2000 # Slightly more epochs with safer LR for same 5-6 hour target
30
+ BASE_LR = 8e-4 # SAFE but effective LR - 2x original, not 4x
31
+ LR_MIN = 4e-4 # Conservative minimum learning rate
32
+ WARMUP_STEPS = 4000 # Gentler warmup to avoid explosion
33
+ GPU_ID = 0 # Use GPU 0
34
+
35
+ # Training optimizations
36
+ USE_MIXED_PRECISION = True # BF16 for H100
37
+ GRADIENT_CLIP_NORM = 0.5 # TIGHTER gradient clipping for flow matching stability
38
+ WEIGHT_DECAY = 0.01 # Weight decay for regularization
39
+ VALIDATION_INTERVAL = 5000 # Validate every 5K steps (more frequent)
40
+ CHECKPOINT_INTERVAL = 300 # Save checkpoint every 300 epochs (more frequent)
41
+ NUM_WORKERS = 32 # MAXIMIZED data loading workers for H100
42
+
43
+ # CFG training parameters
44
+ CFG_DROPOUT_RATE = 0.15 # 15% of batches as unconditional for CFG
45
+
46
+ class AMPFlowTrainerSingleGPUFullData:
47
+ """
48
+ Optimized Single GPU training pipeline for AMP generation using ProtFlow methodology.
49
+ Uses ALL available data with H100-optimized settings for overnight training.
50
+ """
51
+
52
+ def __init__(self, embeddings_path, cfg_data_path, use_wandb=False):
53
+ self.device = torch.device(f'cuda:{GPU_ID}')
54
+ self.embeddings_path = embeddings_path
55
+ self.cfg_data_path = cfg_data_path
56
+ self.use_wandb = use_wandb
57
+
58
+ # Enable H100 optimizations
59
+ torch.backends.cuda.matmul.allow_tf32 = True
60
+ torch.backends.cudnn.allow_tf32 = True
61
+
62
+ print(f"Using GPU {GPU_ID} for optimized H100 training")
63
+ print(f"Mixed precision: {USE_MIXED_PRECISION}")
64
+ print(f"Batch size: {BATCH_SIZE}")
65
+ print(f"Target epochs: {EPOCHS}")
66
+ print(f"Learning rate: {BASE_LR} -> {LR_MIN}")
67
+
68
+ # Initialize mixed precision training
69
+ if USE_MIXED_PRECISION:
70
+ self.scaler = GradScaler()
71
+ print("✓ Mixed precision training enabled (BF16)")
72
+
73
+ # Initialize wandb if requested
74
+ if self.use_wandb:
75
+ wandb.init(
76
+ project="amp-flow-training",
77
+ config={
78
+ "batch_size": BATCH_SIZE,
79
+ "epochs": EPOCHS,
80
+ "base_lr": BASE_LR,
81
+ "lr_min": LR_MIN,
82
+ "warmup_steps": WARMUP_STEPS,
83
+ "mixed_precision": USE_MIXED_PRECISION,
84
+ "gradient_clip": GRADIENT_CLIP_NORM,
85
+ "weight_decay": WEIGHT_DECAY
86
+ }
87
+ )
88
+
89
+ print(f"Loading ALL AMP embeddings from {embeddings_path}...")
90
+
91
+ # Load ALL embeddings (use the combined file instead of individual files)
92
+ self._load_all_embeddings()
93
+
94
+ # Compute normalization statistics
95
+ print("Computing preprocessing statistics...")
96
+ self._compute_preprocessing_stats()
97
+
98
+ # Initialize models
99
+ self._initialize_models()
100
+
101
+ # Initialize datasets and dataloaders
102
+ self._initialize_data()
103
+
104
+ # Initialize optimizer and scheduler
105
+ self._initialize_optimizer()
106
+
107
+ print("✓ Optimized Single GPU training setup complete with FULL DATA!")
108
+
109
+ def _load_all_embeddings(self):
110
+ """Load ALL peptide embeddings from the combined file."""
111
+ # Try to load the combined embeddings file first
112
+ combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt")
113
+
114
+ if os.path.exists(combined_path):
115
+ print(f"Loading combined embeddings from {combined_path}...")
116
+ self.embeddings = torch.load(combined_path, map_location=self.device)
117
+ print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}")
118
+ else:
119
+ print("Combined embeddings file not found, loading individual files...")
120
+ # Fallback to individual files
121
+ import glob
122
+
123
+ embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt"))
124
+ embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')]
125
+
126
+ print(f"Found {len(embedding_files)} individual embedding files")
127
+
128
+ # Load and stack all embeddings
129
+ embeddings_list = []
130
+ for file_path in embedding_files:
131
+ try:
132
+ embedding = torch.load(file_path)
133
+ if embedding.dim() == 2: # (seq_len, hidden_dim)
134
+ embeddings_list.append(embedding)
135
+ else:
136
+ print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}")
137
+ except Exception as e:
138
+ print(f"Warning: Could not load {file_path}: {e}")
139
+
140
+ if not embeddings_list:
141
+ raise ValueError("No valid embeddings found!")
142
+
143
+ self.embeddings = torch.stack(embeddings_list)
144
+ print(f"Loaded {len(self.embeddings)} embeddings from individual files")
145
+
146
+ def _compute_preprocessing_stats(self):
147
+ """Compute normalization statistics for embeddings."""
148
+ # Flatten all embeddings
149
+ flat_embeddings = self.embeddings.reshape(-1, ESM_DIM)
150
+
151
+ # Compute statistics
152
+ mean = flat_embeddings.mean(dim=0)
153
+ std = flat_embeddings.std(dim=0)
154
+ min_val = flat_embeddings.min()
155
+ max_val = flat_embeddings.max()
156
+
157
+ self.stats = {
158
+ 'mean': mean,
159
+ 'std': std,
160
+ 'min': min_val,
161
+ 'max': max_val
162
+ }
163
+
164
+ # Save statistics
165
+ torch.save(self.stats, 'normalization_stats.pt')
166
+ print(f"✓ Statistics computed and saved:")
167
+ print(f" Total embeddings: {len(self.embeddings):,}")
168
+ print(f" Mean: {mean.mean():.4f} ± {mean.std():.4f}")
169
+ print(f" Std: {std.mean():.4f} ± {std.std():.4f}")
170
+ print(f" Range: [{min_val:.4f}, {max_val:.4f}]")
171
+
172
+ def _initialize_models(self):
173
+ """Initialize compressor, decompressor, and flow model."""
174
+ print("Initializing models...")
175
+
176
+ # Load pre-trained compressor and decompressor
177
+ self.compressor = Compressor().to(self.device)
178
+ self.decompressor = Decompressor().to(self.device)
179
+
180
+ self.compressor.load_state_dict(torch.load('final_compressor_model.pth', map_location=self.device))
181
+ self.decompressor.load_state_dict(torch.load('final_decompressor_model.pth', map_location=self.device))
182
+
183
+ # Initialize flow model with CFG
184
+ self.flow_model = AMPFlowMatcherCFGConcat(
185
+ hidden_dim=480,
186
+ compressed_dim=COMP_DIM,
187
+ n_layers=12,
188
+ n_heads=16,
189
+ dim_ff=3072,
190
+ max_seq_len=25, # MAX_SEQ_LEN // 2 due to pooling
191
+ use_cfg=True
192
+ ).to(self.device)
193
+
194
+ # Compile model for PyTorch 2.x speedup (if available)
195
+ try:
196
+ self.flow_model = torch.compile(self.flow_model, mode="reduce-overhead")
197
+ print("✓ Model compiled with torch.compile for speedup")
198
+ except Exception as e:
199
+ print(f"⚠️ Model compilation failed: {e}")
200
+
201
+ # Set models to training mode
202
+ self.compressor.train()
203
+ self.decompressor.train()
204
+ self.flow_model.train()
205
+
206
+ print(f"✓ Models initialized:")
207
+ print(f" Compressor parameters: {sum(p.numel() for p in self.compressor.parameters()):,}")
208
+ print(f" Decompressor parameters: {sum(p.numel() for p in self.decompressor.parameters()):,}")
209
+ print(f" Flow model parameters: {sum(p.numel() for p in self.flow_model.parameters()):,}")
210
+
211
+ def _initialize_data(self):
212
+ """Initialize datasets and dataloaders with FULL data."""
213
+ print("Initializing datasets with FULL data...")
214
+
215
+ # Create CFG dataset with FULL UniProt data
216
+ self.cfg_dataset = CFGFlowDataset(
217
+ embeddings_path=self.embeddings_path,
218
+ cfg_data_path=self.cfg_data_path,
219
+ use_masked_labels=True,
220
+ max_seq_len=MAX_SEQ_LEN,
221
+ device=self.device
222
+ )
223
+
224
+ # Create dataloader with optimized settings
225
+ self.dataloader = create_cfg_dataloader(
226
+ self.cfg_dataset,
227
+ batch_size=BATCH_SIZE,
228
+ shuffle=True,
229
+ num_workers=NUM_WORKERS
230
+ )
231
+
232
+ # Calculate total steps and validation intervals
233
+ self.total_steps = len(self.dataloader) * EPOCHS
234
+ self.validation_steps = VALIDATION_INTERVAL
235
+
236
+ print(f"✓ Dataset initialized with FULL data:")
237
+ print(f" Total samples: {len(self.cfg_dataset):,}")
238
+ print(f" Batch size: {BATCH_SIZE}")
239
+ print(f" Batches per epoch: {len(self.dataloader):,}")
240
+ print(f" Total training steps: {self.total_steps:,}")
241
+ print(f" Validation every: {self.validation_steps:,} steps")
242
+
243
+ def _initialize_optimizer(self):
244
+ """Initialize optimizer and learning rate scheduler."""
245
+ print("Initializing optimizer and scheduler...")
246
+
247
+ # Optimizer for flow model only (compressor/decompressor are frozen)
248
+ self.optimizer = optim.AdamW(
249
+ self.flow_model.parameters(),
250
+ lr=BASE_LR,
251
+ weight_decay=WEIGHT_DECAY,
252
+ betas=(0.9, 0.98), # Optimized betas for flow matching
253
+ eps=1e-6 # Lower epsilon for numerical stability
254
+ )
255
+
256
+ # Learning rate scheduler with proper warmup and cosine annealing
257
+ warmup_scheduler = LinearLR(
258
+ self.optimizer,
259
+ start_factor=0.1,
260
+ end_factor=1.0,
261
+ total_iters=WARMUP_STEPS
262
+ )
263
+
264
+ main_scheduler = CosineAnnealingLR(
265
+ self.optimizer,
266
+ T_max=self.total_steps - WARMUP_STEPS,
267
+ eta_min=LR_MIN
268
+ )
269
+
270
+ self.scheduler = SequentialLR(
271
+ self.optimizer,
272
+ schedulers=[warmup_scheduler, main_scheduler],
273
+ milestones=[WARMUP_STEPS]
274
+ )
275
+
276
+ print(f"✓ Optimizer initialized:")
277
+ print(f" Base LR: {BASE_LR}")
278
+ print(f" Min LR: {LR_MIN}")
279
+ print(f" Warmup steps: {WARMUP_STEPS}")
280
+ print(f" Weight decay: {WEIGHT_DECAY}")
281
+ print(f" Gradient clip norm: {GRADIENT_CLIP_NORM}")
282
+
283
+ def _preprocess_batch(self, batch):
284
+ """Preprocess a batch of data for training."""
285
+ # Extract data
286
+ embeddings = batch['embeddings'].to(self.device) # (B, L, ESM_DIM)
287
+ labels = batch['labels'].to(self.device) # (B,)
288
+
289
+ # Normalize embeddings
290
+ m, s = self.stats['mean'].to(self.device), self.stats['std'].to(self.device)
291
+ mn, mx = self.stats['min'].to(self.device), self.stats['max'].to(self.device)
292
+
293
+ embeddings = (embeddings - m) / (s + 1e-8)
294
+ embeddings = (embeddings - mn) / (mx - mn + 1e-8)
295
+
296
+ # Compress embeddings
297
+ with torch.no_grad():
298
+ compressed = self.compressor(embeddings) # (B, L, COMP_DIM)
299
+
300
+ return compressed, labels
301
+
302
+ def _compute_validation_metrics(self):
303
+ """Compute validation metrics on a subset of data."""
304
+ self.flow_model.eval()
305
+ val_losses = []
306
+
307
+ # Use a subset of data for validation
308
+ val_samples = min(1000, len(self.cfg_dataset))
309
+ val_indices = torch.randperm(len(self.cfg_dataset))[:val_samples]
310
+
311
+ with torch.no_grad():
312
+ for i in range(0, val_samples, BATCH_SIZE):
313
+ batch_indices = val_indices[i:i+BATCH_SIZE]
314
+ batch_data = [self.cfg_dataset[idx] for idx in batch_indices]
315
+
316
+ # Collate batch
317
+ embeddings = torch.stack([item['embedding'] for item in batch_data])
318
+ labels = torch.stack([item['label'] for item in batch_data])
319
+
320
+ # Preprocess
321
+ compressed, labels = self._preprocess_batch({
322
+ 'embeddings': embeddings,
323
+ 'labels': labels
324
+ })
325
+
326
+ B, L, D = compressed.shape
327
+
328
+ # Sample random time
329
+ t = torch.rand(B, device=self.device)
330
+
331
+ # Sample random noise
332
+ eps = torch.randn_like(compressed)
333
+
334
+ # Compute target
335
+ xt = (1 - t.unsqueeze(-1).unsqueeze(-1)) * compressed + t.unsqueeze(-1).unsqueeze(-1) * eps
336
+
337
+ # Predict vector field
338
+ vt_pred = self.flow_model(xt, t, labels=labels)
339
+
340
+ # Target vector field
341
+ vt_target = eps - compressed
342
+
343
+ # Compute loss
344
+ loss = F.mse_loss(vt_pred, vt_target)
345
+ val_losses.append(loss.item())
346
+
347
+ self.flow_model.train()
348
+ return np.mean(val_losses)
349
+
350
+ def train_flow_matching(self):
351
+ """Train the flow matching model with FULL data and optimizations."""
352
+ print(f"🚀 Starting Optimized Single GPU Flow Matching Training with FULL DATA")
353
+ print(f"GPU: {GPU_ID}")
354
+ print(f"Total iterations: {EPOCHS}")
355
+ print(f"Batch size: {BATCH_SIZE}")
356
+ print(f"Total samples: {len(self.cfg_dataset):,}")
357
+ print(f"Mixed precision: {USE_MIXED_PRECISION}")
358
+ print(f"Estimated time: ~8-10 hours (overnight training with ALL data)")
359
+ print("=" * 60)
360
+
361
+ # Training loop
362
+ best_loss = float('inf')
363
+ losses = []
364
+ val_losses = []
365
+ global_step = 0
366
+ start_time = time.time()
367
+
368
+ for epoch in tqdm(range(EPOCHS), desc="Training Flow Model"):
369
+ epoch_losses = []
370
+ epoch_start_time = time.time()
371
+
372
+ for batch_idx, batch in enumerate(self.dataloader):
373
+ # Preprocess batch
374
+ compressed, labels = self._preprocess_batch(batch)
375
+ B, L, D = compressed.shape
376
+
377
+ # CFG training: randomly mask some labels for unconditional training
378
+ if torch.rand(1).item() < CFG_DROPOUT_RATE:
379
+ labels = torch.full_like(labels, fill_value=-1) # Unconditional
380
+
381
+ # Sample random time
382
+ t = torch.rand(B, device=self.device) # (B,)
383
+
384
+ # Sample random noise
385
+ eps = torch.randn_like(compressed) # (B, L, D)
386
+
387
+ # Compute target: x_t = (1-t) * x_0 + t * eps
388
+ xt = (1 - t.unsqueeze(-1).unsqueeze(-1)) * compressed + t.unsqueeze(-1).unsqueeze(-1) * eps
389
+
390
+ # Forward pass with mixed precision
391
+ if USE_MIXED_PRECISION:
392
+ with autocast(dtype=torch.bfloat16):
393
+ vt_pred = self.flow_model(xt, t, labels=labels) # (B, L, D)
394
+ vt_target = eps - compressed # (B, L, D)
395
+ loss = F.mse_loss(vt_pred, vt_target)
396
+
397
+ # Backward pass with gradient scaling
398
+ self.optimizer.zero_grad()
399
+ self.scaler.scale(loss).backward()
400
+
401
+ # Gradient clipping
402
+ self.scaler.unscale_(self.optimizer)
403
+ torch.nn.utils.clip_grad_norm_(self.flow_model.parameters(), max_norm=GRADIENT_CLIP_NORM)
404
+
405
+ self.scaler.step(self.optimizer)
406
+ self.scaler.update()
407
+ else:
408
+ # Standard training
409
+ vt_pred = self.flow_model(xt, t, labels=labels) # (B, L, D)
410
+ vt_target = eps - compressed # (B, L, D)
411
+ loss = F.mse_loss(vt_pred, vt_target)
412
+
413
+ # Backward pass
414
+ self.optimizer.zero_grad()
415
+ loss.backward()
416
+
417
+ # Gradient clipping
418
+ torch.nn.utils.clip_grad_norm_(self.flow_model.parameters(), max_norm=GRADIENT_CLIP_NORM)
419
+
420
+ self.optimizer.step()
421
+
422
+ # Update learning rate
423
+ self.scheduler.step()
424
+
425
+ epoch_losses.append(loss.item())
426
+ global_step += 1
427
+
428
+ # Logging
429
+ if batch_idx % 100 == 0:
430
+ current_lr = self.scheduler.get_last_lr()[0]
431
+ elapsed_time = time.time() - start_time
432
+ steps_per_sec = global_step / elapsed_time
433
+ eta_hours = (self.total_steps - global_step) / steps_per_sec / 3600
434
+
435
+ print(f"Epoch {epoch:4d} | Step {global_step:6d}/{self.total_steps:6d} | "
436
+ f"Loss: {loss.item():.6f} | LR: {current_lr:.2e} | "
437
+ f"Speed: {steps_per_sec:.1f} steps/s | ETA: {eta_hours:.1f}h")
438
+
439
+ # Log to wandb
440
+ if self.use_wandb:
441
+ wandb.log({
442
+ 'train/loss': loss.item(),
443
+ 'train/learning_rate': current_lr,
444
+ 'train/steps_per_sec': steps_per_sec,
445
+ 'train/global_step': global_step
446
+ })
447
+
448
+ # Validation
449
+ if global_step % self.validation_steps == 0:
450
+ val_loss = self._compute_validation_metrics()
451
+ val_losses.append(val_loss)
452
+
453
+ print(f"Validation at step {global_step}: Loss = {val_loss:.6f}")
454
+
455
+ if self.use_wandb:
456
+ wandb.log({
457
+ 'val/loss': val_loss,
458
+ 'val/global_step': global_step
459
+ })
460
+
461
+ # Early stopping check
462
+ if val_loss < best_loss:
463
+ best_loss = val_loss
464
+ self._save_checkpoint(epoch, val_loss, global_step, is_final=False, is_best=True)
465
+
466
+ # Compute epoch statistics
467
+ avg_loss = np.mean(epoch_losses)
468
+ losses.append(avg_loss)
469
+ epoch_time = time.time() - epoch_start_time
470
+
471
+ print(f"Epoch {epoch:4d} | Avg Loss: {avg_loss:.6f} | "
472
+ f"LR: {self.scheduler.get_last_lr()[0]:.2e} | "
473
+ f"Time: {epoch_time:.1f}s | Samples: {len(self.cfg_dataset):,}")
474
+
475
+ # Save checkpoint
476
+ if (epoch + 1) % CHECKPOINT_INTERVAL == 0:
477
+ self._save_checkpoint(epoch, avg_loss, global_step, is_final=True)
478
+
479
+ # Save final model
480
+ self._save_checkpoint(EPOCHS - 1, losses[-1], global_step, is_final=True)
481
+
482
+ total_time = time.time() - start_time
483
+ print("=" * 60)
484
+ print("🎉 Optimized Training Complete with FULL DATA!")
485
+ print(f"Best validation loss: {best_loss:.6f}")
486
+ print(f"Total training time: {total_time/3600:.1f} hours")
487
+ print(f"Total samples used: {len(self.cfg_dataset):,}")
488
+ print(f"Final model saved as: amp_flow_model_final_optimized.pth")
489
+
490
+ return losses, val_losses
491
+
492
+ def _save_checkpoint(self, step, loss, global_step, is_final=False, is_best=False):
493
+ """Save model checkpoint."""
494
+ # Create output directory if it doesn't exist
495
+ output_dir = '/data2/edwardsun/flow_checkpoints'
496
+ os.makedirs(output_dir, exist_ok=True)
497
+
498
+ if is_best:
499
+ filename = os.path.join(output_dir, 'amp_flow_model_best_optimized.pth')
500
+ elif is_final:
501
+ filename = os.path.join(output_dir, 'amp_flow_model_final_optimized.pth')
502
+ else:
503
+ filename = os.path.join(output_dir, f'amp_flow_checkpoint_optimized_step_{step:04d}.pth')
504
+
505
+ checkpoint = {
506
+ 'step': step,
507
+ 'global_step': global_step,
508
+ 'loss': loss,
509
+ 'flow_model_state_dict': self.flow_model.state_dict(),
510
+ 'optimizer_state_dict': self.optimizer.state_dict(),
511
+ 'scheduler_state_dict': self.scheduler.state_dict(),
512
+ 'stats': self.stats,
513
+ 'total_samples': len(self.cfg_dataset),
514
+ 'config': {
515
+ 'batch_size': BATCH_SIZE,
516
+ 'epochs': EPOCHS,
517
+ 'base_lr': BASE_LR,
518
+ 'lr_min': LR_MIN,
519
+ 'warmup_steps': WARMUP_STEPS,
520
+ 'mixed_precision': USE_MIXED_PRECISION,
521
+ 'gradient_clip': GRADIENT_CLIP_NORM,
522
+ 'weight_decay': WEIGHT_DECAY
523
+ }
524
+ }
525
+
526
+ torch.save(checkpoint, filename)
527
+ print(f"✓ Checkpoint saved: {filename} (loss: {loss:.6f}, step: {global_step})")
528
+
529
+ def main():
530
+ """Main training function."""
531
+ global BATCH_SIZE, EPOCHS
532
+
533
+ parser = argparse.ArgumentParser(description='Optimized Single GPU AMP Flow Training with FULL DATA')
534
+ parser.add_argument('--embeddings', default='/data2/edwardsun/flow_project/peptide_embeddings/',
535
+ help='Path to peptide embeddings directory')
536
+ parser.add_argument('--cfg_data', default='/data2/edwardsun/flow_project/test_uniprot_processed/uniprot_processed_data.json',
537
+ help='Path to FULL CFG data file')
538
+ parser.add_argument('--use_wandb', action='store_true', help='Use wandb for logging')
539
+ parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training')
540
+ parser.add_argument('--epochs', type=int, default=EPOCHS, help='Number of training epochs')
541
+
542
+ args = parser.parse_args()
543
+
544
+ # Update global variables if provided
545
+ if args.batch_size != BATCH_SIZE:
546
+ BATCH_SIZE = args.batch_size
547
+ if args.epochs != EPOCHS:
548
+ EPOCHS = args.epochs
549
+
550
+ print(f"Starting optimized training with batch_size={BATCH_SIZE}, epochs={EPOCHS}")
551
+
552
+ # Initialize trainer
553
+ trainer = AMPFlowTrainerSingleGPUFullData(args.embeddings, args.cfg_data, args.use_wandb)
554
+
555
+ # Start training
556
+ losses, val_losses = trainer.train_flow_matching()
557
+
558
+ print("Optimized training completed successfully with FULL DATA!")
559
+
560
+ if __name__ == "__main__":
561
+ main()