| from torchinfo import summary
|
| from model import build_transformer
|
| from util import create_resources
|
| import yaml
|
| import torch
|
| from tqdm import tqdm
|
| import os
|
| import wandb
|
| import matplotlib.pyplot as plt
|
| from matplotlib import font_manager
|
| import re
|
|
|
|
|
| mangal_font_path = "Mangal.TTf"
|
| devanagari_font = font_manager.FontProperties(fname=mangal_font_path)
|
|
|
|
|
| class NoamScheduler:
|
| def __init__(self, optimizer, d_model, warmup_steps):
|
| self.optimizer = optimizer
|
| self.d_model = d_model
|
| self.warmup_steps = warmup_steps
|
| self.step_num = 0
|
|
|
| def step(self):
|
| self.step_num += 1
|
| lr = self.get_lr()
|
| for param_group in self.optimizer.param_groups:
|
| param_group["lr"] = lr
|
| return lr
|
|
|
| def get_lr(self):
|
| step = max(self.step_num, 1)
|
| arg1 = step ** (-0.5)
|
| arg2 = step * (self.warmup_steps ** (-1.5))
|
| return (self.d_model ** (-0.5)) * min(arg1, arg2)
|
|
|
|
|
| class Trainer:
|
| def __init__(
|
| self,
|
| model,
|
| optimizer,
|
| scheduler,
|
| criterion,
|
| device,
|
| tokenizer_src,
|
| tokenizer_tgt,
|
| seq_len,
|
| ):
|
| self.model = model
|
| self.optimizer = optimizer
|
| self.scheduler = scheduler
|
| self.criterion = criterion
|
| self.device = device
|
| self.tgt_tokenizer = tokenizer_tgt
|
| self.src_tokenizer = tokenizer_src
|
| self.seq_len = seq_len
|
|
|
| def train_epoch(self, dataloader):
|
| self.model.train()
|
| torch.cuda.empty_cache()
|
| running_loss = 0.0
|
| total_tokens = 0
|
| progress_bar = tqdm(
|
| enumerate(dataloader), desc="Training", total=len(dataloader)
|
| )
|
|
|
| for batch_idx, batch in progress_bar:
|
|
|
| encoder_input = batch["encoder_input"].to(self.device)
|
|
|
| decoder_input = batch["decoder_input"].to(self.device)
|
|
|
|
|
| encoder_mask = batch["encoder_mask"].to(self.device)
|
| decoder_mask = batch["decoder_mask"].to(self.device)
|
|
|
| encoder_output = self.model.encode(encoder_input, encoder_mask)
|
| decoder_output = self.model.decode(
|
| decoder_input, encoder_output, encoder_mask, decoder_mask
|
| )
|
| projection_output = self.model.project(decoder_output)
|
|
|
| label = batch["label"].to(self.device)
|
|
|
| loss = self.criterion(
|
| projection_output.view(-1, self.tgt_tokenizer.get_vocab_size()),
|
| label.view(-1),
|
| )
|
|
|
| loss.backward()
|
|
|
| self.optimizer.step()
|
| current_lr = self.scheduler.step()
|
| self.optimizer.zero_grad()
|
|
|
| pad_id = 1
|
| with torch.no_grad():
|
| non_pad = label.ne(pad_id)
|
| num_nonpad_tokens = non_pad.sum().item()
|
| running_loss += loss.item() * num_nonpad_tokens
|
| total_tokens += num_nonpad_tokens
|
|
|
| if (batch_idx + 1) % 50 == 0:
|
| wandb.log(
|
| {
|
| "batch_loss": loss.item(),
|
| "learning_rate": current_lr,
|
| "batch": batch_idx + 1,
|
| }
|
| )
|
|
|
| epoch_loss = running_loss / total_tokens if total_tokens > 0 else 0.0
|
| return epoch_loss
|
|
|
| def save_checkpoint(self, epoch, output_dir):
|
| os.makedirs(output_dir, exist_ok=True)
|
| checkpoint = {
|
| "epoch": epoch,
|
| "model_state_dict": self.model.state_dict(),
|
| "optimizer_state_dict": self.optimizer.state_dict(),
|
| "scheduler_state": self.scheduler.step_num,
|
| }
|
| torch.save(checkpoint, os.path.join(output_dir, f"model_epoch_{epoch}.pth"))
|
| print(f"Checkpoint saved at epoch {epoch}")
|
|
|
| def run(self, train_loader, epochs, output_dir, start_epoch=1):
|
| for epoch in range(start_epoch, epochs + 1):
|
| train_loss = self.train_epoch(train_loader)
|
| current_lr = self.scheduler.get_lr()
|
|
|
| wandb.log(
|
| {"epoch": epoch, "train_loss": train_loss, "learning_rate": current_lr}
|
| )
|
|
|
| self.save_checkpoint(epoch, output_dir)
|
|
|
|
|
| def load_latest_checkpoint(model, optimizer, scheduler, model_directory, device):
|
| if not os.path.isdir(model_directory):
|
| return None, 1
|
| checkpoint_files = []
|
| for filename in os.listdir(model_directory):
|
| if filename.endswith(".pth"):
|
| match = re.search(r"model_epoch_(\d+)\.pth", filename)
|
| if match:
|
| epoch = int(match.group(1))
|
| checkpoint_files.append((epoch, filename))
|
|
|
| if not checkpoint_files:
|
| return None, 1
|
|
|
|
|
| latest_epoch, latest_filename = max(checkpoint_files, key=lambda x: x[0])
|
| ckpt_path = os.path.join(model_directory, latest_filename)
|
| ckpt = torch.load(ckpt_path, map_location=device)
|
|
|
| model.load_state_dict(ckpt["model_state_dict"])
|
| optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| scheduler.step_num = ckpt["scheduler_state"]
|
| start_epoch = ckpt["epoch"] + 1
|
| print(f"Resuming Training from epoch {ckpt['epoch']}")
|
| return ckpt, start_epoch
|
|
|
|
|
| def main():
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| (
|
| train_dataloader,
|
| valid_dataloader,
|
| test_dataloader,
|
| tokenizer_src,
|
| tokenizer_tgt,
|
| ) = create_resources()
|
| src_vocab_size = tokenizer_src.get_vocab_size()
|
| tgt_vocab_size = tokenizer_tgt.get_vocab_size()
|
|
|
| with open("config.yaml", "r") as file:
|
| config = yaml.safe_load(file)
|
|
|
| run = wandb.init(
|
| entity="training-transformers-vast",
|
| project="AttentionTranslate-sai", config=config)
|
|
|
| model = build_transformer(
|
| src_vocab_size,
|
| tgt_vocab_size,
|
| config["seq_len"],
|
| config["seq_len"],
|
| config["num_enc_dec_blocks"],
|
| config["num_of_heads"],
|
| config["d_model"],
|
| )
|
|
|
| model = model.to(device)
|
|
|
| wandb.watch(model, log="all")
|
|
|
| criterion = torch.nn.CrossEntropyLoss(
|
| ignore_index=tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1
|
| ).to(device)
|
|
|
| optimizer = torch.optim.AdamW(
|
| model.parameters(), lr=config["learning_rate"], betas=(0.9, 0.98), eps=1e-9
|
| )
|
|
|
| scheduler = NoamScheduler(optimizer, config["d_model"], config["warmup_steps"])
|
|
|
| start_epoch = 1
|
|
|
| if config["resume_training"]:
|
| ckpt, start_epoch = load_latest_checkpoint(
|
| model, optimizer, scheduler, config["model_directory"], device
|
| )
|
|
|
| if start_epoch == 1:
|
| print("Training from scratch.")
|
|
|
| trainer = Trainer(
|
| model,
|
| optimizer,
|
| scheduler,
|
| criterion,
|
| device,
|
| tokenizer_src,
|
| tokenizer_tgt,
|
| config["seq_len"],
|
| )
|
|
|
|
|
|
|
|
|
|
|
| trainer.run(
|
| train_dataloader, config["epochs"], config["model_directory"], start_epoch
|
| )
|
|
|
| run.finish()
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|