| import json
|
| import shutil
|
| import torch
|
| import torch.nn as nn
|
| from torch.utils.data import Dataset, DataLoader
|
| from transformers import (
|
| DebertaV2Model,
|
| DebertaV2TokenizerFast,
|
| DebertaV2Config,
|
| get_linear_schedule_with_warmup,
|
| set_seed
|
| )
|
| from torch.cuda.amp import autocast
|
| from tqdm import tqdm
|
| import numpy as np
|
| from pathlib import Path
|
| import logging
|
| from dataclasses import dataclass
|
| from typing import Optional, Dict, List, Tuple
|
| import wandb
|
| from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
|
| import functools
|
| import re
|
|
|
|
|
| logging.basicConfig(
|
| format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
| datefmt='%m/%d/%Y %H:%M:%S',
|
| level=logging.INFO
|
| )
|
| logger = logging.getLogger(__name__)
|
|
|
| @dataclass
|
| class TrainingConfig:
|
| """Training configuration for link token classification"""
|
|
|
| model_name: str = "microsoft/deberta-v3-large"
|
| num_labels: int = 2
|
|
|
|
|
| train_file: str = "train_windows.jsonl"
|
| val_file: str = "val_windows.jsonl"
|
| max_length: int = 512
|
|
|
|
|
| batch_size: int = 8
|
| gradient_accumulation_steps: int = 8
|
| num_epochs: int = 3
|
| learning_rate: float = 1e-6
|
| warmup_ratio: float = 0.1
|
| weight_decay: float = 0.01
|
| max_grad_norm: float = 1.0
|
| label_smoothing: float = 0.0
|
|
|
|
|
| device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
| num_workers: int = 0
|
| seed: int = 42
|
| bf16: bool = True
|
|
|
|
|
| logging_steps: int = 1
|
| eval_steps: int = 5000
|
| save_steps: int = 10000
|
| output_dir: str = "./deberta_link_output"
|
|
|
|
|
| wandb_project: str = "deberta-link-classification"
|
| wandb_name: str = "deberta-v3-large-link-tokens"
|
|
|
|
|
| patience: int = 2
|
| min_delta: float = 0.0001
|
|
|
|
|
| max_checkpoints: int = 5
|
| protect_latest_epoch_step: bool = True
|
|
|
|
|
| class LinkTokenDataset(Dataset):
|
| """Dataset for link token classification"""
|
|
|
| def __init__(self, file_path: str, max_samples: Optional[int] = None):
|
| self.data = []
|
|
|
| logger.info(f"Loading data from {file_path}")
|
| seq_lengths = []
|
|
|
| with open(file_path, 'r') as f:
|
| for i, line in enumerate(f):
|
| if max_samples and i >= max_samples:
|
| break
|
| sample = json.loads(line)
|
|
|
| seq_len = len(sample['input_ids'])
|
| seq_lengths.append(seq_len)
|
|
|
|
|
| sample['input_ids'] = torch.tensor(sample['input_ids'], dtype=torch.long)
|
| sample['attention_mask'] = torch.tensor(sample['attention_mask'], dtype=torch.long)
|
| sample['labels'] = torch.tensor(sample['labels'], dtype=torch.long)
|
|
|
| self.data.append(sample)
|
|
|
| logger.info(f"Loaded {len(self.data)} samples")
|
| logger.info(f"Sequence lengths - Min: {min(seq_lengths)}, Max: {max(seq_lengths)}, Avg: {np.mean(seq_lengths):.1f}")
|
|
|
|
|
| total_labels = []
|
| for s in self.data:
|
|
|
| valid_labels = s['labels'][s['labels'] != -100]
|
| total_labels.append(valid_labels)
|
|
|
|
|
| if total_labels:
|
| total_labels = torch.cat(total_labels)
|
| num_link_tokens = (total_labels == 1).sum().item()
|
| num_non_link = (total_labels == 0).sum().item()
|
|
|
| logger.info(f"Label distribution - Non-link: {num_non_link}, Link: {num_link_tokens}")
|
| if (num_link_tokens + num_non_link) > 0:
|
| logger.info(f"Link token ratio: {num_link_tokens / (num_link_tokens + num_non_link):.4%}")
|
| else:
|
| logger.info("No valid labels found in the dataset.")
|
|
|
| def __len__(self):
|
| return len(self.data)
|
|
|
| def __getitem__(self, idx):
|
| return self.data[idx]
|
|
|
|
|
| def collate_fn(batch: List[Dict], max_seq_length: int) -> Dict[str, torch.Tensor]:
|
| """
|
| Custom collate function for batching with padding to a fixed max_seq_length.
|
|
|
| Args:
|
| batch (List[Dict]): A list of samples from the dataset.
|
| max_seq_length (int): The maximum sequence length to pad all samples to.
|
|
|
| Returns:
|
| Dict[str, torch.Tensor]: A dictionary containing stacked and padded tensors.
|
| """
|
|
|
| input_ids = []
|
| attention_mask = []
|
| labels = []
|
|
|
| for x in batch:
|
| seq_len = len(x['input_ids'])
|
|
|
|
|
| if seq_len > max_seq_length:
|
| x['input_ids'] = x['input_ids'][:max_seq_length]
|
| x['attention_mask'] = x['attention_mask'][:max_seq_length]
|
| x['labels'] = x['labels'][:max_seq_length]
|
| seq_len = max_seq_length
|
|
|
|
|
| padding_len = max_seq_length - seq_len
|
|
|
|
|
| padded_input = torch.cat([
|
| x['input_ids'],
|
| torch.zeros(padding_len, dtype=torch.long)
|
| ])
|
|
|
|
|
| padded_mask = torch.cat([
|
| x['attention_mask'],
|
| torch.zeros(padding_len, dtype=torch.long)
|
| ])
|
|
|
|
|
| padded_labels = torch.cat([
|
| x['labels'],
|
| torch.full((padding_len,), -100, dtype=torch.long)
|
| ])
|
|
|
| input_ids.append(padded_input)
|
| attention_mask.append(padded_mask)
|
| labels.append(padded_labels)
|
|
|
| return {
|
| 'input_ids': torch.stack(input_ids),
|
| 'attention_mask': torch.stack(attention_mask),
|
| 'labels': torch.stack(labels)
|
| }
|
|
|
|
|
| class DeBERTaForTokenClassification(nn.Module):
|
| """DeBERTa model for token classification"""
|
|
|
| def __init__(self, model_name: str, num_labels: int, dropout_rate: float = 0.1):
|
| super().__init__()
|
|
|
| self.config = DebertaV2Config.from_pretrained(model_name)
|
| self.deberta = DebertaV2Model.from_pretrained(model_name)
|
|
|
| self.dropout = nn.Dropout(dropout_rate)
|
| self.classifier = nn.Linear(self.config.hidden_size, num_labels)
|
|
|
|
|
| nn.init.xavier_uniform_(self.classifier.weight)
|
| nn.init.zeros_(self.classifier.bias)
|
|
|
| def forward(
|
| self,
|
| input_ids: torch.Tensor,
|
| attention_mask: torch.Tensor,
|
| labels: Optional[torch.Tensor] = None
|
| ) -> Dict[str, torch.Tensor]:
|
|
|
| outputs = self.deberta(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask
|
| )
|
|
|
| sequence_output = outputs.last_hidden_state
|
| sequence_output = self.dropout(sequence_output)
|
| logits = self.classifier(sequence_output)
|
|
|
| loss = None
|
| if labels is not None:
|
|
|
|
|
|
|
| weight = torch.tensor([1.0, 25.0]).to(logits.device)
|
|
|
| loss_fct = nn.CrossEntropyLoss(weight=weight, ignore_index=-100)
|
|
|
|
|
| loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
|
|
| return {
|
| 'loss': loss,
|
| 'logits': logits
|
| }
|
|
|
|
|
| def compute_metrics(predictions: np.ndarray, labels: np.ndarray, mask: np.ndarray) -> Dict[str, float]:
|
| """Compute metrics for token classification"""
|
|
|
|
|
|
|
|
|
|
|
| predictions_flat = predictions.flatten()
|
| labels_flat = labels.flatten()
|
| mask_flat = mask.flatten()
|
|
|
|
|
| valid_indices = (labels_flat != -100) & (mask_flat == 1)
|
|
|
| preds_filtered = predictions_flat[valid_indices]
|
| labels_filtered = labels_flat[valid_indices]
|
|
|
|
|
| if len(labels_filtered) == 0:
|
| return {
|
| 'accuracy': 0.0,
|
| 'precision': 0.0,
|
| 'recall': 0.0,
|
| 'f1': 0.0,
|
| 'f1_non_link': 0.0,
|
| 'f1_link': 0.0,
|
| 'precision_link': 0.0,
|
| 'recall_link': 0.0,
|
| 'num_valid_tokens': 0
|
| }
|
|
|
|
|
| accuracy = accuracy_score(labels_filtered, preds_filtered)
|
|
|
| precision, recall, f1, support = precision_recall_fscore_support(
|
| labels_filtered, preds_filtered, average='binary', pos_label=1, zero_division=0
|
| )
|
|
|
|
|
| unique_labels_in_data = np.unique(labels_filtered)
|
|
|
| precision_per_class = [0.0, 0.0]
|
| recall_per_class = [0.0, 0.0]
|
| f1_per_class = [0.0, 0.0]
|
|
|
|
|
| if 0 in unique_labels_in_data:
|
| p0, r0, f0, _ = precision_recall_fscore_support(
|
| labels_filtered, preds_filtered, labels=[0], average='binary', pos_label=0, zero_division=0
|
| )
|
| precision_per_class[0] = p0
|
| recall_per_class[0] = r0
|
| f1_per_class[0] = f0
|
|
|
|
|
| if 1 in unique_labels_in_data:
|
| p1, r1, f1_1, _ = precision_recall_fscore_support(
|
| labels_filtered, preds_filtered, labels=[1], average='binary', pos_label=1, zero_division=0
|
| )
|
| precision_per_class[1] = p1
|
| recall_per_class[1] = r1
|
| f1_per_class[1] = f1_1
|
|
|
| return {
|
| 'accuracy': accuracy,
|
| 'precision': precision,
|
| 'recall': recall,
|
| 'f1': f1,
|
| 'f1_non_link': f1_per_class[0],
|
| 'f1_link': f1_per_class[1],
|
| 'precision_link': precision_per_class[1],
|
| 'recall_link': recall_per_class[1],
|
| 'num_valid_tokens': len(labels_filtered)
|
| }
|
|
|
|
|
| class Trainer:
|
| """Trainer class for DeBERTa token classification"""
|
|
|
| def __init__(self, config: TrainingConfig):
|
| self.config = config
|
| set_seed(config.seed)
|
|
|
|
|
| wandb.init(
|
| project=config.wandb_project,
|
| name=config.wandb_name,
|
| config=vars(config)
|
| )
|
|
|
|
|
| Path(config.output_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| self.train_dataset = LinkTokenDataset(config.train_file)
|
| self.val_dataset = LinkTokenDataset(config.val_file)
|
|
|
|
|
|
|
| self.train_loader = DataLoader(
|
| self.train_dataset,
|
| batch_size=config.batch_size,
|
| shuffle=False,
|
| num_workers=config.num_workers,
|
| collate_fn=functools.partial(collate_fn, max_seq_length=config.max_length),
|
| pin_memory=True
|
| )
|
|
|
| self.val_loader = DataLoader(
|
| self.val_dataset,
|
| batch_size=config.batch_size * 2,
|
| shuffle=False,
|
| num_workers=config.num_workers,
|
| collate_fn=functools.partial(collate_fn, max_seq_length=config.max_length),
|
| pin_memory=True
|
| )
|
|
|
|
|
| self.model = DeBERTaForTokenClassification(
|
| config.model_name,
|
| config.num_labels
|
| ).to(config.device)
|
|
|
|
|
| total_params = sum(p.numel() for p in self.model.parameters())
|
| trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| logger.info(f"Total parameters: {total_params:,}")
|
| logger.info(f"Trainable parameters: {trainable_params:,}")
|
|
|
|
|
| no_decay = ['bias', 'LayerNorm.weight']
|
| optimizer_grouped_parameters = [
|
| {
|
| 'params': [p for n, p in self.model.named_parameters()
|
| if not any(nd in n for nd in no_decay)],
|
| 'weight_decay': config.weight_decay
|
| },
|
| {
|
| 'params': [p for n, p in self.model.named_parameters()
|
| if any(nd in n for nd in no_decay)],
|
| 'weight_decay': 0.0
|
| }
|
| ]
|
|
|
| self.optimizer = torch.optim.AdamW(
|
| optimizer_grouped_parameters,
|
| lr=config.learning_rate,
|
| eps=1e-6
|
| )
|
|
|
|
|
| total_steps = len(self.train_loader) * config.num_epochs // config.gradient_accumulation_steps
|
| warmup_steps = int(total_steps * config.warmup_ratio)
|
|
|
| self.scheduler = get_linear_schedule_with_warmup(
|
| self.optimizer,
|
| num_warmup_steps=warmup_steps,
|
| num_training_steps=total_steps
|
| )
|
|
|
|
|
| self.global_step = 0
|
| self.best_val_loss = float('inf')
|
| self.patience_counter = 0
|
|
|
| def train_epoch(self, epoch: int) -> float:
|
| """Train for one epoch"""
|
| self.model.train()
|
| total_loss = 0
|
| progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
|
|
|
|
|
| early_stop_triggered = False
|
|
|
| for step, batch in enumerate(progress_bar):
|
|
|
| batch = {k: v.to(self.config.device) for k, v in batch.items()}
|
|
|
|
|
| if self.config.bf16:
|
| with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
|
| outputs = self.model(**batch)
|
| loss = outputs['loss'] / self.config.gradient_accumulation_steps
|
| else:
|
| outputs = self.model(**batch)
|
| loss = outputs['loss'] / self.config.gradient_accumulation_steps
|
|
|
|
|
| if torch.isnan(loss) or torch.isinf(loss):
|
| logger.warning(f"NaN or Inf loss encountered at step {self.global_step}. Skipping backward pass.")
|
| self.optimizer.zero_grad()
|
| continue
|
|
|
| loss.backward()
|
| total_loss += loss.item()
|
|
|
|
|
| if (step + 1) % self.config.gradient_accumulation_steps == 0:
|
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
| self.optimizer.step()
|
| self.scheduler.step()
|
| self.optimizer.zero_grad()
|
| self.global_step += 1
|
|
|
|
|
| if self.global_step % self.config.logging_steps == 0:
|
| current_loss = loss.item() * self.config.gradient_accumulation_steps
|
| wandb.log({
|
| 'train/loss': current_loss,
|
| 'train/learning_rate': self.scheduler.get_last_lr()[0],
|
| 'train/global_step': self.global_step,
|
| 'train/epoch': epoch
|
| })
|
| progress_bar.set_postfix({'loss': f'{current_loss:.4f}'})
|
|
|
|
|
| if self.global_step % self.config.eval_steps == 0:
|
| eval_metrics = self.evaluate()
|
| logger.info(f"Step {self.global_step} - Eval metrics: {eval_metrics}")
|
|
|
|
|
| current_val_loss = eval_metrics['loss']
|
| if current_val_loss < self.best_val_loss - self.config.min_delta:
|
| self.best_val_loss = current_val_loss
|
| self.patience_counter = 0
|
| self.save_model(f"best_model_step_{self.global_step}")
|
| logger.info(f"New best validation loss: {self.best_val_loss:.4f}")
|
| else:
|
| self.patience_counter += 1
|
| logger.info(f"No improvement in validation loss. Patience: {self.patience_counter}/{self.config.patience}")
|
| if self.patience_counter >= self.config.patience:
|
| logger.info("Early stopping triggered mid-epoch!")
|
| early_stop_triggered = True
|
| break
|
|
|
| if early_stop_triggered:
|
| break
|
|
|
| return total_loss / len(self.train_loader) if len(self.train_loader) > 0 else 0.0
|
|
|
| def evaluate(self) -> Dict[str, float]:
|
| """Evaluate on validation set"""
|
| self.model.eval()
|
|
|
| all_predictions = []
|
| all_labels = []
|
| all_masks = []
|
| total_loss = 0
|
| num_batches = 0
|
|
|
| with torch.no_grad():
|
| for batch in tqdm(self.val_loader, desc="Evaluating"):
|
| batch = {k: v.to(self.config.device) for k, v in batch.items()}
|
|
|
|
|
| if self.config.bf16:
|
| with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
|
| outputs = self.model(**batch)
|
| else:
|
| outputs = self.model(**batch)
|
|
|
| if outputs['loss'] is not None:
|
| total_loss += outputs['loss'].item()
|
| num_batches += 1
|
|
|
| predictions = torch.argmax(outputs['logits'], dim=-1)
|
|
|
| all_predictions.append(predictions.cpu().numpy())
|
| all_labels.append(batch['labels'].cpu().numpy())
|
| all_masks.append(batch['attention_mask'].cpu().numpy())
|
|
|
| all_predictions = np.concatenate(all_predictions, axis=0)
|
| all_labels = np.concatenate(all_labels, axis=0)
|
| all_masks = np.concatenate(all_masks, axis=0)
|
|
|
|
|
| metrics = compute_metrics(all_predictions, all_labels, all_masks)
|
| metrics['loss'] = total_loss / num_batches if num_batches > 0 else 0.0
|
|
|
|
|
| wandb.log({f'eval/{k}': v for k, v in metrics.items()}, step=self.global_step)
|
|
|
| self.model.train()
|
| return metrics
|
|
|
| def _enforce_checkpoint_limit(self):
|
| """
|
| Enforce checkpoint retention:
|
| - Count all subdirectories in output_dir except 'final_model'
|
| - Keep at most config.max_checkpoints
|
| - Delete oldest by modification time
|
| - Always protect:
|
| * 'final_model'
|
| * latest 'best_model_epoch_*'
|
| * latest 'best_model_step_*'
|
| """
|
| output_dir = Path(self.config.output_dir)
|
| if not output_dir.exists():
|
| return
|
|
|
|
|
| subdirs = [p for p in output_dir.iterdir() if p.is_dir()]
|
| if not subdirs:
|
| return
|
|
|
|
|
| protected = set()
|
|
|
|
|
| final_dir = output_dir / "final_model"
|
| if final_dir.exists() and final_dir.is_dir():
|
| protected.add(final_dir.resolve())
|
|
|
| if self.config.protect_latest_epoch_step:
|
|
|
| epoch_dirs = [d for d in subdirs if re.match(r"best_model_epoch_\d+$", d.name)]
|
| if epoch_dirs:
|
| latest_epoch = max(epoch_dirs, key=lambda d: d.stat().st_mtime)
|
| protected.add(latest_epoch.resolve())
|
|
|
|
|
| step_dirs = [d for d in subdirs if re.match(r"best_model_step_\d+$", d.name)]
|
| if step_dirs:
|
| latest_step = max(step_dirs, key=lambda d: d.stat().st_mtime)
|
| protected.add(latest_step.resolve())
|
|
|
|
|
| counted = [d for d in subdirs if d.resolve() != final_dir.resolve()]
|
|
|
|
|
| if len(counted) <= self.config.max_checkpoints:
|
| return
|
|
|
|
|
| counted_sorted = sorted(counted, key=lambda d: d.stat().st_mtime)
|
|
|
|
|
| to_delete = []
|
| current = len(counted)
|
| for d in counted_sorted:
|
| if current <= self.config.max_checkpoints:
|
| break
|
| if d.resolve() in protected:
|
| continue
|
| to_delete.append(d)
|
| current -= 1
|
|
|
|
|
|
|
|
|
| if current > self.config.max_checkpoints:
|
|
|
| extras = [d for d in counted_sorted if d.resolve() != final_dir.resolve() and d not in to_delete]
|
| for d in extras:
|
| if current <= self.config.max_checkpoints:
|
| break
|
|
|
| if d.resolve() in protected:
|
| continue
|
| to_delete.append(d)
|
| current -= 1
|
|
|
|
|
| for d in to_delete:
|
| try:
|
| shutil.rmtree(d)
|
| logger.info(f"Deleted old checkpoint: {d}")
|
| except Exception as e:
|
| logger.warning(f"Failed to delete {d}: {e}")
|
|
|
| def save_model(self, name: str):
|
| """Save model checkpoint"""
|
| save_path = Path(self.config.output_dir) / name
|
| save_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| torch.save(self.model.state_dict(), save_path / 'pytorch_model.bin')
|
|
|
|
|
| with open(save_path / 'training_config.json', 'w') as f:
|
| json.dump(vars(self.config), f, indent=4)
|
|
|
| logger.info(f"Model saved to {save_path}")
|
|
|
|
|
| self._enforce_checkpoint_limit()
|
|
|
| def train(self):
|
| """Main training loop"""
|
| logger.info("Starting training...")
|
| logger.info(f"Training samples: {len(self.train_dataset)}")
|
| logger.info(f"Validation samples: {len(self.val_dataset)}")
|
|
|
|
|
| total_optimization_steps = (len(self.train_loader) + self.config.gradient_accumulation_steps - 1) // self.config.gradient_accumulation_steps * self.config.num_epochs
|
| logger.info(f"Total optimization steps: {total_optimization_steps}")
|
| logger.info(f"Early stopping: monitoring validation loss with patience={self.config.patience}")
|
|
|
| for epoch in range(self.config.num_epochs):
|
| logger.info(f"\n{'='*50}")
|
| logger.info(f"Epoch {epoch + 1}/{self.config.num_epochs}")
|
|
|
|
|
| avg_train_loss = self.train_epoch(epoch + 1)
|
| logger.info(f"Average training loss: {avg_train_loss:.4f}")
|
|
|
|
|
| if self.patience_counter >= self.config.patience:
|
| logger.info("Training stopped due to early stopping during epoch.")
|
| break
|
|
|
|
|
| eval_metrics = self.evaluate()
|
| logger.info(f"Epoch {epoch + 1} - Eval metrics:")
|
| for key, value in eval_metrics.items():
|
| logger.info(f" {key}: {value:.4f}")
|
|
|
|
|
| current_val_loss = eval_metrics['loss']
|
| if current_val_loss < self.best_val_loss - self.config.min_delta:
|
| self.best_val_loss = current_val_loss
|
| self.patience_counter = 0
|
| self.save_model(f"best_model_epoch_{epoch + 1}")
|
| logger.info(f"New best validation loss at epoch end: {self.best_val_loss:.4f}")
|
| else:
|
| self.patience_counter += 1
|
| logger.info(f"No improvement in validation loss. Patience: {self.patience_counter}/{self.config.patience}")
|
|
|
|
|
| if self.patience_counter >= self.config.patience:
|
| logger.info("Training stopped due to early stopping")
|
| break
|
|
|
|
|
| self.save_model("final_model")
|
|
|
| logger.info("Training completed!")
|
| logger.info(f"Best validation loss: {self.best_val_loss:.4f}")
|
| wandb.finish()
|
|
|
|
|
| def main():
|
| """Main function"""
|
| config = TrainingConfig()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| trainer = Trainer(config)
|
| trainer.train()
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|