Spaces:
Running
Running
OliverPerrin
commited on
Commit
·
590a604
1
Parent(s):
7977c7d
Full training run, code cleanup, mypy/ruff fixes
Browse files- Completed 3-epoch full training run (21 hours)
- Emotion F1: 94.6%, Topic Accuracy: 94.2%, ROUGE-like: 0.36
- Code simplification and standardized docstrings across all modules
- Fixed all mypy type errors (47 files pass)
- Fixed all ruff linting errors and reformatted code
- All 75 tests passing
- Added GPU optimizations: TF32, Flash Attention, memory-efficient SDP
- Training configs optimized for RTX 4070 12GB
- tqdm progress bars for training and evaluation
- configs/config.yaml +7 -0
- configs/training/dev.yaml +10 -13
- configs/training/full.yaml +7 -7
- configs/training/medium.yaml +9 -13
- outputs/evaluation_report.json +25 -25
- outputs/training_history.json +51 -13
- scripts/download_data.py +9 -1
- scripts/eval_rouge.py +9 -1
- scripts/evaluate.py +125 -143
- scripts/export_model.py +9 -1
- scripts/export_tokenizer.py +9 -1
- scripts/inference.py +9 -1
- scripts/preprocess_data.py +10 -1
- scripts/train.py +157 -241
- src/api/app.py +8 -1
- src/api/dependencies.py +8 -1
- src/api/routes.py +9 -1
- src/api/schemas.py +8 -1
- src/data/dataloader.py +43 -34
- src/data/dataset.py +10 -1
- src/data/preprocessing.py +54 -70
- src/data/tokenization.py +10 -1
- src/inference/factory.py +9 -1
- src/inference/pipeline.py +96 -72
- src/inference/postprocessing.py +8 -1
- src/models/decoder.py +14 -13
- src/models/encoder.py +16 -17
- src/models/factory.py +11 -1
- src/models/feedforward.py +8 -2
- src/models/heads.py +14 -14
- src/models/multitask.py +9 -15
- src/models/positional_encoding.py +6 -4
- src/training/metrics.py +10 -1
- src/training/trainer.py +248 -435
- src/training/utils.py +9 -1
- src/utils/config.py +8 -1
- src/utils/io.py +8 -1
- src/utils/labels.py +9 -1
- src/utils/logging.py +8 -1
- src/utils/random.py +8 -1
- tests/test_training/test_trainer.py +1 -1
configs/config.yaml
CHANGED
|
@@ -4,6 +4,13 @@ defaults:
|
|
| 4 |
- training: default
|
| 5 |
- _self_
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
checkpoint_out: "checkpoints/best.pt"
|
| 8 |
labels_out: "artifacts/labels.json"
|
| 9 |
history_out: "outputs/training_history.json"
|
|
|
|
| 4 |
- training: default
|
| 5 |
- _self_
|
| 6 |
|
| 7 |
+
# Hydra config - prevent output dir conflicts
|
| 8 |
+
hydra:
|
| 9 |
+
run:
|
| 10 |
+
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 11 |
+
sweep:
|
| 12 |
+
dir: outputs/multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 13 |
+
|
| 14 |
checkpoint_out: "checkpoints/best.pt"
|
| 15 |
labels_out: "artifacts/labels.json"
|
| 16 |
history_out: "outputs/training_history.json"
|
configs/training/dev.yaml
CHANGED
|
@@ -1,35 +1,32 @@
|
|
| 1 |
# Development/Testing Configuration for FLAN-T5-base
|
| 2 |
# Fast iteration for debugging and testing changes
|
| 3 |
-
# Training time: ~
|
| 4 |
# Use: python scripts/train.py training=dev
|
| 5 |
|
| 6 |
dataloader:
|
| 7 |
-
batch_size: 8
|
| 8 |
shuffle: true
|
| 9 |
-
num_workers: 4
|
| 10 |
pin_memory: true
|
| 11 |
|
| 12 |
optimizer:
|
| 13 |
name: adamw
|
| 14 |
-
lr: 5.0e-5
|
| 15 |
weight_decay: 0.01
|
| 16 |
|
| 17 |
scheduler:
|
| 18 |
name: cosine
|
| 19 |
-
warmup_steps: 50
|
| 20 |
|
| 21 |
trainer:
|
| 22 |
-
max_epochs: 1
|
| 23 |
gradient_clip_norm: 1.0
|
| 24 |
-
gradient_accumulation_steps: 1 # No accumulation
|
| 25 |
-
validation_max_length: 64
|
| 26 |
label_smoothing: 0.1
|
| 27 |
task_weights:
|
| 28 |
summarization: 1.0
|
| 29 |
emotion: 1.0
|
| 30 |
topic: 1.0
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
max_train_samples: 2000 # Reduced for faster iteration
|
| 34 |
-
max_val_samples: 200
|
| 35 |
-
validation_frequency: 1000 # Validate once during training
|
|
|
|
| 1 |
# Development/Testing Configuration for FLAN-T5-base
|
| 2 |
# Fast iteration for debugging and testing changes
|
| 3 |
+
# Training time: ~3-5 minutes on RTX 4070 12GB
|
| 4 |
# Use: python scripts/train.py training=dev
|
| 5 |
|
| 6 |
dataloader:
|
| 7 |
+
batch_size: 8 # Safe for 12GB VRAM - no shared memory spillover
|
| 8 |
shuffle: true
|
| 9 |
+
num_workers: 4
|
| 10 |
pin_memory: true
|
| 11 |
|
| 12 |
optimizer:
|
| 13 |
name: adamw
|
| 14 |
+
lr: 5.0e-5 # Higher LR for fast convergence
|
| 15 |
weight_decay: 0.01
|
| 16 |
|
| 17 |
scheduler:
|
| 18 |
name: cosine
|
| 19 |
+
warmup_steps: 50
|
| 20 |
|
| 21 |
trainer:
|
| 22 |
+
max_epochs: 1
|
| 23 |
gradient_clip_norm: 1.0
|
| 24 |
+
gradient_accumulation_steps: 1 # No accumulation - maximize throughput
|
| 25 |
+
validation_max_length: 64
|
| 26 |
label_smoothing: 0.1
|
| 27 |
task_weights:
|
| 28 |
summarization: 1.0
|
| 29 |
emotion: 1.0
|
| 30 |
topic: 1.0
|
| 31 |
+
max_train_samples: 2000
|
| 32 |
+
max_val_samples: 200
|
|
|
|
|
|
|
|
|
configs/training/full.yaml
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
# Full Training Configuration for FLAN-T5-base
|
| 2 |
# Complete training run on all data
|
| 3 |
-
# Training time: ~6-8 hours on RTX 4070
|
| 4 |
# Use: python scripts/train.py training=full
|
| 5 |
|
| 6 |
dataloader:
|
| 7 |
-
batch_size:
|
| 8 |
shuffle: true
|
| 9 |
-
num_workers:
|
| 10 |
pin_memory: true
|
| 11 |
|
| 12 |
optimizer:
|
|
@@ -16,12 +16,12 @@ optimizer:
|
|
| 16 |
|
| 17 |
scheduler:
|
| 18 |
name: cosine
|
| 19 |
-
warmup_steps:
|
| 20 |
|
| 21 |
trainer:
|
| 22 |
-
max_epochs:
|
| 23 |
-
gradient_clip_norm: 0
|
| 24 |
-
gradient_accumulation_steps: 6 # Effective batch
|
| 25 |
validation_max_length: 128
|
| 26 |
label_smoothing: 0.1
|
| 27 |
task_weights:
|
|
|
|
| 1 |
# Full Training Configuration for FLAN-T5-base
|
| 2 |
# Complete training run on all data
|
| 3 |
+
# Training time: ~6-8 hours on RTX 4070 12GB
|
| 4 |
# Use: python scripts/train.py training=full
|
| 5 |
|
| 6 |
dataloader:
|
| 7 |
+
batch_size: 6 # Optimized for 12GB VRAM
|
| 8 |
shuffle: true
|
| 9 |
+
num_workers: 6
|
| 10 |
pin_memory: true
|
| 11 |
|
| 12 |
optimizer:
|
|
|
|
| 16 |
|
| 17 |
scheduler:
|
| 18 |
name: cosine
|
| 19 |
+
warmup_steps: 500 # ~3% of steps
|
| 20 |
|
| 21 |
trainer:
|
| 22 |
+
max_epochs: 3 # 3 epochs usually sufficient, avoids overfit
|
| 23 |
+
gradient_clip_norm: 1.0
|
| 24 |
+
gradient_accumulation_steps: 6 # Effective batch = 36
|
| 25 |
validation_max_length: 128
|
| 26 |
label_smoothing: 0.1
|
| 27 |
task_weights:
|
configs/training/medium.yaml
CHANGED
|
@@ -1,36 +1,32 @@
|
|
| 1 |
# Medium Configuration for FLAN-T5-base
|
| 2 |
# Balanced approach - good results in reasonable time
|
| 3 |
-
# Training time: ~2-3 hours on RTX 4070
|
| 4 |
# Use: python scripts/train.py training=medium
|
| 5 |
-
# Note: FLAN-T5-base has 12 layers (vs BART's 6), may need smaller batch
|
| 6 |
|
| 7 |
dataloader:
|
| 8 |
-
batch_size:
|
| 9 |
shuffle: true
|
| 10 |
-
num_workers:
|
| 11 |
pin_memory: true
|
| 12 |
|
| 13 |
optimizer:
|
| 14 |
name: adamw
|
| 15 |
-
lr:
|
| 16 |
weight_decay: 0.01
|
| 17 |
|
| 18 |
scheduler:
|
| 19 |
name: cosine
|
| 20 |
-
warmup_steps:
|
| 21 |
|
| 22 |
trainer:
|
| 23 |
max_epochs: 3
|
| 24 |
-
gradient_clip_norm: 0
|
| 25 |
-
gradient_accumulation_steps:
|
| 26 |
-
validation_max_length:
|
| 27 |
label_smoothing: 0.1
|
| 28 |
task_weights:
|
| 29 |
summarization: 1.0
|
| 30 |
emotion: 1.0
|
| 31 |
topic: 1.0
|
| 32 |
-
|
| 33 |
-
# Medium dataset - good representative sample
|
| 34 |
max_train_samples: 50000
|
| 35 |
-
max_val_samples: 5000
|
| 36 |
-
validation_frequency: 5000
|
|
|
|
| 1 |
# Medium Configuration for FLAN-T5-base
|
| 2 |
# Balanced approach - good results in reasonable time
|
| 3 |
+
# Training time: ~2-3 hours on RTX 4070 12GB
|
| 4 |
# Use: python scripts/train.py training=medium
|
|
|
|
| 5 |
|
| 6 |
dataloader:
|
| 7 |
+
batch_size: 6 # Optimized for 12GB VRAM with accumulation
|
| 8 |
shuffle: true
|
| 9 |
+
num_workers: 6
|
| 10 |
pin_memory: true
|
| 11 |
|
| 12 |
optimizer:
|
| 13 |
name: adamw
|
| 14 |
+
lr: 3.0e-5 # Slightly higher - compensates for effective batch
|
| 15 |
weight_decay: 0.01
|
| 16 |
|
| 17 |
scheduler:
|
| 18 |
name: cosine
|
| 19 |
+
warmup_steps: 300 # ~5% of steps
|
| 20 |
|
| 21 |
trainer:
|
| 22 |
max_epochs: 3
|
| 23 |
+
gradient_clip_norm: 1.0
|
| 24 |
+
gradient_accumulation_steps: 3 # Effective batch = 18
|
| 25 |
+
validation_max_length: 96
|
| 26 |
label_smoothing: 0.1
|
| 27 |
task_weights:
|
| 28 |
summarization: 1.0
|
| 29 |
emotion: 1.0
|
| 30 |
topic: 1.0
|
|
|
|
|
|
|
| 31 |
max_train_samples: 50000
|
| 32 |
+
max_val_samples: 5000
|
|
|
outputs/evaluation_report.json
CHANGED
|
@@ -1,44 +1,44 @@
|
|
| 1 |
{
|
| 2 |
-
"split": "
|
| 3 |
"summarization": {
|
| 4 |
-
"rouge_like": 0.
|
| 5 |
-
"bleu": 0.
|
| 6 |
},
|
| 7 |
"emotion": {
|
| 8 |
-
"f1_macro": 0.
|
| 9 |
},
|
| 10 |
"topic": {
|
| 11 |
-
"accuracy": 0.
|
| 12 |
"classification_report": {
|
| 13 |
"Business": {
|
| 14 |
-
"precision": 0.
|
| 15 |
-
"recall": 0.
|
| 16 |
-
"f1-score": 0.
|
| 17 |
-
"support":
|
| 18 |
},
|
| 19 |
"Sci/Tech": {
|
| 20 |
-
"precision": 0.
|
| 21 |
-
"recall": 0.
|
| 22 |
-
"f1-score": 0.
|
| 23 |
-
"support":
|
| 24 |
},
|
| 25 |
"Sports": {
|
| 26 |
-
"precision": 0.
|
| 27 |
-
"recall": 0.
|
| 28 |
-
"f1-score": 0.
|
| 29 |
-
"support":
|
| 30 |
},
|
| 31 |
"World": {
|
| 32 |
-
"precision": 0.
|
| 33 |
-
"recall": 0.
|
| 34 |
-
"f1-score": 0.
|
| 35 |
-
"support":
|
| 36 |
},
|
| 37 |
"macro avg": {
|
| 38 |
-
"precision": 0.
|
| 39 |
-
"recall": 0.
|
| 40 |
-
"f1-score": 0.
|
| 41 |
-
"support":
|
| 42 |
}
|
| 43 |
}
|
| 44 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"split": "val",
|
| 3 |
"summarization": {
|
| 4 |
+
"rouge_like": 0.35947467920968945,
|
| 5 |
+
"bleu": 0.09027012433010549
|
| 6 |
},
|
| 7 |
"emotion": {
|
| 8 |
+
"f1_macro": 0.9455000162124634
|
| 9 |
},
|
| 10 |
"topic": {
|
| 11 |
+
"accuracy": 0.94175,
|
| 12 |
"classification_report": {
|
| 13 |
"Business": {
|
| 14 |
+
"precision": 0.9319045973038369,
|
| 15 |
+
"recall": 0.8986666666666666,
|
| 16 |
+
"f1-score": 0.9149838791786866,
|
| 17 |
+
"support": 3000
|
| 18 |
},
|
| 19 |
"Sci/Tech": {
|
| 20 |
+
"precision": 0.9055627425614489,
|
| 21 |
+
"recall": 0.9333333333333333,
|
| 22 |
+
"f1-score": 0.9192383453709784,
|
| 23 |
+
"support": 3000
|
| 24 |
},
|
| 25 |
"Sports": {
|
| 26 |
+
"precision": 0.9856475300400535,
|
| 27 |
+
"recall": 0.9843333333333333,
|
| 28 |
+
"f1-score": 0.9849899933288859,
|
| 29 |
+
"support": 3000
|
| 30 |
},
|
| 31 |
"World": {
|
| 32 |
+
"precision": 0.9446836700894335,
|
| 33 |
+
"recall": 0.9506666666666667,
|
| 34 |
+
"f1-score": 0.9476657252035222,
|
| 35 |
+
"support": 3000
|
| 36 |
},
|
| 37 |
"macro avg": {
|
| 38 |
+
"precision": 0.9419496349986932,
|
| 39 |
+
"recall": 0.94175,
|
| 40 |
+
"f1-score": 0.9417194857705183,
|
| 41 |
+
"support": 12000
|
| 42 |
}
|
| 43 |
}
|
| 44 |
}
|
outputs/training_history.json
CHANGED
|
@@ -1,21 +1,59 @@
|
|
| 1 |
{
|
| 2 |
"train_epoch_1": {
|
| 3 |
-
"summarization_loss": 3.
|
| 4 |
-
"summarization_rouge_like": 0.
|
| 5 |
-
"emotion_loss": 0.
|
| 6 |
-
"emotion_f1": 0.
|
| 7 |
-
"topic_loss":
|
| 8 |
-
"topic_accuracy": 0.
|
| 9 |
-
"total_loss":
|
| 10 |
"epoch": 1.0
|
| 11 |
},
|
| 12 |
"val_epoch_1": {
|
| 13 |
-
"summarization_loss":
|
| 14 |
-
"summarization_rouge_like": 0.
|
| 15 |
-
"emotion_loss": 0.
|
| 16 |
-
"emotion_f1": 0.
|
| 17 |
-
"topic_loss": 0.
|
| 18 |
-
"topic_accuracy": 0.
|
| 19 |
"epoch": 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
}
|
| 21 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"train_epoch_1": {
|
| 3 |
+
"summarization_loss": 3.222269726091524,
|
| 4 |
+
"summarization_rouge_like": 0.4348834303103812,
|
| 5 |
+
"emotion_loss": 0.2681197640352259,
|
| 6 |
+
"emotion_f1": 0.4939010590246358,
|
| 7 |
+
"topic_loss": 0.2817161389551497,
|
| 8 |
+
"topic_accuracy": 0.9126178087058748,
|
| 9 |
+
"total_loss": 3.7721057520380095,
|
| 10 |
"epoch": 1.0
|
| 11 |
},
|
| 12 |
"val_epoch_1": {
|
| 13 |
+
"summarization_loss": 2.9376416314440097,
|
| 14 |
+
"summarization_rouge_like": 0.4621969238397049,
|
| 15 |
+
"emotion_loss": 0.07456208207925424,
|
| 16 |
+
"emotion_f1": 0.922451647864638,
|
| 17 |
+
"topic_loss": 0.18789680490184146,
|
| 18 |
+
"topic_accuracy": 0.9368641532016696,
|
| 19 |
"epoch": 1.0
|
| 20 |
+
},
|
| 21 |
+
"train_epoch_2": {
|
| 22 |
+
"summarization_loss": 3.0815064049717713,
|
| 23 |
+
"summarization_rouge_like": 0.44604443152864864,
|
| 24 |
+
"emotion_loss": 0.04770229796717623,
|
| 25 |
+
"emotion_f1": 0.9407868445694336,
|
| 26 |
+
"topic_loss": 0.1507136240392336,
|
| 27 |
+
"topic_accuracy": 0.9498742677227413,
|
| 28 |
+
"total_loss": 3.279922429068798,
|
| 29 |
+
"epoch": 2.0
|
| 30 |
+
},
|
| 31 |
+
"val_epoch_2": {
|
| 32 |
+
"summarization_loss": 2.8898715693603942,
|
| 33 |
+
"summarization_rouge_like": 0.4654528613816311,
|
| 34 |
+
"emotion_loss": 0.05001389549380918,
|
| 35 |
+
"emotion_f1": 0.9344953305524384,
|
| 36 |
+
"topic_loss": 0.1755385091801308,
|
| 37 |
+
"topic_accuracy": 0.9435966487133395,
|
| 38 |
+
"epoch": 2.0
|
| 39 |
+
},
|
| 40 |
+
"train_epoch_3": {
|
| 41 |
+
"summarization_loss": 3.0340622767404044,
|
| 42 |
+
"summarization_rouge_like": 0.4502876682264882,
|
| 43 |
+
"emotion_loss": 0.025708710505635942,
|
| 44 |
+
"emotion_f1": 0.9647584015837614,
|
| 45 |
+
"topic_loss": 0.11707986947991166,
|
| 46 |
+
"topic_accuracy": 0.9614479064357344,
|
| 47 |
+
"total_loss": 3.176850952874497,
|
| 48 |
+
"epoch": 3.0
|
| 49 |
+
},
|
| 50 |
+
"val_epoch_3": {
|
| 51 |
+
"summarization_loss": 2.865455434181104,
|
| 52 |
+
"summarization_rouge_like": 0.46790124713702563,
|
| 53 |
+
"emotion_loss": 0.05574661032417156,
|
| 54 |
+
"emotion_f1": 0.940105742034193,
|
| 55 |
+
"topic_loss": 0.19245651335709887,
|
| 56 |
+
"topic_accuracy": 0.942998204667858,
|
| 57 |
+
"epoch": 3.0
|
| 58 |
}
|
| 59 |
}
|
scripts/download_data.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset download script for LexiMind.
|
| 3 |
+
|
| 4 |
+
Downloads training datasets from various sources including HuggingFace Hub,
|
| 5 |
+
Kaggle, and Project Gutenberg. Handles automatic conversion to JSONL format.
|
| 6 |
+
|
| 7 |
+
Author: Oliver Perrin
|
| 8 |
+
Date: December 2025
|
| 9 |
+
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
scripts/eval_rouge.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ROUGE evaluation script for LexiMind.
|
| 3 |
+
|
| 4 |
+
Computes ROUGE-1, ROUGE-2, and ROUGE-L scores on summarization outputs
|
| 5 |
+
with support for batched inference and customizable metrics.
|
| 6 |
+
|
| 7 |
+
Author: Oliver Perrin
|
| 8 |
+
Date: December 2025
|
| 9 |
+
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
scripts/evaluate.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from __future__ import annotations
|
|
@@ -8,9 +13,12 @@ from __future__ import annotations
|
|
| 8 |
import argparse
|
| 9 |
import json
|
| 10 |
import sys
|
|
|
|
| 11 |
from pathlib import Path
|
| 12 |
-
from typing import Any,
|
| 13 |
|
|
|
|
|
|
|
| 14 |
import torch
|
| 15 |
from sklearn.preprocessing import MultiLabelBinarizer
|
| 16 |
from tqdm import tqdm
|
|
@@ -19,14 +27,7 @@ PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
|
| 19 |
if str(PROJECT_ROOT) not in sys.path:
|
| 20 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 21 |
|
| 22 |
-
|
| 23 |
-
import seaborn as sns
|
| 24 |
-
|
| 25 |
-
from src.data.dataset import (
|
| 26 |
-
load_emotion_jsonl,
|
| 27 |
-
load_summarization_jsonl,
|
| 28 |
-
load_topic_jsonl,
|
| 29 |
-
)
|
| 30 |
from src.inference.factory import create_inference_pipeline
|
| 31 |
from src.training.metrics import (
|
| 32 |
accuracy,
|
|
@@ -38,80 +39,67 @@ from src.training.metrics import (
|
|
| 38 |
)
|
| 39 |
from src.utils.config import load_yaml
|
| 40 |
|
| 41 |
-
|
| 42 |
-
"train": ("train",),
|
| 43 |
-
"val": ("val", "validation"),
|
| 44 |
-
"test": ("test",),
|
| 45 |
-
}
|
| 46 |
|
|
|
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
| 51 |
for ext in ("jsonl", "json"):
|
| 52 |
-
|
| 53 |
-
if
|
| 54 |
-
return
|
| 55 |
-
raise FileNotFoundError(f"Missing {split} split
|
| 56 |
|
| 57 |
|
| 58 |
-
def
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
default="val",
|
| 63 |
-
choices=["train", "val", "test"],
|
| 64 |
-
help="Dataset split to evaluate.",
|
| 65 |
-
)
|
| 66 |
-
parser.add_argument(
|
| 67 |
-
"--checkpoint", default="checkpoints/best.pt", help="Path to the trained checkpoint."
|
| 68 |
-
)
|
| 69 |
-
parser.add_argument("--labels", default="artifacts/labels.json", help="Label metadata JSON.")
|
| 70 |
-
parser.add_argument(
|
| 71 |
-
"--data-config", default="configs/data/datasets.yaml", help="Data configuration YAML."
|
| 72 |
-
)
|
| 73 |
-
parser.add_argument(
|
| 74 |
-
"--model-config", default="configs/model/base.yaml", help="Model architecture YAML."
|
| 75 |
-
)
|
| 76 |
-
parser.add_argument(
|
| 77 |
-
"--device",
|
| 78 |
-
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 79 |
-
help="Device for evaluation.",
|
| 80 |
-
)
|
| 81 |
-
parser.add_argument(
|
| 82 |
-
"--batch-size",
|
| 83 |
-
type=int,
|
| 84 |
-
default=16,
|
| 85 |
-
help="Batch size for generation/classification during evaluation.",
|
| 86 |
-
)
|
| 87 |
-
parser.add_argument(
|
| 88 |
-
"--output-dir", default="outputs", help="Directory to save evaluation artifacts."
|
| 89 |
-
)
|
| 90 |
-
return parser.parse_args()
|
| 91 |
|
| 92 |
|
| 93 |
-
|
| 94 |
-
for start in range(0, len(items), size):
|
| 95 |
-
yield items[start : start + size]
|
| 96 |
|
| 97 |
|
| 98 |
-
def plot_confusion_matrix(cm, labels,
|
|
|
|
| 99 |
plt.figure(figsize=(10, 8))
|
| 100 |
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
|
| 101 |
plt.xlabel("Predicted")
|
| 102 |
plt.ylabel("True")
|
| 103 |
plt.title("Topic Classification Confusion Matrix")
|
| 104 |
plt.tight_layout()
|
| 105 |
-
plt.savefig(
|
| 106 |
plt.close()
|
| 107 |
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
def main() -> None:
|
| 110 |
args = parse_args()
|
| 111 |
-
|
|
|
|
| 112 |
output_dir = Path(args.output_dir)
|
| 113 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 114 |
|
|
|
|
|
|
|
| 115 |
pipeline, metadata = create_inference_pipeline(
|
| 116 |
checkpoint_path=args.checkpoint,
|
| 117 |
labels_path=args.labels,
|
|
@@ -120,100 +108,94 @@ def main() -> None:
|
|
| 120 |
device=args.device,
|
| 121 |
)
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
for batch in tqdm(
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
| 181 |
)
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
#
|
| 184 |
-
print("Evaluating Topic Classification...")
|
| 185 |
-
topic_preds = []
|
| 186 |
-
topic_targets = []
|
| 187 |
-
total_batches = (len(topic_examples) + args.batch_size - 1) // args.batch_size
|
| 188 |
-
for batch in tqdm(
|
| 189 |
-
chunks(topic_examples, args.batch_size), total=total_batches, desc="Topic", unit="batch"
|
| 190 |
-
):
|
| 191 |
-
inputs = [example.text for example in batch]
|
| 192 |
-
topic_predictions = pipeline.predict_topics(inputs)
|
| 193 |
-
topic_preds.extend([pred.label for pred in topic_predictions])
|
| 194 |
-
topic_targets.extend([example.topic for example in batch])
|
| 195 |
-
|
| 196 |
-
topic_accuracy = accuracy(topic_preds, topic_targets)
|
| 197 |
-
topic_report = classification_report_dict(topic_preds, topic_targets, labels=metadata.topic)
|
| 198 |
-
topic_cm = get_confusion_matrix(topic_preds, topic_targets, labels=metadata.topic)
|
| 199 |
-
|
| 200 |
-
# Save Confusion Matrix
|
| 201 |
cm_path = output_dir / "topic_confusion_matrix.png"
|
| 202 |
plot_confusion_matrix(topic_cm, metadata.topic, cm_path)
|
| 203 |
-
print(f"Confusion matrix saved
|
|
|
|
|
|
|
| 204 |
|
| 205 |
results = {
|
| 206 |
"split": args.split,
|
| 207 |
-
"summarization": {"rouge_like":
|
| 208 |
"emotion": {"f1_macro": emotion_f1},
|
| 209 |
-
"topic": {"accuracy":
|
| 210 |
}
|
| 211 |
|
| 212 |
report_path = output_dir / "evaluation_report.json"
|
| 213 |
-
with open(report_path, "w"
|
| 214 |
json.dump(results, f, indent=2)
|
| 215 |
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
print(json.dumps(results, indent=2))
|
| 218 |
|
| 219 |
|
|
|
|
| 1 |
"""
|
| 2 |
+
Evaluation script for LexiMind.
|
| 3 |
+
|
| 4 |
+
Computes ROUGE/BLEU for summarization, multi-label F1 for emotion,
|
| 5 |
+
and accuracy with confusion matrix for topic classification.
|
| 6 |
+
|
| 7 |
+
Author: Oliver Perrin
|
| 8 |
+
Date: December 2025
|
| 9 |
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
|
|
|
| 13 |
import argparse
|
| 14 |
import json
|
| 15 |
import sys
|
| 16 |
+
import time
|
| 17 |
from pathlib import Path
|
| 18 |
+
from typing import Any, Callable, List
|
| 19 |
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
import seaborn as sns
|
| 22 |
import torch
|
| 23 |
from sklearn.preprocessing import MultiLabelBinarizer
|
| 24 |
from tqdm import tqdm
|
|
|
|
| 27 |
if str(PROJECT_ROOT) not in sys.path:
|
| 28 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 29 |
|
| 30 |
+
from src.data.dataset import load_emotion_jsonl, load_summarization_jsonl, load_topic_jsonl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
from src.inference.factory import create_inference_pipeline
|
| 32 |
from src.training.metrics import (
|
| 33 |
accuracy,
|
|
|
|
| 39 |
)
|
| 40 |
from src.utils.config import load_yaml
|
| 41 |
|
| 42 |
+
# --------------- Data Loading ---------------
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
SPLIT_ALIASES = {"train": ("train",), "val": ("val", "validation"), "test": ("test",)}
|
| 45 |
|
| 46 |
+
|
| 47 |
+
def load_split(root: Path, split: str, loader: Callable[[str], List[Any]]) -> List[Any]:
|
| 48 |
+
"""Load a dataset split, checking aliases."""
|
| 49 |
+
for alias in SPLIT_ALIASES.get(split, (split,)):
|
| 50 |
for ext in ("jsonl", "json"):
|
| 51 |
+
path = root / f"{alias}.{ext}"
|
| 52 |
+
if path.exists():
|
| 53 |
+
return list(loader(str(path)))
|
| 54 |
+
raise FileNotFoundError(f"Missing {split} split in {root}")
|
| 55 |
|
| 56 |
|
| 57 |
+
def chunks(items: List, size: int):
|
| 58 |
+
"""Yield batches of items."""
|
| 59 |
+
for i in range(0, len(items), size):
|
| 60 |
+
yield items[i : i + size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
+
# --------------- Visualization ---------------
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
+
def plot_confusion_matrix(cm, labels, path: Path) -> None:
|
| 67 |
+
"""Save confusion matrix heatmap."""
|
| 68 |
plt.figure(figsize=(10, 8))
|
| 69 |
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
|
| 70 |
plt.xlabel("Predicted")
|
| 71 |
plt.ylabel("True")
|
| 72 |
plt.title("Topic Classification Confusion Matrix")
|
| 73 |
plt.tight_layout()
|
| 74 |
+
plt.savefig(path)
|
| 75 |
plt.close()
|
| 76 |
|
| 77 |
|
| 78 |
+
# --------------- Main ---------------
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def parse_args() -> argparse.Namespace:
|
| 82 |
+
p = argparse.ArgumentParser(description="Evaluate LexiMind")
|
| 83 |
+
p.add_argument("--split", default="val", choices=["train", "val", "test"])
|
| 84 |
+
p.add_argument("--checkpoint", default="checkpoints/best.pt")
|
| 85 |
+
p.add_argument("--labels", default="artifacts/labels.json")
|
| 86 |
+
p.add_argument("--data-config", default="configs/data/datasets.yaml")
|
| 87 |
+
p.add_argument("--model-config", default="configs/model/base.yaml")
|
| 88 |
+
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
| 89 |
+
p.add_argument("--batch-size", type=int, default=148) # Larger batch for inference (no grads)
|
| 90 |
+
p.add_argument("--output-dir", default="outputs")
|
| 91 |
+
return p.parse_args()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
def main() -> None:
|
| 95 |
args = parse_args()
|
| 96 |
+
start_time = time.perf_counter()
|
| 97 |
+
|
| 98 |
output_dir = Path(args.output_dir)
|
| 99 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 100 |
|
| 101 |
+
# Load pipeline
|
| 102 |
+
print("Loading model...")
|
| 103 |
pipeline, metadata = create_inference_pipeline(
|
| 104 |
checkpoint_path=args.checkpoint,
|
| 105 |
labels_path=args.labels,
|
|
|
|
| 108 |
device=args.device,
|
| 109 |
)
|
| 110 |
|
| 111 |
+
# Load data
|
| 112 |
+
data_cfg = load_yaml(args.data_config).data
|
| 113 |
+
summ_data = load_split(
|
| 114 |
+
Path(data_cfg["processed"]["summarization"]), args.split, load_summarization_jsonl
|
| 115 |
+
)
|
| 116 |
+
emot_data = load_split(Path(data_cfg["processed"]["emotion"]), args.split, load_emotion_jsonl)
|
| 117 |
+
topic_data = load_split(Path(data_cfg["processed"]["topic"]), args.split, load_topic_jsonl)
|
| 118 |
+
|
| 119 |
+
print(f"\nEvaluating on {args.split} split:")
|
| 120 |
+
print(f" Summarization: {len(summ_data)} samples")
|
| 121 |
+
print(f" Emotion: {len(emot_data)} samples")
|
| 122 |
+
print(f" Topic: {len(topic_data)} samples")
|
| 123 |
+
|
| 124 |
+
# --------------- Summarization ---------------
|
| 125 |
+
|
| 126 |
+
print("\nSummarization...")
|
| 127 |
+
preds, refs = [], []
|
| 128 |
+
for batch in tqdm(list(chunks(summ_data, args.batch_size)), desc="Summarization", unit="batch"):
|
| 129 |
+
preds.extend(pipeline.summarize([ex.source for ex in batch]))
|
| 130 |
+
refs.extend([ex.summary for ex in batch])
|
| 131 |
+
|
| 132 |
+
rouge = rouge_like(preds, refs)
|
| 133 |
+
bleu = calculate_bleu(preds, refs)
|
| 134 |
+
print(f" ROUGE-like: {rouge:.4f}, BLEU: {bleu:.4f}")
|
| 135 |
+
|
| 136 |
+
# --------------- Emotion ---------------
|
| 137 |
+
|
| 138 |
+
print("\nEmotion Classification...")
|
| 139 |
+
binarizer = MultiLabelBinarizer(classes=metadata.emotion)
|
| 140 |
+
binarizer.fit([[label] for label in metadata.emotion])
|
| 141 |
+
label_idx = {label: i for i, label in enumerate(metadata.emotion)}
|
| 142 |
+
|
| 143 |
+
pred_vecs, target_vecs = [], []
|
| 144 |
+
for batch in tqdm(list(chunks(emot_data, args.batch_size)), desc="Emotion", unit="batch"):
|
| 145 |
+
emotion_results = pipeline.predict_emotions([ex.text for ex in batch], threshold=0.3)
|
| 146 |
+
targets = binarizer.transform([list(ex.emotions) for ex in batch])
|
| 147 |
+
|
| 148 |
+
for pred, target in zip(emotion_results, targets, strict=False):
|
| 149 |
+
vec = torch.zeros(len(metadata.emotion))
|
| 150 |
+
for lbl in pred.labels:
|
| 151 |
+
if lbl in label_idx:
|
| 152 |
+
vec[label_idx[lbl]] = 1.0
|
| 153 |
+
pred_vecs.append(vec)
|
| 154 |
+
target_vecs.append(torch.tensor(target, dtype=torch.float32))
|
| 155 |
+
|
| 156 |
+
emotion_f1 = multilabel_f1(torch.stack(pred_vecs), torch.stack(target_vecs))
|
| 157 |
+
print(f" F1 (macro): {emotion_f1:.4f}")
|
| 158 |
+
|
| 159 |
+
# --------------- Topic ---------------
|
| 160 |
+
|
| 161 |
+
print("\nTopic Classification...")
|
| 162 |
+
topic_pred_labels: List[str] = []
|
| 163 |
+
topic_true_labels: List[str] = []
|
| 164 |
+
for batch in tqdm(list(chunks(topic_data, args.batch_size)), desc="Topic", unit="batch"):
|
| 165 |
+
topic_results = pipeline.predict_topics([ex.text for ex in batch])
|
| 166 |
+
topic_pred_labels.extend([r.label for r in topic_results])
|
| 167 |
+
topic_true_labels.extend([ex.topic for ex in batch])
|
| 168 |
+
|
| 169 |
+
topic_acc = accuracy(topic_pred_labels, topic_true_labels)
|
| 170 |
+
topic_report = classification_report_dict(
|
| 171 |
+
topic_pred_labels, topic_true_labels, labels=metadata.topic
|
| 172 |
)
|
| 173 |
+
topic_cm = get_confusion_matrix(topic_pred_labels, topic_true_labels, labels=metadata.topic)
|
| 174 |
+
print(f" Accuracy: {topic_acc:.4f}")
|
| 175 |
|
| 176 |
+
# Save confusion matrix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
cm_path = output_dir / "topic_confusion_matrix.png"
|
| 178 |
plot_confusion_matrix(topic_cm, metadata.topic, cm_path)
|
| 179 |
+
print(f" Confusion matrix saved: {cm_path}")
|
| 180 |
+
|
| 181 |
+
# --------------- Save Results ---------------
|
| 182 |
|
| 183 |
results = {
|
| 184 |
"split": args.split,
|
| 185 |
+
"summarization": {"rouge_like": rouge, "bleu": bleu},
|
| 186 |
"emotion": {"f1_macro": emotion_f1},
|
| 187 |
+
"topic": {"accuracy": topic_acc, "classification_report": topic_report},
|
| 188 |
}
|
| 189 |
|
| 190 |
report_path = output_dir / "evaluation_report.json"
|
| 191 |
+
with open(report_path, "w") as f:
|
| 192 |
json.dump(results, f, indent=2)
|
| 193 |
|
| 194 |
+
total_time = time.perf_counter() - start_time
|
| 195 |
+
print(f"\n{'=' * 50}")
|
| 196 |
+
print(f"Evaluation complete in {total_time:.1f}s")
|
| 197 |
+
print(f"Report saved: {report_path}")
|
| 198 |
+
print(f"{'=' * 50}")
|
| 199 |
print(json.dumps(results, indent=2))
|
| 200 |
|
| 201 |
|
scripts/export_model.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model export script for LexiMind.
|
| 3 |
+
|
| 4 |
+
Rebuilds the multitask model from configuration and exports trained weights
|
| 5 |
+
for deployment or distribution.
|
| 6 |
+
|
| 7 |
+
Author: Oliver Perrin
|
| 8 |
+
Date: December 2025
|
| 9 |
+
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
scripts/export_tokenizer.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tokenizer export script for LexiMind.
|
| 3 |
+
|
| 4 |
+
Saves the FLAN-T5 tokenizer to the artifacts directory for reproducible
|
| 5 |
+
inference without requiring network access.
|
| 6 |
+
|
| 7 |
+
Author: Oliver Perrin
|
| 8 |
+
Date: December 2025
|
| 9 |
+
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
scripts/inference.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference script for the LexiMind multitask model.
|
| 3 |
+
|
| 4 |
+
Command-line interface for running summarization, emotion detection, and topic
|
| 5 |
+
classification on arbitrary text inputs.
|
| 6 |
+
|
| 7 |
+
Author: Oliver Perrin
|
| 8 |
+
Date: December 2025
|
| 9 |
+
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
scripts/preprocess_data.py
CHANGED
|
@@ -1,4 +1,13 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data preprocessing script for LexiMind.
|
| 3 |
+
|
| 4 |
+
Transforms raw datasets into standardized JSONL splits for training. Handles
|
| 5 |
+
summarization, emotion classification, topic classification, and book paragraph
|
| 6 |
+
extraction with text cleaning.
|
| 7 |
+
|
| 8 |
+
Author: Oliver Perrin
|
| 9 |
+
Date: December 2025
|
| 10 |
+
"""
|
| 11 |
|
| 12 |
from __future__ import annotations
|
| 13 |
|
scripts/train.py
CHANGED
|
@@ -1,13 +1,20 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import json
|
| 6 |
-
import platform
|
| 7 |
import sys
|
| 8 |
-
import
|
| 9 |
from pathlib import Path
|
| 10 |
-
from typing import Any, Dict, Sequence
|
| 11 |
|
| 12 |
import hydra
|
| 13 |
import torch
|
|
@@ -37,8 +44,7 @@ from src.training.utils import set_seed
|
|
| 37 |
from src.utils.io import save_state
|
| 38 |
from src.utils.labels import LabelMetadata, save_label_metadata
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
|
| 43 |
SPLIT_ALIASES: Dict[str, Sequence[str]] = {
|
| 44 |
"train": ("train",),
|
|
@@ -47,286 +53,214 @@ SPLIT_ALIASES: Dict[str, Sequence[str]] = {
|
|
| 47 |
}
|
| 48 |
|
| 49 |
|
| 50 |
-
def
|
| 51 |
-
splits
|
| 52 |
-
|
| 53 |
-
|
| 54 |
for alias in aliases:
|
| 55 |
-
for
|
| 56 |
-
|
| 57 |
-
if
|
| 58 |
-
splits[
|
| 59 |
-
found = True
|
| 60 |
break
|
| 61 |
-
if
|
| 62 |
break
|
| 63 |
-
if not
|
| 64 |
-
raise FileNotFoundError(f"Missing {
|
| 65 |
return splits
|
| 66 |
|
| 67 |
|
| 68 |
-
def
|
| 69 |
-
"""
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
limit = int(max_train)
|
| 76 |
-
if original_len > limit:
|
| 77 |
-
splits["train"] = splits["train"][:limit]
|
| 78 |
-
print(f"Limited 'train' split from {original_len} to {limit} samples")
|
| 79 |
|
| 80 |
-
if max_val is not None and "val" in splits:
|
| 81 |
-
original_len = len(splits["val"])
|
| 82 |
-
limit = int(max_val)
|
| 83 |
-
if original_len > limit:
|
| 84 |
-
splits["val"] = splits["val"][:limit]
|
| 85 |
-
print(f"Limited 'val' split from {original_len} to {limit} samples")
|
| 86 |
|
|
|
|
| 87 |
|
| 88 |
-
def compile_model_safe(model: torch.nn.Module) -> Tuple[Any, str]:
|
| 89 |
-
"""
|
| 90 |
-
Safely compile model with best available backend.
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
-
# NOTE: The 'inductor' backend causes NaN gradients during backward pass with
|
| 98 |
-
# bfloat16 autocast on the decoder (seq2seq tasks). This is a known issue.
|
| 99 |
-
# Use 'aot_eager' which provides graph optimization without inductor's codegen.
|
| 100 |
-
# See: debug_compile_config.py and test_compile_modes.py for investigation.
|
| 101 |
|
| 102 |
-
|
| 103 |
-
try:
|
| 104 |
-
print("Attempting to compile with 'aot_eager' backend...")
|
| 105 |
-
compiled_model = torch.compile(model, backend="aot_eager")
|
| 106 |
-
print("✓ Successfully compiled with 'aot_eager' backend")
|
| 107 |
-
return cast(torch.nn.Module, compiled_model), "aot_eager"
|
| 108 |
-
except Exception as e:
|
| 109 |
-
warnings.warn(f"aot_eager backend failed: {e}", stacklevel=2)
|
| 110 |
-
|
| 111 |
-
# Fallback: Try other backends (inductor may work for encoder-only tasks)
|
| 112 |
-
backends_to_try = ["eager"]
|
| 113 |
-
if system != "Windows":
|
| 114 |
-
# On Linux, inductor might work for some configurations
|
| 115 |
-
backends_to_try = ["eager", "inductor"]
|
| 116 |
-
|
| 117 |
-
for backend in backends_to_try:
|
| 118 |
-
try:
|
| 119 |
-
print(f"Attempting to compile with '{backend}' backend...")
|
| 120 |
-
compiled_model = torch.compile(model, backend=backend)
|
| 121 |
-
# Trigger a dummy run or just return? torch.compile is lazy.
|
| 122 |
-
# I assume it works if the call succeeds, runtime errors handled later.
|
| 123 |
-
print(f"✓ Successfully compiled with '{backend}' backend")
|
| 124 |
-
return cast(torch.nn.Module, compiled_model), backend
|
| 125 |
-
except Exception as e:
|
| 126 |
-
print(f"✗ '{backend}' backend failed: {e}")
|
| 127 |
-
continue
|
| 128 |
-
|
| 129 |
-
# No compilation worked, return original model
|
| 130 |
-
warnings.warn("All torch.compile backends failed, using uncompiled model", stacklevel=2)
|
| 131 |
-
return model, "none"
|
| 132 |
|
| 133 |
|
| 134 |
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
| 135 |
def main(cfg: DictConfig) -> None:
|
|
|
|
| 136 |
print(OmegaConf.to_yaml(cfg))
|
| 137 |
set_seed(cfg.seed)
|
| 138 |
|
| 139 |
-
# Enable TF32 for Ampere
|
| 140 |
-
# This provides significant speedup on RTX 4070
|
| 141 |
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
| 142 |
-
print("
|
| 143 |
torch.set_float32_matmul_precision("high")
|
| 144 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 145 |
torch.backends.cudnn.allow_tf32 = True
|
| 146 |
-
torch.backends.cudnn.benchmark = True # Auto-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
# Access configs directly from Hydra cfg object
|
| 149 |
data_cfg = cfg.data
|
| 150 |
-
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
num_decoder_layers=cfg.model.num_decoder_layers,
|
| 157 |
-
num_attention_heads=cfg.model.num_attention_heads,
|
| 158 |
-
ffn_dim=cfg.model.ffn_dim,
|
| 159 |
-
dropout=cfg.model.dropout,
|
| 160 |
-
use_pretrained=cfg.model.use_pretrained,
|
| 161 |
-
pretrained_model_name=cfg.model.pretrained_model_name,
|
| 162 |
-
activation=getattr(cfg.model, "activation", "gelu"),
|
| 163 |
-
use_relative_position_bias=getattr(cfg.model, "use_relative_position_bias", False),
|
| 164 |
-
)
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
summarization_splits = _read_examples(summarization_dir, load_summarization_jsonl)
|
| 171 |
-
emotion_splits = _read_examples(emotion_dir, load_emotion_jsonl)
|
| 172 |
-
topic_splits = _read_examples(topic_dir, load_topic_jsonl)
|
| 173 |
-
|
| 174 |
-
# Apply sample limits if configured (e.g. for dev/medium runs)
|
| 175 |
-
trainer_cfg = training_cfg.get("trainer", {})
|
| 176 |
-
print("\nApplying dataset limits...")
|
| 177 |
-
_limit_samples(summarization_splits, trainer_cfg)
|
| 178 |
-
_limit_samples(emotion_splits, trainer_cfg)
|
| 179 |
-
_limit_samples(topic_splits, trainer_cfg)
|
| 180 |
-
print("Dataset limits applied.\n")
|
| 181 |
-
|
| 182 |
-
tokenizer_section = data_cfg.get("tokenizer", {})
|
| 183 |
-
tokenizer_config = TokenizerConfig(
|
| 184 |
-
pretrained_model_name=tokenizer_section.get("pretrained_model_name", "google/flan-t5-base"),
|
| 185 |
-
max_length=int(tokenizer_section.get("max_length", 512)),
|
| 186 |
-
lower=bool(tokenizer_section.get("lower", False)),
|
| 187 |
-
)
|
| 188 |
-
tokenizer = Tokenizer(tokenizer_config)
|
| 189 |
|
| 190 |
-
|
| 191 |
-
summarization_val = SummarizationDataset(summarization_splits["val"])
|
| 192 |
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
topic_train = TopicDataset(topic_splits["train"])
|
| 197 |
topic_val = TopicDataset(topic_splits["val"], encoder=topic_train.encoder)
|
| 198 |
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
num_workers = int(
|
| 204 |
-
pin_memory = bool(
|
| 205 |
-
|
| 206 |
|
| 207 |
train_loaders = {
|
| 208 |
"summarization": build_summarization_dataloader(
|
| 209 |
-
|
| 210 |
tokenizer,
|
|
|
|
|
|
|
|
|
|
| 211 |
batch_size=batch_size,
|
| 212 |
-
shuffle=shuffle,
|
| 213 |
-
max_source_length=max_length,
|
| 214 |
-
max_target_length=max_length,
|
| 215 |
num_workers=num_workers,
|
| 216 |
pin_memory=pin_memory,
|
| 217 |
),
|
| 218 |
"emotion": build_emotion_dataloader(
|
| 219 |
-
|
| 220 |
tokenizer,
|
|
|
|
|
|
|
| 221 |
batch_size=batch_size,
|
| 222 |
-
shuffle=shuffle,
|
| 223 |
-
max_length=max_length,
|
| 224 |
num_workers=num_workers,
|
| 225 |
pin_memory=pin_memory,
|
| 226 |
),
|
| 227 |
"topic": build_topic_dataloader(
|
| 228 |
topic_train,
|
| 229 |
tokenizer,
|
|
|
|
|
|
|
| 230 |
batch_size=batch_size,
|
| 231 |
-
shuffle=shuffle,
|
| 232 |
-
max_length=max_length,
|
| 233 |
num_workers=num_workers,
|
| 234 |
pin_memory=pin_memory,
|
| 235 |
),
|
| 236 |
}
|
| 237 |
-
|
| 238 |
val_loaders = {
|
| 239 |
"summarization": build_summarization_dataloader(
|
| 240 |
-
|
| 241 |
tokenizer,
|
| 242 |
-
batch_size=batch_size,
|
| 243 |
shuffle=False,
|
| 244 |
-
max_source_length=
|
| 245 |
-
max_target_length=
|
|
|
|
| 246 |
num_workers=num_workers,
|
| 247 |
pin_memory=pin_memory,
|
| 248 |
),
|
| 249 |
"emotion": build_emotion_dataloader(
|
| 250 |
-
|
| 251 |
tokenizer,
|
| 252 |
-
batch_size=batch_size,
|
| 253 |
shuffle=False,
|
| 254 |
-
max_length=
|
|
|
|
| 255 |
num_workers=num_workers,
|
| 256 |
pin_memory=pin_memory,
|
| 257 |
),
|
| 258 |
"topic": build_topic_dataloader(
|
| 259 |
topic_val,
|
| 260 |
tokenizer,
|
| 261 |
-
batch_size=batch_size,
|
| 262 |
shuffle=False,
|
| 263 |
-
max_length=
|
|
|
|
| 264 |
num_workers=num_workers,
|
| 265 |
pin_memory=pin_memory,
|
| 266 |
),
|
| 267 |
}
|
| 268 |
|
|
|
|
|
|
|
|
|
|
| 269 |
device = torch.device(cfg.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
model = build_multitask_model(
|
| 271 |
tokenizer,
|
| 272 |
-
num_emotions=len(
|
| 273 |
num_topics=len(topic_train.topic_classes),
|
| 274 |
config=model_cfg,
|
| 275 |
).to(device)
|
| 276 |
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
#
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
else:
|
| 293 |
-
backend_used = "disabled"
|
| 294 |
-
if use_compile and model.decoder is not None:
|
| 295 |
-
# Compile decoder.forward but keep step/greedy_decode uncompiled for generation
|
| 296 |
-
model.decoder, _ = compile_model_safe(model.decoder)
|
| 297 |
-
|
| 298 |
-
# Compile heads
|
| 299 |
-
if use_compile:
|
| 300 |
-
for name, head in model.heads.items():
|
| 301 |
-
compiled_head, _ = compile_model_safe(head)
|
| 302 |
-
model.heads[name] = compiled_head
|
| 303 |
-
# Update the registered module as well to ensure parameters are tracked correctly
|
| 304 |
-
setattr(model, f"head_{name}", compiled_head)
|
| 305 |
-
|
| 306 |
-
print(f"Using compilation backend: {backend_used}")
|
| 307 |
-
|
| 308 |
-
# Verify weights loaded correctly (check for NaNs/Infs)
|
| 309 |
-
print("\n=== Weight Loading Verification ===")
|
| 310 |
-
has_issues = False
|
| 311 |
-
for name, param in model.named_parameters():
|
| 312 |
-
if torch.isnan(param).any():
|
| 313 |
-
print(f"WARNING: NaN in {name}")
|
| 314 |
-
has_issues = True
|
| 315 |
-
if torch.isinf(param).any():
|
| 316 |
-
print(f"WARNING: Inf in {name}")
|
| 317 |
-
has_issues = True
|
| 318 |
-
if not has_issues:
|
| 319 |
-
print("✓ No NaNs or Infs found in model parameters.")
|
| 320 |
-
print("=== Verification Complete ===\n")
|
| 321 |
-
|
| 322 |
-
trainer_cfg = training_cfg.get("trainer", {})
|
| 323 |
trainer = Trainer(
|
| 324 |
model=model,
|
| 325 |
optimizer=optimizer,
|
| 326 |
config=TrainerConfig(
|
| 327 |
max_epochs=int(trainer_cfg.get("max_epochs", 1)),
|
| 328 |
gradient_clip_norm=float(trainer_cfg.get("gradient_clip_norm", 1.0)),
|
| 329 |
-
logging_interval=int(trainer_cfg.get("logging_interval", 50)),
|
| 330 |
task_weights=trainer_cfg.get("task_weights"),
|
| 331 |
label_smoothing=float(trainer_cfg.get("label_smoothing", 0.0)),
|
| 332 |
gradient_accumulation_steps=int(trainer_cfg.get("gradient_accumulation_steps", 1)),
|
|
@@ -335,61 +269,43 @@ def main(cfg: DictConfig) -> None:
|
|
| 335 |
tokenizer=tokenizer,
|
| 336 |
)
|
| 337 |
|
| 338 |
-
#
|
| 339 |
-
|
| 340 |
-
def
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
save_state(model, str(
|
| 344 |
-
print(f"Checkpoint saved: {epoch_path}")
|
| 345 |
|
| 346 |
-
|
|
|
|
| 347 |
|
| 348 |
-
|
| 349 |
-
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
| 350 |
-
save_state(model, str(checkpoint_path))
|
| 351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
labels_path = Path(cfg.labels_out)
|
| 353 |
save_label_metadata(
|
| 354 |
-
LabelMetadata(
|
| 355 |
-
emotion=emotion_train.emotion_classes,
|
| 356 |
-
topic=topic_train.topic_classes,
|
| 357 |
-
),
|
| 358 |
labels_path,
|
| 359 |
)
|
| 360 |
|
|
|
|
| 361 |
history_path = Path(cfg.history_out)
|
| 362 |
history_path.parent.mkdir(parents=True, exist_ok=True)
|
| 363 |
-
with history_path.open("w"
|
| 364 |
-
json.dump(history,
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
print(f"
|
| 368 |
-
print(f"
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
print("
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
try:
|
| 375 |
-
subprocess.run(
|
| 376 |
-
[
|
| 377 |
-
sys.executable,
|
| 378 |
-
"scripts/evaluate.py",
|
| 379 |
-
"--split",
|
| 380 |
-
"test", # Evaluate on test set
|
| 381 |
-
"--checkpoint",
|
| 382 |
-
str(checkpoint_path),
|
| 383 |
-
"--labels",
|
| 384 |
-
str(labels_path),
|
| 385 |
-
"--output-dir",
|
| 386 |
-
"outputs",
|
| 387 |
-
],
|
| 388 |
-
check=True,
|
| 389 |
-
)
|
| 390 |
-
print("Evaluation pipeline completed successfully.")
|
| 391 |
-
except subprocess.CalledProcessError as e:
|
| 392 |
-
print(f"Evaluation pipeline failed with error: {e}")
|
| 393 |
|
| 394 |
|
| 395 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training script for LexiMind.
|
| 3 |
+
|
| 4 |
+
Orchestrates dataset loading, model construction, torch.compile optimization,
|
| 5 |
+
and multi-task training with checkpoint management.
|
| 6 |
+
|
| 7 |
+
Author: Oliver Perrin
|
| 8 |
+
Date: December 2025
|
| 9 |
+
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
import json
|
|
|
|
| 14 |
import sys
|
| 15 |
+
import time
|
| 16 |
from pathlib import Path
|
| 17 |
+
from typing import Any, Dict, Sequence
|
| 18 |
|
| 19 |
import hydra
|
| 20 |
import torch
|
|
|
|
| 44 |
from src.utils.io import save_state
|
| 45 |
from src.utils.labels import LabelMetadata, save_label_metadata
|
| 46 |
|
| 47 |
+
# --------------- Data Loading ---------------
|
|
|
|
| 48 |
|
| 49 |
SPLIT_ALIASES: Dict[str, Sequence[str]] = {
|
| 50 |
"train": ("train",),
|
|
|
|
| 53 |
}
|
| 54 |
|
| 55 |
|
| 56 |
+
def load_splits(data_dir: Path, loader) -> Dict[str, list]:
|
| 57 |
+
"""Load train/val/test splits from data directory."""
|
| 58 |
+
splits = {}
|
| 59 |
+
for name, aliases in SPLIT_ALIASES.items():
|
| 60 |
for alias in aliases:
|
| 61 |
+
for ext in ("jsonl", "json"):
|
| 62 |
+
path = data_dir / f"{alias}.{ext}"
|
| 63 |
+
if path.exists():
|
| 64 |
+
splits[name] = loader(str(path))
|
|
|
|
| 65 |
break
|
| 66 |
+
if name in splits:
|
| 67 |
break
|
| 68 |
+
if name not in splits:
|
| 69 |
+
raise FileNotFoundError(f"Missing {name} split in {data_dir}")
|
| 70 |
return splits
|
| 71 |
|
| 72 |
|
| 73 |
+
def limit_samples(splits: Dict[str, list], cfg: DictConfig) -> None:
|
| 74 |
+
"""Apply sample limits for dev/debug runs."""
|
| 75 |
+
for split, key in [("train", "max_train_samples"), ("val", "max_val_samples")]:
|
| 76 |
+
limit = cfg.get(key)
|
| 77 |
+
if limit and split in splits and len(splits[split]) > limit:
|
| 78 |
+
splits[split] = splits[split][: int(limit)]
|
| 79 |
+
print(f" {split}: limited to {limit} samples")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
# --------------- Model Compilation ---------------
|
| 83 |
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
def compile_model(model: torch.nn.Module) -> Any:
|
| 86 |
+
"""Compile model with aot_eager backend (stable, avoids inductor NaN issues)."""
|
| 87 |
+
try:
|
| 88 |
+
compiled = torch.compile(model, backend="aot_eager")
|
| 89 |
+
print("✓ Compiled with aot_eager")
|
| 90 |
+
return compiled
|
| 91 |
+
except Exception:
|
| 92 |
+
return model
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
# --------------- Main ---------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
|
| 98 |
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
| 99 |
def main(cfg: DictConfig) -> None:
|
| 100 |
+
start_time = time.perf_counter()
|
| 101 |
print(OmegaConf.to_yaml(cfg))
|
| 102 |
set_seed(cfg.seed)
|
| 103 |
|
| 104 |
+
# Enable TF32 for Ampere+ GPUs (RTX 30xx/40xx) - ~2x matmul speedup
|
|
|
|
| 105 |
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
| 106 |
+
print("✓ TF32 enabled for Ampere GPU")
|
| 107 |
torch.set_float32_matmul_precision("high")
|
| 108 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 109 |
torch.backends.cudnn.allow_tf32 = True
|
| 110 |
+
torch.backends.cudnn.benchmark = True # Auto-tune convolutions
|
| 111 |
+
torch.backends.cuda.enable_flash_sdp(True) # Flash attention if available
|
| 112 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True) # Memory-efficient attention
|
| 113 |
+
|
| 114 |
+
# Disable debug APIs for max speed
|
| 115 |
+
torch.autograd.set_detect_anomaly(False)
|
| 116 |
+
torch.autograd.profiler.profile(False)
|
| 117 |
+
torch.autograd.profiler.emit_nvtx(False)
|
| 118 |
+
|
| 119 |
+
# --------------- Load Data ---------------
|
| 120 |
|
|
|
|
| 121 |
data_cfg = cfg.data
|
| 122 |
+
trainer_cfg = cfg.training.get("trainer", {})
|
| 123 |
|
| 124 |
+
print("\nLoading datasets...")
|
| 125 |
+
summ_splits = load_splits(Path(data_cfg.processed.summarization), load_summarization_jsonl)
|
| 126 |
+
emot_splits = load_splits(Path(data_cfg.processed.emotion), load_emotion_jsonl)
|
| 127 |
+
topic_splits = load_splits(Path(data_cfg.processed.topic), load_topic_jsonl)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
+
# Apply dev/debug sample limits
|
| 130 |
+
for splits in [summ_splits, emot_splits, topic_splits]:
|
| 131 |
+
limit_samples(splits, trainer_cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
+
# --------------- Tokenizer & Datasets ---------------
|
|
|
|
| 134 |
|
| 135 |
+
tok_cfg = data_cfg.get("tokenizer", {})
|
| 136 |
+
tokenizer = Tokenizer(
|
| 137 |
+
TokenizerConfig(
|
| 138 |
+
pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
|
| 139 |
+
max_length=int(tok_cfg.get("max_length", 512)),
|
| 140 |
+
lower=bool(tok_cfg.get("lower", False)),
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
|
| 144 |
+
summ_train = SummarizationDataset(summ_splits["train"])
|
| 145 |
+
summ_val = SummarizationDataset(summ_splits["val"])
|
| 146 |
+
emot_train = EmotionDataset(emot_splits["train"])
|
| 147 |
+
emot_val = EmotionDataset(emot_splits["val"], binarizer=emot_train.binarizer)
|
| 148 |
topic_train = TopicDataset(topic_splits["train"])
|
| 149 |
topic_val = TopicDataset(topic_splits["val"], encoder=topic_train.encoder)
|
| 150 |
|
| 151 |
+
# --------------- DataLoaders ---------------
|
| 152 |
+
|
| 153 |
+
dl_cfg = cfg.training.get("dataloader", {})
|
| 154 |
+
batch_size = int(dl_cfg.get("batch_size", 8))
|
| 155 |
+
num_workers = int(dl_cfg.get("num_workers", 4))
|
| 156 |
+
pin_memory = bool(dl_cfg.get("pin_memory", True))
|
| 157 |
+
max_len = tokenizer.config.max_length
|
| 158 |
|
| 159 |
train_loaders = {
|
| 160 |
"summarization": build_summarization_dataloader(
|
| 161 |
+
summ_train,
|
| 162 |
tokenizer,
|
| 163 |
+
shuffle=True,
|
| 164 |
+
max_source_length=max_len,
|
| 165 |
+
max_target_length=max_len,
|
| 166 |
batch_size=batch_size,
|
|
|
|
|
|
|
|
|
|
| 167 |
num_workers=num_workers,
|
| 168 |
pin_memory=pin_memory,
|
| 169 |
),
|
| 170 |
"emotion": build_emotion_dataloader(
|
| 171 |
+
emot_train,
|
| 172 |
tokenizer,
|
| 173 |
+
shuffle=True,
|
| 174 |
+
max_length=max_len,
|
| 175 |
batch_size=batch_size,
|
|
|
|
|
|
|
| 176 |
num_workers=num_workers,
|
| 177 |
pin_memory=pin_memory,
|
| 178 |
),
|
| 179 |
"topic": build_topic_dataloader(
|
| 180 |
topic_train,
|
| 181 |
tokenizer,
|
| 182 |
+
shuffle=True,
|
| 183 |
+
max_length=max_len,
|
| 184 |
batch_size=batch_size,
|
|
|
|
|
|
|
| 185 |
num_workers=num_workers,
|
| 186 |
pin_memory=pin_memory,
|
| 187 |
),
|
| 188 |
}
|
|
|
|
| 189 |
val_loaders = {
|
| 190 |
"summarization": build_summarization_dataloader(
|
| 191 |
+
summ_val,
|
| 192 |
tokenizer,
|
|
|
|
| 193 |
shuffle=False,
|
| 194 |
+
max_source_length=max_len,
|
| 195 |
+
max_target_length=max_len,
|
| 196 |
+
batch_size=batch_size,
|
| 197 |
num_workers=num_workers,
|
| 198 |
pin_memory=pin_memory,
|
| 199 |
),
|
| 200 |
"emotion": build_emotion_dataloader(
|
| 201 |
+
emot_val,
|
| 202 |
tokenizer,
|
|
|
|
| 203 |
shuffle=False,
|
| 204 |
+
max_length=max_len,
|
| 205 |
+
batch_size=batch_size,
|
| 206 |
num_workers=num_workers,
|
| 207 |
pin_memory=pin_memory,
|
| 208 |
),
|
| 209 |
"topic": build_topic_dataloader(
|
| 210 |
topic_val,
|
| 211 |
tokenizer,
|
|
|
|
| 212 |
shuffle=False,
|
| 213 |
+
max_length=max_len,
|
| 214 |
+
batch_size=batch_size,
|
| 215 |
num_workers=num_workers,
|
| 216 |
pin_memory=pin_memory,
|
| 217 |
),
|
| 218 |
}
|
| 219 |
|
| 220 |
+
# --------------- Model ---------------
|
| 221 |
+
|
| 222 |
+
print("\nBuilding model...")
|
| 223 |
device = torch.device(cfg.device)
|
| 224 |
+
model_cfg = ModelConfig(
|
| 225 |
+
d_model=cfg.model.d_model,
|
| 226 |
+
num_encoder_layers=cfg.model.num_encoder_layers,
|
| 227 |
+
num_decoder_layers=cfg.model.num_decoder_layers,
|
| 228 |
+
num_attention_heads=cfg.model.num_attention_heads,
|
| 229 |
+
ffn_dim=cfg.model.ffn_dim,
|
| 230 |
+
dropout=cfg.model.dropout,
|
| 231 |
+
use_pretrained=cfg.model.use_pretrained,
|
| 232 |
+
pretrained_model_name=cfg.model.pretrained_model_name,
|
| 233 |
+
activation=getattr(cfg.model, "activation", "gelu"),
|
| 234 |
+
use_relative_position_bias=getattr(cfg.model, "use_relative_position_bias", False),
|
| 235 |
+
)
|
| 236 |
model = build_multitask_model(
|
| 237 |
tokenizer,
|
| 238 |
+
num_emotions=len(emot_train.emotion_classes),
|
| 239 |
num_topics=len(topic_train.topic_classes),
|
| 240 |
config=model_cfg,
|
| 241 |
).to(device)
|
| 242 |
|
| 243 |
+
# Compile encoder/decoder for faster training (skip heads - small overhead)
|
| 244 |
+
if model.encoder is not None:
|
| 245 |
+
model.encoder = compile_model(model.encoder)
|
| 246 |
+
if model.decoder is not None:
|
| 247 |
+
model.decoder = compile_model(model.decoder)
|
| 248 |
+
|
| 249 |
+
# --------------- Optimizer & Trainer ---------------
|
| 250 |
+
|
| 251 |
+
opt_cfg = cfg.training.get("optimizer", {})
|
| 252 |
+
optimizer = torch.optim.AdamW(
|
| 253 |
+
model.parameters(),
|
| 254 |
+
lr=float(opt_cfg.get("lr", 3e-5)),
|
| 255 |
+
weight_decay=float(opt_cfg.get("weight_decay", 0.01)),
|
| 256 |
+
)
|
| 257 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
trainer = Trainer(
|
| 259 |
model=model,
|
| 260 |
optimizer=optimizer,
|
| 261 |
config=TrainerConfig(
|
| 262 |
max_epochs=int(trainer_cfg.get("max_epochs", 1)),
|
| 263 |
gradient_clip_norm=float(trainer_cfg.get("gradient_clip_norm", 1.0)),
|
|
|
|
| 264 |
task_weights=trainer_cfg.get("task_weights"),
|
| 265 |
label_smoothing=float(trainer_cfg.get("label_smoothing", 0.0)),
|
| 266 |
gradient_accumulation_steps=int(trainer_cfg.get("gradient_accumulation_steps", 1)),
|
|
|
|
| 269 |
tokenizer=tokenizer,
|
| 270 |
)
|
| 271 |
|
| 272 |
+
# --------------- Train ---------------
|
| 273 |
+
|
| 274 |
+
def save_checkpoint(epoch: int, model: torch.nn.Module, history: Dict) -> None:
|
| 275 |
+
path = Path(cfg.checkpoint_out).parent / f"epoch_{epoch}.pt"
|
| 276 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 277 |
+
save_state(model, str(path))
|
|
|
|
| 278 |
|
| 279 |
+
print("\nStarting training...")
|
| 280 |
+
history = trainer.fit(train_loaders, val_loaders, checkpoint_callback=save_checkpoint)
|
| 281 |
|
| 282 |
+
# --------------- Save Outputs ---------------
|
|
|
|
|
|
|
| 283 |
|
| 284 |
+
# Best checkpoint
|
| 285 |
+
ckpt_path = Path(cfg.checkpoint_out)
|
| 286 |
+
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
| 287 |
+
save_state(model, str(ckpt_path))
|
| 288 |
+
|
| 289 |
+
# Labels
|
| 290 |
labels_path = Path(cfg.labels_out)
|
| 291 |
save_label_metadata(
|
| 292 |
+
LabelMetadata(emotion=emot_train.emotion_classes, topic=topic_train.topic_classes),
|
|
|
|
|
|
|
|
|
|
| 293 |
labels_path,
|
| 294 |
)
|
| 295 |
|
| 296 |
+
# History
|
| 297 |
history_path = Path(cfg.history_out)
|
| 298 |
history_path.parent.mkdir(parents=True, exist_ok=True)
|
| 299 |
+
with history_path.open("w") as f:
|
| 300 |
+
json.dump(history, f, indent=2)
|
| 301 |
+
|
| 302 |
+
total_time = time.perf_counter() - start_time
|
| 303 |
+
print(f"\n{'=' * 50}")
|
| 304 |
+
print(f"Training complete in {total_time:.1f}s")
|
| 305 |
+
print(f" Checkpoint: {ckpt_path}")
|
| 306 |
+
print(f" Labels: {labels_path}")
|
| 307 |
+
print(f" History: {history_path}")
|
| 308 |
+
print(f"{'=' * 50}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
|
| 311 |
if __name__ == "__main__":
|
src/api/app.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from fastapi import FastAPI
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI application factory for LexiMind.
|
| 3 |
+
|
| 4 |
+
Creates and configures the REST API application.
|
| 5 |
+
|
| 6 |
+
Author: Oliver Perrin
|
| 7 |
+
Date: December 2025
|
| 8 |
+
"""
|
| 9 |
|
| 10 |
from fastapi import FastAPI
|
| 11 |
|
src/api/dependencies.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI dependency providers for LexiMind.
|
| 3 |
+
|
| 4 |
+
Manages lazy initialization and caching of the inference pipeline.
|
| 5 |
+
|
| 6 |
+
Author: Oliver Perrin
|
| 7 |
+
Date: December 2025
|
| 8 |
+
"""
|
| 9 |
|
| 10 |
from __future__ import annotations
|
| 11 |
|
src/api/routes.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from typing import cast
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API routes for LexiMind.
|
| 3 |
+
|
| 4 |
+
Defines REST endpoints for text analysis including summarization,
|
| 5 |
+
emotion detection, and topic classification.
|
| 6 |
+
|
| 7 |
+
Author: Oliver Perrin
|
| 8 |
+
Date: December 2025
|
| 9 |
+
"""
|
| 10 |
|
| 11 |
from typing import cast
|
| 12 |
|
src/api/schemas.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from pydantic import BaseModel
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic schemas for LexiMind API.
|
| 3 |
+
|
| 4 |
+
Defines request and response models for the REST API.
|
| 5 |
+
|
| 6 |
+
Author: Oliver Perrin
|
| 7 |
+
Date: December 2025
|
| 8 |
+
"""
|
| 9 |
|
| 10 |
from pydantic import BaseModel
|
| 11 |
|
src/data/dataloader.py
CHANGED
|
@@ -1,8 +1,15 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
-
from typing import List
|
| 6 |
|
| 7 |
import torch
|
| 8 |
from torch.utils.data import DataLoader
|
|
@@ -17,9 +24,11 @@ from .dataset import (
|
|
| 17 |
)
|
| 18 |
from .tokenization import Tokenizer
|
| 19 |
|
|
|
|
|
|
|
| 20 |
|
| 21 |
class SummarizationCollator:
|
| 22 |
-
"""Prepare encoder-decoder batches for
|
| 23 |
|
| 24 |
def __init__(
|
| 25 |
self,
|
|
@@ -32,36 +41,24 @@ class SummarizationCollator:
|
|
| 32 |
self.max_source_length = max_source_length
|
| 33 |
self.max_target_length = max_target_length
|
| 34 |
|
| 35 |
-
def __call__(self, batch: List[SummarizationExample]) ->
|
| 36 |
-
sources = [
|
| 37 |
-
targets = [
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
|
| 42 |
-
#
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
# labels (target): [A, B, EOS, PAD] (drop first BOS)
|
| 46 |
|
| 47 |
-
ids = target_enc["input_ids"]
|
| 48 |
-
mask = target_enc["attention_mask"]
|
| 49 |
-
|
| 50 |
-
# Slice to create shifted inputs/targets
|
| 51 |
-
# tgt_ids: everything except the last token
|
| 52 |
tgt_ids = ids[:, :-1]
|
| 53 |
-
|
| 54 |
-
# labels: everything except the first token (BOS)
|
| 55 |
labels = ids[:, 1:].clone()
|
| 56 |
-
|
| 57 |
-
# Adjust mask for labels to ignore padding
|
| 58 |
-
# The mask corresponds to the original ids. We slice it to match labels.
|
| 59 |
-
labels_mask = mask[:, 1:]
|
| 60 |
-
labels[labels_mask == 0] = -100
|
| 61 |
|
| 62 |
return {
|
| 63 |
-
"src_ids":
|
| 64 |
-
"src_mask":
|
| 65 |
"tgt_ids": tgt_ids,
|
| 66 |
"labels": labels,
|
| 67 |
}
|
|
@@ -77,11 +74,13 @@ class EmotionCollator:
|
|
| 77 |
self.binarizer = dataset.binarizer
|
| 78 |
self.max_length = max_length
|
| 79 |
|
| 80 |
-
def __call__(self, batch: List[EmotionExample]) ->
|
| 81 |
-
texts = [
|
| 82 |
encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
return {
|
| 86 |
"input_ids": encoded["input_ids"],
|
| 87 |
"attention_mask": encoded["attention_mask"],
|
|
@@ -90,7 +89,7 @@ class EmotionCollator:
|
|
| 90 |
|
| 91 |
|
| 92 |
class TopicCollator:
|
| 93 |
-
"""Prepare batches for topic classification
|
| 94 |
|
| 95 |
def __init__(
|
| 96 |
self, tokenizer: Tokenizer, dataset: TopicDataset, *, max_length: int | None = None
|
|
@@ -99,11 +98,12 @@ class TopicCollator:
|
|
| 99 |
self.encoder = dataset.encoder
|
| 100 |
self.max_length = max_length
|
| 101 |
|
| 102 |
-
def __call__(self, batch: List[TopicExample]) ->
|
| 103 |
-
texts = [
|
| 104 |
encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
|
| 105 |
labels = torch.as_tensor(
|
| 106 |
-
self.encoder.transform([
|
|
|
|
| 107 |
)
|
| 108 |
return {
|
| 109 |
"input_ids": encoded["input_ids"],
|
|
@@ -112,6 +112,9 @@ class TopicCollator:
|
|
| 112 |
}
|
| 113 |
|
| 114 |
|
|
|
|
|
|
|
|
|
|
| 115 |
def build_summarization_dataloader(
|
| 116 |
dataset: SummarizationDataset,
|
| 117 |
tokenizer: Tokenizer,
|
|
@@ -123,6 +126,7 @@ def build_summarization_dataloader(
|
|
| 123 |
num_workers: int = 0,
|
| 124 |
pin_memory: bool = False,
|
| 125 |
) -> DataLoader:
|
|
|
|
| 126 |
collator = SummarizationCollator(
|
| 127 |
tokenizer,
|
| 128 |
max_source_length=max_source_length,
|
|
@@ -135,6 +139,7 @@ def build_summarization_dataloader(
|
|
| 135 |
collate_fn=collator,
|
| 136 |
num_workers=num_workers,
|
| 137 |
pin_memory=pin_memory,
|
|
|
|
| 138 |
)
|
| 139 |
|
| 140 |
|
|
@@ -148,6 +153,7 @@ def build_emotion_dataloader(
|
|
| 148 |
num_workers: int = 0,
|
| 149 |
pin_memory: bool = False,
|
| 150 |
) -> DataLoader:
|
|
|
|
| 151 |
collator = EmotionCollator(tokenizer, dataset, max_length=max_length)
|
| 152 |
return DataLoader(
|
| 153 |
dataset,
|
|
@@ -156,6 +162,7 @@ def build_emotion_dataloader(
|
|
| 156 |
collate_fn=collator,
|
| 157 |
num_workers=num_workers,
|
| 158 |
pin_memory=pin_memory,
|
|
|
|
| 159 |
)
|
| 160 |
|
| 161 |
|
|
@@ -169,6 +176,7 @@ def build_topic_dataloader(
|
|
| 169 |
num_workers: int = 0,
|
| 170 |
pin_memory: bool = False,
|
| 171 |
) -> DataLoader:
|
|
|
|
| 172 |
collator = TopicCollator(tokenizer, dataset, max_length=max_length)
|
| 173 |
return DataLoader(
|
| 174 |
dataset,
|
|
@@ -177,4 +185,5 @@ def build_topic_dataloader(
|
|
| 177 |
collate_fn=collator,
|
| 178 |
num_workers=num_workers,
|
| 179 |
pin_memory=pin_memory,
|
|
|
|
| 180 |
)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DataLoader builders for LexiMind.
|
| 3 |
+
|
| 4 |
+
Task-specific collators and factory functions for summarization, emotion, and topic.
|
| 5 |
+
|
| 6 |
+
Author: Oliver Perrin
|
| 7 |
+
Date: December 2025
|
| 8 |
+
"""
|
| 9 |
|
| 10 |
from __future__ import annotations
|
| 11 |
|
| 12 |
+
from typing import Dict, List
|
| 13 |
|
| 14 |
import torch
|
| 15 |
from torch.utils.data import DataLoader
|
|
|
|
| 24 |
)
|
| 25 |
from .tokenization import Tokenizer
|
| 26 |
|
| 27 |
+
# --------------- Collators ---------------
|
| 28 |
+
|
| 29 |
|
| 30 |
class SummarizationCollator:
|
| 31 |
+
"""Prepare encoder-decoder batches for seq2seq summarization."""
|
| 32 |
|
| 33 |
def __init__(
|
| 34 |
self,
|
|
|
|
| 41 |
self.max_source_length = max_source_length
|
| 42 |
self.max_target_length = max_target_length
|
| 43 |
|
| 44 |
+
def __call__(self, batch: List[SummarizationExample]) -> Dict[str, torch.Tensor]:
|
| 45 |
+
sources = [ex.source for ex in batch]
|
| 46 |
+
targets = [ex.summary for ex in batch]
|
| 47 |
|
| 48 |
+
src_enc = self.tokenizer.batch_encode(sources, max_length=self.max_source_length)
|
| 49 |
+
tgt_enc = self.tokenizer.batch_encode(targets, max_length=self.max_target_length)
|
| 50 |
|
| 51 |
+
# Shift targets: tgt_ids = [BOS, A, B], labels = [A, B, EOS]
|
| 52 |
+
ids = tgt_enc["input_ids"]
|
| 53 |
+
mask = tgt_enc["attention_mask"]
|
|
|
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
tgt_ids = ids[:, :-1]
|
|
|
|
|
|
|
| 56 |
labels = ids[:, 1:].clone()
|
| 57 |
+
labels[mask[:, 1:] == 0] = -100 # Mask padding in loss
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
return {
|
| 60 |
+
"src_ids": src_enc["input_ids"],
|
| 61 |
+
"src_mask": src_enc["attention_mask"],
|
| 62 |
"tgt_ids": tgt_ids,
|
| 63 |
"labels": labels,
|
| 64 |
}
|
|
|
|
| 74 |
self.binarizer = dataset.binarizer
|
| 75 |
self.max_length = max_length
|
| 76 |
|
| 77 |
+
def __call__(self, batch: List[EmotionExample]) -> Dict[str, torch.Tensor]:
|
| 78 |
+
texts = [ex.text for ex in batch]
|
| 79 |
encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
|
| 80 |
+
labels = torch.as_tensor(
|
| 81 |
+
self.binarizer.transform([ex.emotions for ex in batch]),
|
| 82 |
+
dtype=torch.float32,
|
| 83 |
+
)
|
| 84 |
return {
|
| 85 |
"input_ids": encoded["input_ids"],
|
| 86 |
"attention_mask": encoded["attention_mask"],
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
class TopicCollator:
|
| 92 |
+
"""Prepare batches for single-label topic classification."""
|
| 93 |
|
| 94 |
def __init__(
|
| 95 |
self, tokenizer: Tokenizer, dataset: TopicDataset, *, max_length: int | None = None
|
|
|
|
| 98 |
self.encoder = dataset.encoder
|
| 99 |
self.max_length = max_length
|
| 100 |
|
| 101 |
+
def __call__(self, batch: List[TopicExample]) -> Dict[str, torch.Tensor]:
|
| 102 |
+
texts = [ex.text for ex in batch]
|
| 103 |
encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
|
| 104 |
labels = torch.as_tensor(
|
| 105 |
+
self.encoder.transform([ex.topic for ex in batch]),
|
| 106 |
+
dtype=torch.long,
|
| 107 |
)
|
| 108 |
return {
|
| 109 |
"input_ids": encoded["input_ids"],
|
|
|
|
| 112 |
}
|
| 113 |
|
| 114 |
|
| 115 |
+
# --------------- Factory Functions ---------------
|
| 116 |
+
|
| 117 |
+
|
| 118 |
def build_summarization_dataloader(
|
| 119 |
dataset: SummarizationDataset,
|
| 120 |
tokenizer: Tokenizer,
|
|
|
|
| 126 |
num_workers: int = 0,
|
| 127 |
pin_memory: bool = False,
|
| 128 |
) -> DataLoader:
|
| 129 |
+
"""Create dataloader for summarization task."""
|
| 130 |
collator = SummarizationCollator(
|
| 131 |
tokenizer,
|
| 132 |
max_source_length=max_source_length,
|
|
|
|
| 139 |
collate_fn=collator,
|
| 140 |
num_workers=num_workers,
|
| 141 |
pin_memory=pin_memory,
|
| 142 |
+
persistent_workers=num_workers > 0, # Keep workers alive between epochs
|
| 143 |
)
|
| 144 |
|
| 145 |
|
|
|
|
| 153 |
num_workers: int = 0,
|
| 154 |
pin_memory: bool = False,
|
| 155 |
) -> DataLoader:
|
| 156 |
+
"""Create dataloader for emotion classification task."""
|
| 157 |
collator = EmotionCollator(tokenizer, dataset, max_length=max_length)
|
| 158 |
return DataLoader(
|
| 159 |
dataset,
|
|
|
|
| 162 |
collate_fn=collator,
|
| 163 |
num_workers=num_workers,
|
| 164 |
pin_memory=pin_memory,
|
| 165 |
+
persistent_workers=num_workers > 0,
|
| 166 |
)
|
| 167 |
|
| 168 |
|
|
|
|
| 176 |
num_workers: int = 0,
|
| 177 |
pin_memory: bool = False,
|
| 178 |
) -> DataLoader:
|
| 179 |
+
"""Create dataloader for topic classification task."""
|
| 180 |
collator = TopicCollator(tokenizer, dataset, max_length=max_length)
|
| 181 |
return DataLoader(
|
| 182 |
dataset,
|
|
|
|
| 185 |
collate_fn=collator,
|
| 186 |
num_workers=num_workers,
|
| 187 |
pin_memory=pin_memory,
|
| 188 |
+
persistent_workers=num_workers > 0,
|
| 189 |
)
|
src/data/dataset.py
CHANGED
|
@@ -1,4 +1,13 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset definitions for the LexiMind multitask training pipeline.
|
| 3 |
+
|
| 4 |
+
Defines PyTorch Dataset classes and data loading utilities for summarization,
|
| 5 |
+
emotion classification, and topic classification tasks. Supports both JSON
|
| 6 |
+
array and JSONL file formats.
|
| 7 |
+
|
| 8 |
+
Author: Oliver Perrin
|
| 9 |
+
Date: December 2025
|
| 10 |
+
"""
|
| 11 |
|
| 12 |
from __future__ import annotations
|
| 13 |
|
src/data/preprocessing.py
CHANGED
|
@@ -1,52 +1,64 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from dataclasses import dataclass, replace
|
| 6 |
-
from typing import
|
| 7 |
|
| 8 |
import torch
|
| 9 |
-
from sklearn.base import BaseEstimator, TransformerMixin
|
| 10 |
|
| 11 |
from .tokenization import Tokenizer, TokenizerConfig
|
| 12 |
|
|
|
|
| 13 |
|
| 14 |
-
class BasicTextCleaner(BaseEstimator, TransformerMixin):
|
| 15 |
-
"""Minimal text cleaner following scikit-learn conventions."""
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
self.lowercase = lowercase
|
| 19 |
-
self.strip = strip
|
| 20 |
|
| 21 |
-
def
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
if self.lowercase:
|
| 30 |
-
item = item.lower()
|
| 31 |
-
return " ".join(item.split())
|
| 32 |
|
| 33 |
|
| 34 |
@dataclass
|
| 35 |
class Batch:
|
| 36 |
-
"""
|
| 37 |
|
| 38 |
input_ids: torch.Tensor
|
| 39 |
attention_mask: torch.Tensor
|
| 40 |
lengths: List[int]
|
| 41 |
|
| 42 |
|
| 43 |
-
|
| 44 |
-
"""Coordinate lightweight text cleaning and tokenization.
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
"""
|
| 50 |
|
| 51 |
def __init__(
|
| 52 |
self,
|
|
@@ -56,19 +68,10 @@ class TextPreprocessor:
|
|
| 56 |
tokenizer_name: str = "google/flan-t5-base",
|
| 57 |
max_length: int | None = None,
|
| 58 |
lowercase: bool = True,
|
| 59 |
-
remove_stopwords: bool = False,
|
| 60 |
-
sklearn_transformer: TransformerMixin | None = None,
|
| 61 |
) -> None:
|
| 62 |
-
self.cleaner =
|
| 63 |
-
self.lowercase = lowercase
|
| 64 |
-
if remove_stopwords:
|
| 65 |
-
raise ValueError(
|
| 66 |
-
"Stop-word removal is not supported because it conflicts with subword tokenizers; "
|
| 67 |
-
"clean the text externally before initializing TextPreprocessor."
|
| 68 |
-
)
|
| 69 |
-
self._stop_words = None
|
| 70 |
-
self._sklearn_transformer = sklearn_transformer
|
| 71 |
|
|
|
|
| 72 |
if tokenizer is None:
|
| 73 |
cfg = tokenizer_config or TokenizerConfig(pretrained_model_name=tokenizer_name)
|
| 74 |
if max_length is not None:
|
|
@@ -78,52 +81,33 @@ class TextPreprocessor:
|
|
| 78 |
self.tokenizer = tokenizer
|
| 79 |
if max_length is not None and max_length != tokenizer.config.max_length:
|
| 80 |
raise ValueError(
|
| 81 |
-
"
|
| 82 |
-
"
|
| 83 |
)
|
| 84 |
|
| 85 |
self.max_length = max_length or self.tokenizer.config.max_length
|
| 86 |
|
| 87 |
def clean_text(self, text: str) -> str:
|
| 88 |
-
|
| 89 |
-
return self.
|
| 90 |
-
|
| 91 |
-
def _normalize_tokens(self, text: str) -> str:
|
| 92 |
-
"""Apply token-level normalization and optional stop-word filtering."""
|
| 93 |
-
# Note: Pre-tokenization word-splitting is incompatible with subword tokenizers.
|
| 94 |
-
# Stop-word filtering should be done post-tokenization or not at all for transformers.
|
| 95 |
-
return text
|
| 96 |
-
|
| 97 |
-
def _apply_sklearn_transform(self, texts: List[str]) -> List[str]:
|
| 98 |
-
if self._sklearn_transformer is None:
|
| 99 |
-
return texts
|
| 100 |
-
|
| 101 |
-
transform = getattr(self._sklearn_transformer, "transform", None)
|
| 102 |
-
if transform is None:
|
| 103 |
-
raise AttributeError("Provided sklearn transformer must implement a 'transform' method")
|
| 104 |
-
transformed = transform(texts)
|
| 105 |
-
if isinstance(transformed, list):
|
| 106 |
-
return transformed # assume downstream type is already list[str]
|
| 107 |
-
if hasattr(transformed, "tolist"):
|
| 108 |
-
transformed = transformed.tolist()
|
| 109 |
-
|
| 110 |
-
result = list(transformed)
|
| 111 |
-
if not all(isinstance(item, str) for item in result):
|
| 112 |
-
result = [str(item) for item in result]
|
| 113 |
-
return result
|
| 114 |
-
|
| 115 |
-
def _prepare_texts(self, texts: Sequence[str]) -> List[str]:
|
| 116 |
-
cleaned = self.cleaner.transform(texts)
|
| 117 |
-
normalized = [self._normalize_tokens(text) for text in cleaned]
|
| 118 |
-
return self._apply_sklearn_transform(normalized)
|
| 119 |
|
| 120 |
def batch_encode(self, texts: Sequence[str]) -> Batch:
|
| 121 |
-
|
|
|
|
| 122 |
encoded = self.tokenizer.batch_encode(cleaned, max_length=self.max_length)
|
| 123 |
-
|
| 124 |
-
|
|
|
|
| 125 |
lengths = attention_mask.sum(dim=1).tolist()
|
|
|
|
| 126 |
return Batch(input_ids=input_ids, attention_mask=attention_mask, lengths=lengths)
|
| 127 |
|
| 128 |
def __call__(self, texts: Sequence[str]) -> Batch:
|
|
|
|
| 129 |
return self.batch_encode(texts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Text preprocessing for LexiMind.
|
| 3 |
+
|
| 4 |
+
Lightweight text cleaning and tokenization pipeline for model input preparation.
|
| 5 |
+
|
| 6 |
+
Author: Oliver Perrin
|
| 7 |
+
Date: December 2025
|
| 8 |
+
"""
|
| 9 |
|
| 10 |
from __future__ import annotations
|
| 11 |
|
| 12 |
from dataclasses import dataclass, replace
|
| 13 |
+
from typing import List, Sequence
|
| 14 |
|
| 15 |
import torch
|
|
|
|
| 16 |
|
| 17 |
from .tokenization import Tokenizer, TokenizerConfig
|
| 18 |
|
| 19 |
+
# --------------- Text Cleaning ---------------
|
| 20 |
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
class TextCleaner:
|
| 23 |
+
"""Basic text normalization."""
|
| 24 |
+
|
| 25 |
+
def __init__(self, lowercase: bool = True) -> None:
|
| 26 |
self.lowercase = lowercase
|
|
|
|
| 27 |
|
| 28 |
+
def clean(self, text: str) -> str:
|
| 29 |
+
"""Strip, normalize whitespace, optionally lowercase."""
|
| 30 |
+
text = text.strip()
|
| 31 |
+
if self.lowercase:
|
| 32 |
+
text = text.lower()
|
| 33 |
+
return " ".join(text.split())
|
| 34 |
+
|
| 35 |
+
def clean_batch(self, texts: Sequence[str]) -> List[str]:
|
| 36 |
+
"""Clean multiple texts."""
|
| 37 |
+
return [self.clean(t) for t in texts]
|
| 38 |
|
| 39 |
+
# Backwards compatibility alias
|
| 40 |
+
def transform(self, texts: Sequence[str]) -> List[str]:
|
| 41 |
+
"""Alias for clean_batch (sklearn-style interface)."""
|
| 42 |
+
return self.clean_batch(texts)
|
| 43 |
|
| 44 |
+
|
| 45 |
+
# --------------- Batch Output ---------------
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
@dataclass
|
| 49 |
class Batch:
|
| 50 |
+
"""Tokenized batch ready for model consumption."""
|
| 51 |
|
| 52 |
input_ids: torch.Tensor
|
| 53 |
attention_mask: torch.Tensor
|
| 54 |
lengths: List[int]
|
| 55 |
|
| 56 |
|
| 57 |
+
# --------------- Preprocessor ---------------
|
|
|
|
| 58 |
|
| 59 |
+
|
| 60 |
+
class TextPreprocessor:
|
| 61 |
+
"""Combines text cleaning with tokenization."""
|
|
|
|
| 62 |
|
| 63 |
def __init__(
|
| 64 |
self,
|
|
|
|
| 68 |
tokenizer_name: str = "google/flan-t5-base",
|
| 69 |
max_length: int | None = None,
|
| 70 |
lowercase: bool = True,
|
|
|
|
|
|
|
| 71 |
) -> None:
|
| 72 |
+
self.cleaner = TextCleaner(lowercase=lowercase)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
+
# Initialize or validate tokenizer
|
| 75 |
if tokenizer is None:
|
| 76 |
cfg = tokenizer_config or TokenizerConfig(pretrained_model_name=tokenizer_name)
|
| 77 |
if max_length is not None:
|
|
|
|
| 81 |
self.tokenizer = tokenizer
|
| 82 |
if max_length is not None and max_length != tokenizer.config.max_length:
|
| 83 |
raise ValueError(
|
| 84 |
+
"max_length conflicts with tokenizer config - "
|
| 85 |
+
"initialize tokenizer with desired settings"
|
| 86 |
)
|
| 87 |
|
| 88 |
self.max_length = max_length or self.tokenizer.config.max_length
|
| 89 |
|
| 90 |
def clean_text(self, text: str) -> str:
|
| 91 |
+
"""Clean a single text."""
|
| 92 |
+
return self.cleaner.clean(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
def batch_encode(self, texts: Sequence[str]) -> Batch:
|
| 95 |
+
"""Clean and tokenize texts into a batch."""
|
| 96 |
+
cleaned = self.cleaner.clean_batch(texts)
|
| 97 |
encoded = self.tokenizer.batch_encode(cleaned, max_length=self.max_length)
|
| 98 |
+
|
| 99 |
+
input_ids = encoded["input_ids"]
|
| 100 |
+
attention_mask = encoded["attention_mask"].to(dtype=torch.bool)
|
| 101 |
lengths = attention_mask.sum(dim=1).tolist()
|
| 102 |
+
|
| 103 |
return Batch(input_ids=input_ids, attention_mask=attention_mask, lengths=lengths)
|
| 104 |
|
| 105 |
def __call__(self, texts: Sequence[str]) -> Batch:
|
| 106 |
+
"""Alias for batch_encode."""
|
| 107 |
return self.batch_encode(texts)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# --------------- Backwards Compatibility ---------------
|
| 111 |
+
|
| 112 |
+
# Keep old name for any imports
|
| 113 |
+
BasicTextCleaner = TextCleaner
|
src/data/tokenization.py
CHANGED
|
@@ -1,4 +1,13 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tokenizer facade for LexiMind.
|
| 3 |
+
|
| 4 |
+
Wraps HuggingFace tokenizers with a simplified interface that handles
|
| 5 |
+
special token management, batch encoding, and T5-specific conventions
|
| 6 |
+
for decoder input preparation.
|
| 7 |
+
|
| 8 |
+
Author: Oliver Perrin
|
| 9 |
+
Date: December 2025
|
| 10 |
+
"""
|
| 11 |
|
| 12 |
from __future__ import annotations
|
| 13 |
|
src/inference/factory.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference pipeline factory for LexiMind.
|
| 3 |
+
|
| 4 |
+
Assembles a complete inference pipeline from saved checkpoints, tokenizer
|
| 5 |
+
artifacts, and label metadata. Handles model loading and configuration.
|
| 6 |
+
|
| 7 |
+
Author: Oliver Perrin
|
| 8 |
+
Date: December 2025
|
| 9 |
+
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
src/inference/pipeline.py
CHANGED
|
@@ -1,9 +1,17 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from dataclasses import dataclass, fields, replace
|
| 6 |
-
from typing import Any,
|
| 7 |
|
| 8 |
import torch
|
| 9 |
import torch.nn.functional as F
|
|
@@ -11,10 +19,12 @@ import torch.nn.functional as F
|
|
| 11 |
from ..data.preprocessing import Batch, TextPreprocessor
|
| 12 |
from ..data.tokenization import Tokenizer
|
| 13 |
|
|
|
|
|
|
|
| 14 |
|
| 15 |
@dataclass
|
| 16 |
class InferenceConfig:
|
| 17 |
-
"""
|
| 18 |
|
| 19 |
summary_max_length: int = 128
|
| 20 |
emotion_threshold: float = 0.5
|
|
@@ -33,8 +43,11 @@ class TopicPrediction:
|
|
| 33 |
confidence: float
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
class InferencePipeline:
|
| 37 |
-
"""
|
| 38 |
|
| 39 |
def __init__(
|
| 40 |
self,
|
|
@@ -50,50 +63,49 @@ class InferencePipeline:
|
|
| 50 |
self.model = model
|
| 51 |
self.tokenizer = tokenizer
|
| 52 |
self.config = config or InferenceConfig()
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
| 58 |
self.model.to(self.device)
|
| 59 |
self.model.eval()
|
| 60 |
|
| 61 |
self.preprocessor = preprocessor or TextPreprocessor(tokenizer=tokenizer)
|
| 62 |
-
self.emotion_labels = list(emotion_labels) if emotion_labels
|
| 63 |
-
self.topic_labels = list(topic_labels) if topic_labels
|
|
|
|
|
|
|
| 64 |
|
| 65 |
def summarize(self, texts: Sequence[str], *, max_length: int | None = None) -> List[str]:
|
|
|
|
| 66 |
if not texts:
|
| 67 |
return []
|
| 68 |
-
|
|
|
|
| 69 |
src_ids = batch.input_ids
|
| 70 |
src_mask = batch.attention_mask
|
| 71 |
max_len = max_length or self.config.summary_max_length
|
| 72 |
|
| 73 |
-
if not hasattr(self.model, "encoder") or not hasattr(self.model, "decoder"):
|
| 74 |
-
raise RuntimeError(
|
| 75 |
-
"Model must expose encoder and decoder attributes for summarization."
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
# Cast to Any to allow access to dynamic attributes encoder and decoder
|
| 79 |
model = cast(Any, self.model)
|
|
|
|
|
|
|
| 80 |
|
| 81 |
with torch.inference_mode():
|
| 82 |
-
|
|
|
|
| 83 |
src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
| 84 |
)
|
| 85 |
-
memory = model.encoder(src_ids, mask=
|
| 86 |
-
min_len = 10
|
| 87 |
|
| 88 |
-
#
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
unk_id = getattr(self.tokenizer._tokenizer, "unk_token_id", None)
|
| 94 |
-
if isinstance(unk_id, int):
|
| 95 |
-
ban_token_ids.append(unk_id)
|
| 96 |
-
ban_token_ids = [tid for tid in ban_token_ids if tid is not None]
|
| 97 |
|
| 98 |
generated = model.decoder.greedy_decode(
|
| 99 |
memory=memory,
|
|
@@ -101,16 +113,15 @@ class InferencePipeline:
|
|
| 101 |
start_token_id=self.tokenizer.bos_token_id,
|
| 102 |
end_token_id=self.tokenizer.eos_token_id,
|
| 103 |
device=self.device,
|
| 104 |
-
min_len=
|
| 105 |
-
ban_token_ids=
|
| 106 |
no_repeat_ngram_size=3,
|
| 107 |
memory_mask=src_mask,
|
| 108 |
)
|
| 109 |
|
| 110 |
-
|
| 111 |
-
final_summaries = decoded_list
|
| 112 |
|
| 113 |
-
|
| 114 |
|
| 115 |
def predict_emotions(
|
| 116 |
self,
|
|
@@ -118,78 +129,91 @@ class InferencePipeline:
|
|
| 118 |
*,
|
| 119 |
threshold: float | None = None,
|
| 120 |
) -> List[EmotionPrediction]:
|
|
|
|
| 121 |
if not texts:
|
| 122 |
return []
|
| 123 |
-
if
|
| 124 |
-
raise RuntimeError("emotion_labels
|
| 125 |
|
| 126 |
-
batch = self.
|
| 127 |
-
|
| 128 |
-
|
| 129 |
|
| 130 |
with torch.inference_mode():
|
| 131 |
-
logits = self.model.forward("emotion",
|
| 132 |
probs = torch.sigmoid(logits)
|
| 133 |
|
| 134 |
-
|
| 135 |
for row in probs.cpu():
|
| 136 |
pairs = [
|
| 137 |
(label, score)
|
| 138 |
for label, score in zip(self.emotion_labels, row.tolist(), strict=False)
|
| 139 |
-
if score >=
|
| 140 |
]
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
def predict_topics(self, texts: Sequence[str]) -> List[TopicPrediction]:
|
|
|
|
| 147 |
if not texts:
|
| 148 |
return []
|
| 149 |
-
if
|
| 150 |
-
raise RuntimeError("topic_labels
|
| 151 |
|
| 152 |
-
batch = self.
|
| 153 |
-
|
| 154 |
|
| 155 |
with torch.inference_mode():
|
| 156 |
-
logits = self.model.forward("topic",
|
| 157 |
probs = F.softmax(logits, dim=-1)
|
| 158 |
|
| 159 |
-
results
|
| 160 |
for row in probs.cpu():
|
| 161 |
-
|
| 162 |
-
best_index = int(row.argmax().item())
|
| 163 |
results.append(
|
| 164 |
-
TopicPrediction(
|
|
|
|
|
|
|
|
|
|
| 165 |
)
|
| 166 |
return results
|
| 167 |
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
text_list = list(texts)
|
| 170 |
-
if self.emotion_labels is None or not self.emotion_labels:
|
| 171 |
-
raise RuntimeError("emotion_labels must be provided for batch predictions")
|
| 172 |
-
if self.topic_labels is None or not self.topic_labels:
|
| 173 |
-
raise RuntimeError("topic_labels must be provided for batch predictions")
|
| 174 |
return {
|
| 175 |
"summaries": self.summarize(text_list),
|
| 176 |
"emotion": self.predict_emotions(text_list),
|
| 177 |
"topic": self.predict_topics(text_list),
|
| 178 |
}
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
| 189 |
|
| 190 |
@staticmethod
|
| 191 |
-
def
|
| 192 |
-
inputs
|
|
|
|
| 193 |
if batch.attention_mask is not None:
|
| 194 |
inputs["attention_mask"] = batch.attention_mask
|
| 195 |
return inputs
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference pipeline for LexiMind.
|
| 3 |
+
|
| 4 |
+
Unified interface for summarization, emotion detection, and topic classification
|
| 5 |
+
with batched processing and device management.
|
| 6 |
+
|
| 7 |
+
Author: Oliver Perrin
|
| 8 |
+
Date: December 2025
|
| 9 |
+
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
from dataclasses import dataclass, fields, replace
|
| 14 |
+
from typing import Any, Dict, List, Sequence, cast
|
| 15 |
|
| 16 |
import torch
|
| 17 |
import torch.nn.functional as F
|
|
|
|
| 19 |
from ..data.preprocessing import Batch, TextPreprocessor
|
| 20 |
from ..data.tokenization import Tokenizer
|
| 21 |
|
| 22 |
+
# --------------- Configuration ---------------
|
| 23 |
+
|
| 24 |
|
| 25 |
@dataclass
|
| 26 |
class InferenceConfig:
|
| 27 |
+
"""Pipeline settings."""
|
| 28 |
|
| 29 |
summary_max_length: int = 128
|
| 30 |
emotion_threshold: float = 0.5
|
|
|
|
| 43 |
confidence: float
|
| 44 |
|
| 45 |
|
| 46 |
+
# --------------- Pipeline ---------------
|
| 47 |
+
|
| 48 |
+
|
| 49 |
class InferencePipeline:
|
| 50 |
+
"""Multi-task inference with batched processing."""
|
| 51 |
|
| 52 |
def __init__(
|
| 53 |
self,
|
|
|
|
| 63 |
self.model = model
|
| 64 |
self.tokenizer = tokenizer
|
| 65 |
self.config = config or InferenceConfig()
|
| 66 |
+
|
| 67 |
+
# Resolve device
|
| 68 |
+
chosen = device or self.config.device
|
| 69 |
+
if chosen is None:
|
| 70 |
+
param = next(model.parameters(), None)
|
| 71 |
+
chosen = param.device if param else "cpu"
|
| 72 |
+
self.device = torch.device(chosen)
|
| 73 |
+
|
| 74 |
self.model.to(self.device)
|
| 75 |
self.model.eval()
|
| 76 |
|
| 77 |
self.preprocessor = preprocessor or TextPreprocessor(tokenizer=tokenizer)
|
| 78 |
+
self.emotion_labels = list(emotion_labels) if emotion_labels else None
|
| 79 |
+
self.topic_labels = list(topic_labels) if topic_labels else None
|
| 80 |
+
|
| 81 |
+
# --------------- Summarization ---------------
|
| 82 |
|
| 83 |
def summarize(self, texts: Sequence[str], *, max_length: int | None = None) -> List[str]:
|
| 84 |
+
"""Generate summaries for input texts."""
|
| 85 |
if not texts:
|
| 86 |
return []
|
| 87 |
+
|
| 88 |
+
batch = self._to_device(self.preprocessor.batch_encode(texts))
|
| 89 |
src_ids = batch.input_ids
|
| 90 |
src_mask = batch.attention_mask
|
| 91 |
max_len = max_length or self.config.summary_max_length
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
model = cast(Any, self.model)
|
| 94 |
+
if not hasattr(model, "encoder") or not hasattr(model, "decoder"):
|
| 95 |
+
raise RuntimeError("Model must have encoder and decoder for summarization")
|
| 96 |
|
| 97 |
with torch.inference_mode():
|
| 98 |
+
# Encode
|
| 99 |
+
enc_mask = (
|
| 100 |
src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
| 101 |
)
|
| 102 |
+
memory = model.encoder(src_ids, mask=enc_mask)
|
|
|
|
| 103 |
|
| 104 |
+
# Decode with constraints to improve quality
|
| 105 |
+
ban_ids = [self.tokenizer.bos_token_id, self.tokenizer.pad_token_id]
|
| 106 |
+
unk = getattr(self.tokenizer._tokenizer, "unk_token_id", None)
|
| 107 |
+
if isinstance(unk, int):
|
| 108 |
+
ban_ids.append(unk)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
generated = model.decoder.greedy_decode(
|
| 111 |
memory=memory,
|
|
|
|
| 113 |
start_token_id=self.tokenizer.bos_token_id,
|
| 114 |
end_token_id=self.tokenizer.eos_token_id,
|
| 115 |
device=self.device,
|
| 116 |
+
min_len=10,
|
| 117 |
+
ban_token_ids=[i for i in ban_ids if i is not None],
|
| 118 |
no_repeat_ngram_size=3,
|
| 119 |
memory_mask=src_mask,
|
| 120 |
)
|
| 121 |
|
| 122 |
+
return self.tokenizer.decode_batch(generated.tolist())
|
|
|
|
| 123 |
|
| 124 |
+
# --------------- Emotion ---------------
|
| 125 |
|
| 126 |
def predict_emotions(
|
| 127 |
self,
|
|
|
|
| 129 |
*,
|
| 130 |
threshold: float | None = None,
|
| 131 |
) -> List[EmotionPrediction]:
|
| 132 |
+
"""Predict emotions for input texts."""
|
| 133 |
if not texts:
|
| 134 |
return []
|
| 135 |
+
if not self.emotion_labels:
|
| 136 |
+
raise RuntimeError("emotion_labels required for emotion prediction")
|
| 137 |
|
| 138 |
+
batch = self._to_device(self.preprocessor.batch_encode(texts))
|
| 139 |
+
inputs = self._model_inputs(batch)
|
| 140 |
+
thresh = threshold or self.config.emotion_threshold
|
| 141 |
|
| 142 |
with torch.inference_mode():
|
| 143 |
+
logits = self.model.forward("emotion", inputs)
|
| 144 |
probs = torch.sigmoid(logits)
|
| 145 |
|
| 146 |
+
results = []
|
| 147 |
for row in probs.cpu():
|
| 148 |
pairs = [
|
| 149 |
(label, score)
|
| 150 |
for label, score in zip(self.emotion_labels, row.tolist(), strict=False)
|
| 151 |
+
if score >= thresh
|
| 152 |
]
|
| 153 |
+
results.append(
|
| 154 |
+
EmotionPrediction(
|
| 155 |
+
labels=[label for label, _ in pairs],
|
| 156 |
+
scores=[score for _, score in pairs],
|
| 157 |
+
)
|
| 158 |
+
)
|
| 159 |
+
return results
|
| 160 |
+
|
| 161 |
+
# --------------- Topic ---------------
|
| 162 |
|
| 163 |
def predict_topics(self, texts: Sequence[str]) -> List[TopicPrediction]:
|
| 164 |
+
"""Predict topic for input texts."""
|
| 165 |
if not texts:
|
| 166 |
return []
|
| 167 |
+
if not self.topic_labels:
|
| 168 |
+
raise RuntimeError("topic_labels required for topic prediction")
|
| 169 |
|
| 170 |
+
batch = self._to_device(self.preprocessor.batch_encode(texts))
|
| 171 |
+
inputs = self._model_inputs(batch)
|
| 172 |
|
| 173 |
with torch.inference_mode():
|
| 174 |
+
logits = self.model.forward("topic", inputs)
|
| 175 |
probs = F.softmax(logits, dim=-1)
|
| 176 |
|
| 177 |
+
results = []
|
| 178 |
for row in probs.cpu():
|
| 179 |
+
idx = int(row.argmax().item())
|
|
|
|
| 180 |
results.append(
|
| 181 |
+
TopicPrediction(
|
| 182 |
+
label=self.topic_labels[idx],
|
| 183 |
+
confidence=row[idx].item(),
|
| 184 |
+
)
|
| 185 |
)
|
| 186 |
return results
|
| 187 |
|
| 188 |
+
# --------------- Batch Prediction ---------------
|
| 189 |
+
|
| 190 |
+
def batch_predict(self, texts: Sequence[str]) -> Dict[str, Any]:
|
| 191 |
+
"""Run all three tasks on input texts."""
|
| 192 |
+
if not self.emotion_labels or not self.topic_labels:
|
| 193 |
+
raise RuntimeError("Both emotion_labels and topic_labels required")
|
| 194 |
+
|
| 195 |
text_list = list(texts)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
return {
|
| 197 |
"summaries": self.summarize(text_list),
|
| 198 |
"emotion": self.predict_emotions(text_list),
|
| 199 |
"topic": self.predict_topics(text_list),
|
| 200 |
}
|
| 201 |
|
| 202 |
+
# --------------- Helpers ---------------
|
| 203 |
+
|
| 204 |
+
def _to_device(self, batch: Batch) -> Batch:
|
| 205 |
+
"""Move batch tensors to device with non_blocking for speed."""
|
| 206 |
+
updates = {}
|
| 207 |
+
for f in fields(batch):
|
| 208 |
+
val = getattr(batch, f.name)
|
| 209 |
+
if torch.is_tensor(val):
|
| 210 |
+
updates[f.name] = val.to(self.device, non_blocking=True)
|
| 211 |
+
return replace(batch, **updates) if updates else batch
|
| 212 |
|
| 213 |
@staticmethod
|
| 214 |
+
def _model_inputs(batch: Batch) -> Dict[str, torch.Tensor]:
|
| 215 |
+
"""Extract model inputs from batch."""
|
| 216 |
+
inputs = {"input_ids": batch.input_ids}
|
| 217 |
if batch.attention_mask is not None:
|
| 218 |
inputs["attention_mask"] = batch.attention_mask
|
| 219 |
return inputs
|
src/inference/postprocessing.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from typing import List
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Output postprocessing utilities for LexiMind.
|
| 3 |
+
|
| 4 |
+
Provides text cleaning helpers for model outputs.
|
| 5 |
+
|
| 6 |
+
Author: Oliver Perrin
|
| 7 |
+
Date: December 2025
|
| 8 |
+
"""
|
| 9 |
|
| 10 |
from typing import List
|
| 11 |
|
src/models/decoder.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
-
|
| 6 |
-
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
- Masks are boolean: True =
|
| 11 |
-
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|
|
|
| 1 |
+
"""Transformer Decoder implementation (Pre-LN).
|
| 2 |
+
|
| 3 |
+
This module implements the decoder component of the Transformer architecture:
|
| 4 |
+
- create_causal_mask: Generate causal attention masks
|
| 5 |
+
- TransformerDecoderLayer: Single decoder block with self-attn + cross-attn + FFN
|
| 6 |
+
- TransformerDecoder: Full stack with embeddings, positional encoding, and generation
|
| 7 |
+
|
| 8 |
+
Design notes:
|
| 9 |
+
- Pre-LN with RMSNorm for training stability
|
| 10 |
+
- Masks are boolean: True = attend, False = mask
|
| 11 |
+
- Supports T5-style relative position bias
|
| 12 |
+
|
| 13 |
+
Author: Oliver Perrin
|
| 14 |
+
Date: 2025-10-23
|
| 15 |
"""
|
| 16 |
|
| 17 |
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
src/models/encoder.py
CHANGED
|
@@ -1,17 +1,16 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
-
|
| 10 |
-
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
- Optionally collect attention weights by passing collect_attn=True to forward().
|
| 15 |
"""
|
| 16 |
|
| 17 |
from typing import List, Literal, Optional, Tuple, Union
|
|
@@ -213,9 +212,9 @@ class TransformerEncoder(nn.Module):
|
|
| 213 |
Build a 3D attention mask (batch, seq, seq) from input_ids and pad_token_id.
|
| 214 |
True indicates valid positions; False indicates masked (pad).
|
| 215 |
"""
|
| 216 |
-
assert (
|
| 217 |
-
|
| 218 |
-
)
|
| 219 |
# mask shape: (batch, seq) where True = token kept (non-pad)
|
| 220 |
pad_mask = input_ids != self.pad_token_id
|
| 221 |
# Convert to (batch, seq_q, seq_k) by outer product broadcasting
|
|
|
|
| 1 |
+
"""Transformer Encoder implementation (Pre-LN).
|
| 2 |
+
|
| 3 |
+
This module implements the encoder component of the Transformer architecture:
|
| 4 |
+
- TransformerEncoderLayer: Single encoder block with self-attention + FFN
|
| 5 |
+
- TransformerEncoder: Full stack with embeddings and positional encoding
|
| 6 |
+
|
| 7 |
+
Design notes:
|
| 8 |
+
- Pre-LN with RMSNorm for training stability
|
| 9 |
+
- Masks are boolean: True = attend, False = mask
|
| 10 |
+
- Supports T5-style relative position bias
|
| 11 |
+
|
| 12 |
+
Author: Oliver Perrin
|
| 13 |
+
Date: 2025-10-23
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
from typing import List, Literal, Optional, Tuple, Union
|
|
|
|
| 212 |
Build a 3D attention mask (batch, seq, seq) from input_ids and pad_token_id.
|
| 213 |
True indicates valid positions; False indicates masked (pad).
|
| 214 |
"""
|
| 215 |
+
assert self.pad_token_id is not None, (
|
| 216 |
+
"pad_token_id must be set to build padding mask from ids."
|
| 217 |
+
)
|
| 218 |
# mask shape: (batch, seq) where True = token kept (non-pad)
|
| 219 |
pad_mask = input_ids != self.pad_token_id
|
| 220 |
# Convert to (batch, seq_q, seq_k) by outer product broadcasting
|
src/models/factory.py
CHANGED
|
@@ -1,4 +1,14 @@
|
|
| 1 |
-
"""Factory helpers to assemble multitask models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""Factory helpers to assemble multitask models.
|
| 2 |
+
|
| 3 |
+
This module provides model construction and weight loading utilities:
|
| 4 |
+
- ModelConfig: Dataclass for architecture hyperparameters
|
| 5 |
+
- load_model_config: Load configuration from YAML
|
| 6 |
+
- build_multitask_model: Construct full model with task heads
|
| 7 |
+
- Weight loading: Transfer pretrained T5/FLAN-T5 or LLaMA weights
|
| 8 |
+
|
| 9 |
+
Author: Oliver Perrin
|
| 10 |
+
Date: 2025-10-23
|
| 11 |
+
"""
|
| 12 |
|
| 13 |
from __future__ import annotations
|
| 14 |
|
src/models/feedforward.py
CHANGED
|
@@ -1,5 +1,11 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
from typing import Literal, Optional
|
|
|
|
| 1 |
+
"""Position-wise Feed-Forward Network.
|
| 2 |
+
|
| 3 |
+
This module implements the FFN sublayer used in Transformer blocks:
|
| 4 |
+
- Standard FFN: Two linear layers with activation (GELU/ReLU)
|
| 5 |
+
- Gated FFN: SwiGLU (LLaMA-style) or Gated-GELU (T5/FLAN-T5 style)
|
| 6 |
+
|
| 7 |
+
Author: Oliver Perrin
|
| 8 |
+
Date: 2025-10-23
|
| 9 |
"""
|
| 10 |
|
| 11 |
from typing import Literal, Optional
|
src/models/heads.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Prediction heads for Transformer models.
|
| 3 |
|
| 4 |
-
|
| 5 |
-
- ClassificationHead:
|
| 6 |
-
- TokenClassificationHead:
|
| 7 |
-
- LMHead:
|
| 8 |
-
- ProjectionHead:
|
| 9 |
|
| 10 |
-
|
|
|
|
| 11 |
"""
|
| 12 |
|
| 13 |
from typing import Literal, Optional
|
|
@@ -117,12 +117,12 @@ class LMHead(nn.Module):
|
|
| 117 |
|
| 118 |
if tie_embedding is not None:
|
| 119 |
# Validate sizes
|
| 120 |
-
assert (
|
| 121 |
-
|
| 122 |
-
)
|
| 123 |
-
assert (
|
| 124 |
-
|
| 125 |
-
)
|
| 126 |
# Tie weights: point the projection weight to the embedding weight Tensor
|
| 127 |
# Remove the existing projection parameter in favor of the embedding weight
|
| 128 |
# This keeps the same Parameter object, so updates affect both modules.
|
|
|
|
| 1 |
+
"""Prediction heads for Transformer models.
|
|
|
|
| 2 |
|
| 3 |
+
This module provides task-specific output heads:
|
| 4 |
+
- ClassificationHead: Sequence-level classification with pooling (mean/cls/max)
|
| 5 |
+
- TokenClassificationHead: Per-token classification (NER, POS tagging)
|
| 6 |
+
- LMHead: Language modeling logits with optional weight tying
|
| 7 |
+
- ProjectionHead: MLP for representation learning / contrastive tasks
|
| 8 |
|
| 9 |
+
Author: Oliver Perrin
|
| 10 |
+
Date: 2025-10-23
|
| 11 |
"""
|
| 12 |
|
| 13 |
from typing import Literal, Optional
|
|
|
|
| 117 |
|
| 118 |
if tie_embedding is not None:
|
| 119 |
# Validate sizes
|
| 120 |
+
assert tie_embedding.num_embeddings == vocab_size, (
|
| 121 |
+
"vocab size mismatch for weight tying"
|
| 122 |
+
)
|
| 123 |
+
assert tie_embedding.embedding_dim == d_model, (
|
| 124 |
+
"embedding dim must match d_model for weight tying"
|
| 125 |
+
)
|
| 126 |
# Tie weights: point the projection weight to the embedding weight Tensor
|
| 127 |
# Remove the existing projection parameter in favor of the embedding weight
|
| 128 |
# This keeps the same Parameter object, so updates affect both modules.
|
src/models/multitask.py
CHANGED
|
@@ -1,18 +1,12 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
-
|
| 10 |
-
|
| 11 |
-
Design goals:
|
| 12 |
-
- Keep composition simple and explicit (use named heads per task)
|
| 13 |
-
- Support encoder-only tasks (classification, token classification) and
|
| 14 |
-
seq2seq tasks (encoder -> decoder -> LMHead)
|
| 15 |
-
- Minimal dependencies on training loop; return logits and (optionally) loss
|
| 16 |
"""
|
| 17 |
|
| 18 |
from typing import Any, Dict, Optional
|
|
|
|
| 1 |
+
"""Multitask model composition utilities.
|
| 2 |
+
|
| 3 |
+
This module provides infrastructure for multi-task learning:
|
| 4 |
+
- MultiTaskModel: Compose encoder/decoder with multiple task heads
|
| 5 |
+
- Routing: forward(task_name, ...) dispatches to correct components
|
| 6 |
+
- Loss computation: Built-in cross-entropy with ignore_index support
|
| 7 |
+
|
| 8 |
+
Author: Oliver Perrin
|
| 9 |
+
Date: 2025-10-23
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
from typing import Any, Dict, Optional
|
src/models/positional_encoding.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
-
# src/models/positional_encoding.py
|
| 2 |
-
|
| 3 |
"""
|
| 4 |
Positional Encoding for Transformer models.
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import math
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Positional Encoding for Transformer models.
|
| 3 |
|
| 4 |
+
Provides sinusoidal position embeddings that inject sequential order information
|
| 5 |
+
into token representations. Required because self-attention is permutation-invariant
|
| 6 |
+
and has no inherent notion of token position.
|
| 7 |
+
|
| 8 |
+
Author: Oliver Perrin
|
| 9 |
+
Date: December 2025
|
| 10 |
"""
|
| 11 |
|
| 12 |
import math
|
src/training/metrics.py
CHANGED
|
@@ -1,4 +1,13 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training and evaluation metrics for LexiMind.
|
| 3 |
+
|
| 4 |
+
Provides metric computation utilities for all task types: accuracy for topic
|
| 5 |
+
classification, multi-label F1 for emotion detection, and ROUGE/BLEU for
|
| 6 |
+
summarization quality assessment.
|
| 7 |
+
|
| 8 |
+
Author: Oliver Perrin
|
| 9 |
+
Date: December 2025
|
| 10 |
+
"""
|
| 11 |
|
| 12 |
from __future__ import annotations
|
| 13 |
|
src/training/trainer.py
CHANGED
|
@@ -1,38 +1,50 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
-
import shutil
|
| 6 |
import time
|
| 7 |
from collections import defaultdict
|
| 8 |
from dataclasses import dataclass
|
| 9 |
-
from typing import Callable, Dict,
|
| 10 |
|
| 11 |
import mlflow
|
| 12 |
import torch
|
| 13 |
import torch.nn.functional as F
|
| 14 |
from torch.utils.data import DataLoader
|
|
|
|
| 15 |
|
| 16 |
from ..data.tokenization import Tokenizer
|
| 17 |
from .metrics import accuracy, multilabel_f1, rouge_like
|
| 18 |
|
|
|
|
|
|
|
| 19 |
|
| 20 |
@dataclass
|
| 21 |
class TrainerConfig:
|
|
|
|
|
|
|
| 22 |
max_epochs: int = 1
|
| 23 |
gradient_clip_norm: float = 1.0
|
| 24 |
-
logging_interval: int = 50
|
| 25 |
task_weights: Dict[str, float] | None = None
|
| 26 |
validation_samples: int = 3
|
| 27 |
validation_max_length: int = 128
|
| 28 |
-
label_smoothing: float = 0.0
|
| 29 |
experiment_name: str = "LexiMind"
|
| 30 |
run_name: str | None = None
|
| 31 |
gradient_accumulation_steps: int = 1
|
| 32 |
|
| 33 |
|
|
|
|
| 34 |
class Trainer:
|
| 35 |
-
"""
|
| 36 |
|
| 37 |
def __init__(
|
| 38 |
self,
|
|
@@ -47,392 +59,315 @@ class Trainer:
|
|
| 47 |
self.config = config
|
| 48 |
self.device = device
|
| 49 |
self.tokenizer = tokenizer
|
|
|
|
|
|
|
| 50 |
self.emotion_loss = torch.nn.BCEWithLogitsLoss()
|
| 51 |
self.topic_loss = torch.nn.CrossEntropyLoss()
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
self.
|
| 55 |
-
self.
|
| 56 |
-
self.
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
# Initialize GradScaler for float16/bfloat16 training
|
| 60 |
-
# This scales gradients to prevent underflow during backward pass
|
| 61 |
-
# Note: bfloat16 generally doesn't need scaling, but we keep it for safety unless it causes NaNs
|
| 62 |
-
self.scaler = torch.GradScaler("cuda", enabled=(device.type == "cuda"))
|
| 63 |
-
|
| 64 |
-
# Initialize MLflow
|
| 65 |
mlflow.set_experiment(config.experiment_name)
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def fit(
|
| 68 |
self,
|
| 69 |
train_loaders: Dict[str, DataLoader],
|
| 70 |
val_loaders: Dict[str, DataLoader] | None = None,
|
| 71 |
checkpoint_callback: Callable | None = None,
|
| 72 |
) -> Dict[str, Dict[str, float]]:
|
| 73 |
-
"""Train
|
| 74 |
-
|
| 75 |
-
Args:
|
| 76 |
-
train_loaders: Task-specific training dataloaders
|
| 77 |
-
val_loaders: Optional task-specific validation dataloaders
|
| 78 |
-
checkpoint_callback: Optional callback(epoch, model, history) to save checkpoints
|
| 79 |
-
|
| 80 |
-
Returns:
|
| 81 |
-
Training history dictionary
|
| 82 |
-
"""
|
| 83 |
history: Dict[str, Dict[str, float]] = {}
|
| 84 |
-
|
| 85 |
-
start_time = time.perf_counter()
|
| 86 |
|
| 87 |
with mlflow.start_run(run_name=self.config.run_name):
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
}
|
| 97 |
)
|
| 98 |
|
| 99 |
-
for epoch in
|
| 100 |
epoch_start = time.perf_counter()
|
| 101 |
-
train_metrics = self._run_epoch(
|
| 102 |
-
train_loaders,
|
| 103 |
-
train=True,
|
| 104 |
-
epoch=epoch,
|
| 105 |
-
total_epochs=total_epochs,
|
| 106 |
-
epoch_start=epoch_start,
|
| 107 |
-
global_start=start_time,
|
| 108 |
-
)
|
| 109 |
-
history[f"train_epoch_{epoch}"] = train_metrics
|
| 110 |
|
| 111 |
-
#
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
|
|
|
|
| 116 |
if val_loaders:
|
| 117 |
val_metrics = self._run_epoch(val_loaders, train=False, epoch=epoch)
|
| 118 |
history[f"val_epoch_{epoch}"] = val_metrics
|
|
|
|
| 119 |
|
| 120 |
-
# Log validation metrics to MLflow
|
| 121 |
-
for k, v in val_metrics.items():
|
| 122 |
-
if k != "epoch":
|
| 123 |
-
mlflow.log_metric(f"val_{k}", v, step=epoch)
|
| 124 |
-
|
| 125 |
-
# Generate sample summaries for manual quality assessment
|
| 126 |
if "summarization" in val_loaders:
|
| 127 |
self._validate_generation(val_loaders["summarization"], epoch)
|
| 128 |
|
| 129 |
-
#
|
| 130 |
-
if checkpoint_callback
|
| 131 |
checkpoint_callback(epoch, self.model, history)
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
|
|
|
|
|
|
| 137 |
return history
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
def _run_epoch(
|
| 140 |
self,
|
| 141 |
loaders: Dict[str, DataLoader],
|
| 142 |
*,
|
| 143 |
train: bool,
|
| 144 |
epoch: int,
|
| 145 |
-
total_epochs: int | None = None,
|
| 146 |
-
epoch_start: float | None = None,
|
| 147 |
-
global_start: float | None = None,
|
| 148 |
) -> Dict[str, float]:
|
| 149 |
-
|
|
|
|
| 150 |
self.model.train(train)
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
}
|
| 155 |
max_batches = max(len(loader) for loader in loaders.values())
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
| 162 |
)
|
| 163 |
|
| 164 |
-
def emit_progress(step: int, final: bool = False) -> None:
|
| 165 |
-
if not progress_enabled:
|
| 166 |
-
return
|
| 167 |
-
total_epochs_value = total_epochs
|
| 168 |
-
epoch_start_value = epoch_start
|
| 169 |
-
global_start_value = global_start
|
| 170 |
-
assert total_epochs_value is not None
|
| 171 |
-
assert epoch_start_value is not None
|
| 172 |
-
assert global_start_value is not None
|
| 173 |
-
self._update_epoch_progress(
|
| 174 |
-
epoch=epoch,
|
| 175 |
-
total_epochs=total_epochs_value,
|
| 176 |
-
step=step,
|
| 177 |
-
total_steps=max_batches,
|
| 178 |
-
epoch_start=epoch_start_value,
|
| 179 |
-
global_start=global_start_value,
|
| 180 |
-
final=final,
|
| 181 |
-
)
|
| 182 |
-
|
| 183 |
-
emit_progress(0)
|
| 184 |
-
|
| 185 |
context = torch.enable_grad() if train else torch.no_grad()
|
| 186 |
with context:
|
| 187 |
-
for step in
|
| 188 |
-
|
| 189 |
-
if (
|
| 190 |
-
train
|
| 191 |
-
and self.device.type == "cuda"
|
| 192 |
-
and hasattr(torch.compiler, "cudagraph_mark_step_begin")
|
| 193 |
-
):
|
| 194 |
-
torch.compiler.cudagraph_mark_step_begin()
|
| 195 |
-
|
| 196 |
-
backward_performed = False
|
| 197 |
-
step_total_loss = 0.0
|
| 198 |
-
|
| 199 |
-
# Mixed Precision Context
|
| 200 |
-
# Using bfloat16 for my RTX 4070 (Ampere/Ada) - better stability than float16
|
| 201 |
-
# Disable scaler for bfloat16 to prevent NaNs
|
| 202 |
-
use_bfloat16 = self.device.type == "cuda" and torch.cuda.is_bf16_supported()
|
| 203 |
|
| 204 |
for task, loader in loaders.items():
|
| 205 |
-
batch = self.
|
| 206 |
if batch is None:
|
| 207 |
continue
|
| 208 |
|
| 209 |
-
with
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
):
|
| 214 |
-
loss, task_metrics = self._forward_task(task, batch, train)
|
| 215 |
|
|
|
|
| 216 |
if torch.isnan(loss):
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
f"Warning: NaN loss detected for task '{task}'. Skipping update for this task. (Consecutive NaNs: {self._nan_counter})"
|
| 221 |
-
)
|
| 222 |
-
if self._nan_counter > 10:
|
| 223 |
-
raise RuntimeError(
|
| 224 |
-
"Too many consecutive NaN losses. Training is diverging."
|
| 225 |
-
)
|
| 226 |
continue
|
| 227 |
-
|
| 228 |
-
if train:
|
| 229 |
-
self._nan_counter = 0
|
| 230 |
-
|
| 231 |
-
weight = self._task_weight(task)
|
| 232 |
-
# Scale loss by gradient accumulation steps
|
| 233 |
-
weighted_loss = (loss * weight) / self.gradient_accumulation_steps
|
| 234 |
-
step_total_loss += weighted_loss.item() * self.gradient_accumulation_steps
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
| 239 |
|
|
|
|
| 240 |
if train:
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
else:
|
| 248 |
-
self.scaler.scale(
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
if
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
self.model.parameters(), self.config.gradient_clip_norm
|
| 264 |
-
)
|
| 265 |
-
self.optimizer.step()
|
| 266 |
-
self.optimizer.zero_grad()
|
| 267 |
-
else:
|
| 268 |
-
self.scaler.unscale_(self.optimizer)
|
| 269 |
-
torch.nn.utils.clip_grad_norm_(
|
| 270 |
-
self.model.parameters(), self.config.gradient_clip_norm
|
| 271 |
-
)
|
| 272 |
-
|
| 273 |
-
# Step optimizer using scaler
|
| 274 |
-
self.scaler.step(self.optimizer)
|
| 275 |
-
self.scaler.update()
|
| 276 |
-
self.optimizer.zero_grad()
|
| 277 |
-
|
| 278 |
-
if (
|
| 279 |
-
train
|
| 280 |
-
and self.config.logging_interval
|
| 281 |
-
and (step + 1) % self.config.logging_interval == 0
|
| 282 |
-
):
|
| 283 |
-
if torch.cuda.is_available() and self.device.type == "cuda":
|
| 284 |
-
torch.cuda.empty_cache()
|
| 285 |
-
emit_progress(step + 1)
|
| 286 |
-
emit_progress(max_batches, final=True)
|
| 287 |
-
|
| 288 |
-
averaged = {
|
| 289 |
-
name: sum(values) / len(values)
|
| 290 |
-
for name, values in metrics_accumulator.items()
|
| 291 |
-
if values
|
| 292 |
-
}
|
| 293 |
averaged["epoch"] = float(epoch)
|
| 294 |
-
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
| 296 |
return averaged
|
| 297 |
|
| 298 |
-
def
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
) -> Dict[str, torch.Tensor] | None:
|
|
|
|
| 304 |
try:
|
| 305 |
-
batch = next(
|
| 306 |
except StopIteration:
|
| 307 |
-
|
| 308 |
try:
|
| 309 |
-
batch = next(
|
| 310 |
except StopIteration:
|
| 311 |
return None
|
| 312 |
return {
|
| 313 |
-
|
| 314 |
-
for
|
| 315 |
}
|
| 316 |
|
|
|
|
|
|
|
| 317 |
def _forward_task(
|
| 318 |
-
self, task: str, batch: Dict[str, torch.Tensor]
|
| 319 |
) -> tuple[torch.Tensor, Dict[str, float]]:
|
|
|
|
| 320 |
if task == "summarization":
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
|
|
|
| 373 |
|
| 374 |
def _decode_labels(self, labels: torch.Tensor) -> List[str]:
|
|
|
|
| 375 |
valid = labels.clone()
|
| 376 |
valid[valid == -100] = self.tokenizer.pad_token_id
|
| 377 |
return self.tokenizer.decode_batch(valid.tolist())
|
| 378 |
|
|
|
|
|
|
|
| 379 |
def _validate_generation(self, val_loader: DataLoader, epoch: int) -> None:
|
| 380 |
-
"""Generate
|
| 381 |
self.model.eval()
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
|
|
|
| 386 |
|
| 387 |
with torch.no_grad():
|
| 388 |
-
for batch in val_loader:
|
| 389 |
-
if
|
| 390 |
break
|
| 391 |
|
| 392 |
batch = {
|
| 393 |
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
| 394 |
for k, v in batch.items()
|
| 395 |
}
|
| 396 |
-
src_ids = batch["src_ids"]
|
| 397 |
src_mask = batch.get("src_mask")
|
| 398 |
-
labels = batch["labels"]
|
| 399 |
-
|
| 400 |
-
# Only process first item from batch
|
| 401 |
-
src_ids = src_ids[:1]
|
| 402 |
if src_mask is not None:
|
| 403 |
src_mask = src_mask[:1]
|
| 404 |
-
labels = labels[:1]
|
| 405 |
|
| 406 |
-
# Encode
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
if samples_generated == 0:
|
| 414 |
-
print("\n[DEBUG] Encoder output stats:")
|
| 415 |
-
print(f" Shape: {memory.shape}")
|
| 416 |
-
print(f" Mean: {memory.mean().item():.6f}")
|
| 417 |
-
print(f" Std: {memory.std().item():.6f}")
|
| 418 |
-
print(f" Min: {memory.min().item():.6f}")
|
| 419 |
-
print(f" Max: {memory.max().item():.6f}")
|
| 420 |
-
print(f" Has NaN: {torch.isnan(memory).any().item()}")
|
| 421 |
-
print(f" Has Inf: {torch.isinf(memory).any().item()}")
|
| 422 |
-
|
| 423 |
-
# Check first few positions
|
| 424 |
-
print(f" First position norm: {memory[0, 0].norm().item():.4f}")
|
| 425 |
-
print(f" Last position norm: {memory[0, -1].norm().item():.4f}")
|
| 426 |
-
|
| 427 |
-
# Ban special tokens from generation
|
| 428 |
-
ban_token_ids = [self.tokenizer.bos_token_id, self.tokenizer.pad_token_id]
|
| 429 |
-
unk_id = getattr(self.tokenizer._tokenizer, "unk_token_id", None)
|
| 430 |
-
if isinstance(unk_id, int):
|
| 431 |
-
ban_token_ids.append(unk_id)
|
| 432 |
-
ban_token_ids = [tid for tid in ban_token_ids if tid is not None]
|
| 433 |
-
|
| 434 |
-
# Generate using naive method (full forward, O(N^2)) for debugging
|
| 435 |
-
generated = self.model.decoder.greedy_decode_naive(
|
| 436 |
memory=memory,
|
| 437 |
max_len=self.config.validation_max_length,
|
| 438 |
start_token_id=self.tokenizer.bos_token_id,
|
|
@@ -441,139 +376,17 @@ class Trainer:
|
|
| 441 |
memory_mask=src_mask,
|
| 442 |
)
|
| 443 |
|
| 444 |
-
# Decode
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
print(f"\nSample {samples_generated + 1}:")
|
| 450 |
-
print(
|
| 451 |
-
f"Raw token IDs: {generated[0][:20].tolist()}..."
|
| 452 |
-
) # Debug: show first 20 tokens
|
| 453 |
-
print(
|
| 454 |
-
f"Source: {source_text[:200]}..."
|
| 455 |
-
if len(source_text) > 200
|
| 456 |
-
else f"Source: {source_text}"
|
| 457 |
-
)
|
| 458 |
-
print(f"Generated: {generated_text}")
|
| 459 |
-
print(
|
| 460 |
-
f"Reference: {reference_text[:200]}..."
|
| 461 |
-
if len(reference_text) > 200
|
| 462 |
-
else f"Reference: {reference_text}"
|
| 463 |
-
)
|
| 464 |
-
print("-" * 80)
|
| 465 |
|
| 466 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
|
| 468 |
-
|
| 469 |
self.model.train()
|
| 470 |
-
|
| 471 |
-
def _print_epoch_progress(
|
| 472 |
-
self,
|
| 473 |
-
epoch: int,
|
| 474 |
-
total_epochs: int,
|
| 475 |
-
epoch_duration: float,
|
| 476 |
-
total_elapsed: float,
|
| 477 |
-
) -> None:
|
| 478 |
-
progress = epoch / total_epochs
|
| 479 |
-
percent = progress * 100
|
| 480 |
-
remaining_epochs = total_epochs - epoch
|
| 481 |
-
eta = (total_elapsed / epoch) * remaining_epochs if epoch > 0 else 0.0
|
| 482 |
-
bar = self._format_progress_bar(progress)
|
| 483 |
-
message = (
|
| 484 |
-
f"[progress] {bar} {percent:5.1f}% | epoch {epoch}/{total_epochs} "
|
| 485 |
-
f"| last {epoch_duration:6.2f}s | total {total_elapsed:6.2f}s | ETA {eta:6.2f}s"
|
| 486 |
-
)
|
| 487 |
-
print(message, flush=True)
|
| 488 |
-
|
| 489 |
-
@staticmethod
|
| 490 |
-
def _format_progress_bar(progress: float, width: int = 20) -> str:
|
| 491 |
-
clamped = max(0.0, min(1.0, progress))
|
| 492 |
-
filled = int(round(clamped * width))
|
| 493 |
-
bar = "#" * filled + "-" * (width - filled)
|
| 494 |
-
return f"[{bar}]"
|
| 495 |
-
|
| 496 |
-
def _update_epoch_progress(
|
| 497 |
-
self,
|
| 498 |
-
*,
|
| 499 |
-
epoch: int,
|
| 500 |
-
total_epochs: int,
|
| 501 |
-
step: int,
|
| 502 |
-
total_steps: int,
|
| 503 |
-
epoch_start: float,
|
| 504 |
-
global_start: float,
|
| 505 |
-
final: bool = False,
|
| 506 |
-
) -> None:
|
| 507 |
-
if total_steps <= 0 or total_epochs <= 0:
|
| 508 |
-
return
|
| 509 |
-
bounded_step = max(0, min(step, total_steps))
|
| 510 |
-
step_fraction = bounded_step / total_steps
|
| 511 |
-
epochs_completed = (epoch - 1) + step_fraction
|
| 512 |
-
overall_progress = epochs_completed / total_epochs
|
| 513 |
-
percent = overall_progress * 100.0
|
| 514 |
-
epoch_elapsed = time.perf_counter() - epoch_start
|
| 515 |
-
total_elapsed = time.perf_counter() - global_start
|
| 516 |
-
if epochs_completed > 0:
|
| 517 |
-
remaining_epochs = max(total_epochs - epochs_completed, 0.0)
|
| 518 |
-
total_eta = (
|
| 519 |
-
(total_elapsed / epochs_completed) * remaining_epochs if total_elapsed > 0 else 0.0
|
| 520 |
-
)
|
| 521 |
-
else:
|
| 522 |
-
total_eta = 0.0
|
| 523 |
-
|
| 524 |
-
if step > 0:
|
| 525 |
-
epoch_eta = (epoch_elapsed / step) * (total_steps - step)
|
| 526 |
-
else:
|
| 527 |
-
epoch_eta = 0.0
|
| 528 |
-
|
| 529 |
-
bar = self._format_progress_bar(overall_progress, width=self._progress_bar_width())
|
| 530 |
-
message = (
|
| 531 |
-
f"[progress] {bar} {percent:5.1f}% "
|
| 532 |
-
f"e {epoch}/{total_epochs} "
|
| 533 |
-
f"s {bounded_step}/{total_steps} "
|
| 534 |
-
f"ep_eta {self._format_duration(epoch_eta)} "
|
| 535 |
-
f"tot_eta {self._format_duration(total_eta)}"
|
| 536 |
-
)
|
| 537 |
-
display = self._truncate_to_terminal(message)
|
| 538 |
-
padding = " " * max(self._progress_last_len - len(display), 0)
|
| 539 |
-
print(f"\r{display}{padding}", end="", flush=True)
|
| 540 |
-
if final:
|
| 541 |
-
print()
|
| 542 |
-
self._progress_last_len = 0
|
| 543 |
-
else:
|
| 544 |
-
self._progress_last_len = len(display)
|
| 545 |
-
|
| 546 |
-
def _truncate_to_terminal(self, text: str) -> str:
|
| 547 |
-
columns = self._terminal_width()
|
| 548 |
-
if columns <= 0:
|
| 549 |
-
return text
|
| 550 |
-
if len(text) >= columns:
|
| 551 |
-
return text[: max(columns - 1, 1)]
|
| 552 |
-
return text
|
| 553 |
-
|
| 554 |
-
def _progress_bar_width(self) -> int:
|
| 555 |
-
columns = self._terminal_width()
|
| 556 |
-
reserved = 60
|
| 557 |
-
if columns <= reserved:
|
| 558 |
-
return 10
|
| 559 |
-
return max(10, min(30, columns - reserved))
|
| 560 |
-
|
| 561 |
-
@staticmethod
|
| 562 |
-
def _terminal_width() -> int:
|
| 563 |
-
try:
|
| 564 |
-
return shutil.get_terminal_size(fallback=(120, 20)).columns
|
| 565 |
-
except OSError:
|
| 566 |
-
return 120
|
| 567 |
-
|
| 568 |
-
@staticmethod
|
| 569 |
-
def _format_duration(seconds: float) -> str:
|
| 570 |
-
seconds = max(0.0, seconds)
|
| 571 |
-
if seconds >= 3600:
|
| 572 |
-
hours = int(seconds // 3600)
|
| 573 |
-
minutes = int((seconds % 3600) // 60)
|
| 574 |
-
return f"{hours}h{minutes:02}m"
|
| 575 |
-
if seconds >= 60:
|
| 576 |
-
minutes = int(seconds // 60)
|
| 577 |
-
secs = int(seconds % 60)
|
| 578 |
-
return f"{minutes}m{secs:02}s"
|
| 579 |
-
return f"{seconds:4.1f}s"
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-task Trainer for LexiMind.
|
| 3 |
+
|
| 4 |
+
Handles training across summarization, emotion, and topic heads with mixed-precision,
|
| 5 |
+
gradient accumulation, and MLflow logging.
|
| 6 |
+
|
| 7 |
+
Author: Oliver Perrin
|
| 8 |
+
Date: December 2025
|
| 9 |
+
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
|
|
|
| 13 |
import time
|
| 14 |
from collections import defaultdict
|
| 15 |
from dataclasses import dataclass
|
| 16 |
+
from typing import Any, Callable, Dict, List
|
| 17 |
|
| 18 |
import mlflow
|
| 19 |
import torch
|
| 20 |
import torch.nn.functional as F
|
| 21 |
from torch.utils.data import DataLoader
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
|
| 24 |
from ..data.tokenization import Tokenizer
|
| 25 |
from .metrics import accuracy, multilabel_f1, rouge_like
|
| 26 |
|
| 27 |
+
# --------------- Configuration ---------------
|
| 28 |
+
|
| 29 |
|
| 30 |
@dataclass
|
| 31 |
class TrainerConfig:
|
| 32 |
+
"""Training hyperparameters."""
|
| 33 |
+
|
| 34 |
max_epochs: int = 1
|
| 35 |
gradient_clip_norm: float = 1.0
|
|
|
|
| 36 |
task_weights: Dict[str, float] | None = None
|
| 37 |
validation_samples: int = 3
|
| 38 |
validation_max_length: int = 128
|
| 39 |
+
label_smoothing: float = 0.0
|
| 40 |
experiment_name: str = "LexiMind"
|
| 41 |
run_name: str | None = None
|
| 42 |
gradient_accumulation_steps: int = 1
|
| 43 |
|
| 44 |
|
| 45 |
+
# --------------- Trainer ---------------
|
| 46 |
class Trainer:
|
| 47 |
+
"""Multi-task trainer with AMP and gradient accumulation."""
|
| 48 |
|
| 49 |
def __init__(
|
| 50 |
self,
|
|
|
|
| 59 |
self.config = config
|
| 60 |
self.device = device
|
| 61 |
self.tokenizer = tokenizer
|
| 62 |
+
|
| 63 |
+
# Task losses
|
| 64 |
self.emotion_loss = torch.nn.BCEWithLogitsLoss()
|
| 65 |
self.topic_loss = torch.nn.CrossEntropyLoss()
|
| 66 |
+
|
| 67 |
+
# AMP setup: bfloat16 for Ampere+ GPUs, float16 otherwise
|
| 68 |
+
self.use_amp = device.type == "cuda"
|
| 69 |
+
self.use_bfloat16 = self.use_amp and torch.cuda.is_bf16_supported()
|
| 70 |
+
self.scaler = torch.GradScaler("cuda", enabled=(self.use_amp and not self.use_bfloat16))
|
| 71 |
+
|
| 72 |
+
self._nan_counter = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
mlflow.set_experiment(config.experiment_name)
|
| 74 |
|
| 75 |
+
# CUDA optimizations
|
| 76 |
+
if device.type == "cuda":
|
| 77 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 78 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 79 |
+
|
| 80 |
+
# --------------- Training Loop ---------------
|
| 81 |
+
|
| 82 |
def fit(
|
| 83 |
self,
|
| 84 |
train_loaders: Dict[str, DataLoader],
|
| 85 |
val_loaders: Dict[str, DataLoader] | None = None,
|
| 86 |
checkpoint_callback: Callable | None = None,
|
| 87 |
) -> Dict[str, Dict[str, float]]:
|
| 88 |
+
"""Train model across all tasks with progress tracking."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
history: Dict[str, Dict[str, float]] = {}
|
| 90 |
+
total_start = time.perf_counter()
|
|
|
|
| 91 |
|
| 92 |
with mlflow.start_run(run_name=self.config.run_name):
|
| 93 |
+
self._log_config()
|
| 94 |
+
|
| 95 |
+
# Epoch progress bar
|
| 96 |
+
epoch_pbar = tqdm(
|
| 97 |
+
range(1, self.config.max_epochs + 1),
|
| 98 |
+
desc="Training",
|
| 99 |
+
unit="epoch",
|
| 100 |
+
position=0,
|
|
|
|
| 101 |
)
|
| 102 |
|
| 103 |
+
for epoch in epoch_pbar:
|
| 104 |
epoch_start = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
+
# Train
|
| 107 |
+
train_metrics = self._run_epoch(train_loaders, train=True, epoch=epoch)
|
| 108 |
+
history[f"train_epoch_{epoch}"] = train_metrics
|
| 109 |
+
self._log_metrics(train_metrics, "train", epoch)
|
| 110 |
|
| 111 |
+
# Validate
|
| 112 |
if val_loaders:
|
| 113 |
val_metrics = self._run_epoch(val_loaders, train=False, epoch=epoch)
|
| 114 |
history[f"val_epoch_{epoch}"] = val_metrics
|
| 115 |
+
self._log_metrics(val_metrics, "val", epoch)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
if "summarization" in val_loaders:
|
| 118 |
self._validate_generation(val_loaders["summarization"], epoch)
|
| 119 |
|
| 120 |
+
# Checkpoint
|
| 121 |
+
if checkpoint_callback:
|
| 122 |
checkpoint_callback(epoch, self.model, history)
|
| 123 |
|
| 124 |
+
# Update epoch progress bar with metrics
|
| 125 |
+
epoch_time = time.perf_counter() - epoch_start
|
| 126 |
+
total_time = time.perf_counter() - total_start
|
| 127 |
+
desc = f"Epoch {epoch}/{self.config.max_epochs}"
|
| 128 |
+
if "total_loss" in train_metrics:
|
| 129 |
+
desc += f" | loss={train_metrics['total_loss']:.3f}"
|
| 130 |
+
epoch_pbar.set_description(desc)
|
| 131 |
+
epoch_pbar.set_postfix(
|
| 132 |
+
{"time": f"{epoch_time:.1f}s", "total": f"{total_time:.1f}s"}
|
| 133 |
+
)
|
| 134 |
|
| 135 |
+
total_time = time.perf_counter() - total_start
|
| 136 |
+
print(f"\n✓ Training complete in {total_time:.1f}s")
|
| 137 |
return history
|
| 138 |
|
| 139 |
+
def _log_config(self) -> None:
|
| 140 |
+
"""Log config to MLflow."""
|
| 141 |
+
mlflow.log_params(
|
| 142 |
+
{
|
| 143 |
+
"max_epochs": self.config.max_epochs,
|
| 144 |
+
"gradient_clip_norm": self.config.gradient_clip_norm,
|
| 145 |
+
"label_smoothing": self.config.label_smoothing,
|
| 146 |
+
"task_weights": str(self.config.task_weights),
|
| 147 |
+
}
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def _log_metrics(self, metrics: Dict[str, float], prefix: str, epoch: int) -> None:
|
| 151 |
+
"""Log metrics to MLflow."""
|
| 152 |
+
for k, v in metrics.items():
|
| 153 |
+
if k != "epoch":
|
| 154 |
+
mlflow.log_metric(f"{prefix}_{k}", v, step=epoch)
|
| 155 |
+
|
| 156 |
+
# --------------- Epoch Execution ---------------
|
| 157 |
+
|
| 158 |
def _run_epoch(
|
| 159 |
self,
|
| 160 |
loaders: Dict[str, DataLoader],
|
| 161 |
*,
|
| 162 |
train: bool,
|
| 163 |
epoch: int,
|
|
|
|
|
|
|
|
|
|
| 164 |
) -> Dict[str, float]:
|
| 165 |
+
"""Run one epoch with progress bar."""
|
| 166 |
+
phase = "Train" if train else "Val"
|
| 167 |
self.model.train(train)
|
| 168 |
+
|
| 169 |
+
metrics: Dict[str, List[float]] = defaultdict(list)
|
| 170 |
+
iterators = {task: iter(loader) for task, loader in loaders.items()}
|
|
|
|
| 171 |
max_batches = max(len(loader) for loader in loaders.values())
|
| 172 |
+
accum_steps = self.config.gradient_accumulation_steps
|
| 173 |
+
|
| 174 |
+
# Batch progress bar (nested under epoch bar)
|
| 175 |
+
pbar = tqdm(
|
| 176 |
+
range(max_batches),
|
| 177 |
+
desc=f" {phase}",
|
| 178 |
+
unit="batch",
|
| 179 |
+
leave=False,
|
| 180 |
+
position=1,
|
| 181 |
)
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
context = torch.enable_grad() if train else torch.no_grad()
|
| 184 |
with context:
|
| 185 |
+
for step in pbar:
|
| 186 |
+
step_loss = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
for task, loader in loaders.items():
|
| 189 |
+
batch = self._get_batch(iterators, loader, task)
|
| 190 |
if batch is None:
|
| 191 |
continue
|
| 192 |
|
| 193 |
+
# Forward with AMP
|
| 194 |
+
amp_dtype = torch.bfloat16 if self.use_bfloat16 else torch.float16
|
| 195 |
+
with torch.autocast("cuda", dtype=amp_dtype, enabled=self.use_amp):
|
| 196 |
+
loss, task_metrics = self._forward_task(task, batch)
|
|
|
|
|
|
|
| 197 |
|
| 198 |
+
# NaN check
|
| 199 |
if torch.isnan(loss):
|
| 200 |
+
self._nan_counter += 1
|
| 201 |
+
if self._nan_counter > 10:
|
| 202 |
+
raise RuntimeError("Training diverging - too many NaN losses")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
continue
|
| 204 |
+
self._nan_counter = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
+
# Record metrics
|
| 207 |
+
metrics[f"{task}_loss"].append(loss.item())
|
| 208 |
+
for name, val in task_metrics.items():
|
| 209 |
+
metrics[f"{task}_{name}"].append(val)
|
| 210 |
|
| 211 |
+
# Backward
|
| 212 |
if train:
|
| 213 |
+
weight = (self.config.task_weights or {}).get(task, 1.0)
|
| 214 |
+
scaled = (loss * weight) / accum_steps
|
| 215 |
+
step_loss += scaled.item() * accum_steps
|
| 216 |
+
|
| 217 |
+
if self.use_bfloat16:
|
| 218 |
+
scaled.backward()
|
| 219 |
else:
|
| 220 |
+
self.scaler.scale(scaled).backward()
|
| 221 |
+
|
| 222 |
+
# Optimizer step
|
| 223 |
+
if train and (step + 1) % accum_steps == 0:
|
| 224 |
+
self._optimizer_step()
|
| 225 |
+
|
| 226 |
+
if step_loss > 0:
|
| 227 |
+
metrics["total_loss"].append(step_loss)
|
| 228 |
+
|
| 229 |
+
# Update progress bar
|
| 230 |
+
if metrics["total_loss"]:
|
| 231 |
+
pbar.set_postfix({"loss": f"{metrics['total_loss'][-1]:.3f}"})
|
| 232 |
+
|
| 233 |
+
# Average and print summary
|
| 234 |
+
averaged = {k: sum(v) / len(v) for k, v in metrics.items() if v}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
averaged["epoch"] = float(epoch)
|
| 236 |
+
|
| 237 |
+
summary = f"[{phase.lower()}] epoch {epoch}: "
|
| 238 |
+
summary += ", ".join(f"{k}={v:.4f}" for k, v in averaged.items() if k != "epoch")
|
| 239 |
+
tqdm.write(summary)
|
| 240 |
+
|
| 241 |
return averaged
|
| 242 |
|
| 243 |
+
def _optimizer_step(self) -> None:
|
| 244 |
+
"""Optimizer step with gradient clipping."""
|
| 245 |
+
if self.use_bfloat16:
|
| 246 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip_norm)
|
| 247 |
+
self.optimizer.step()
|
| 248 |
+
else:
|
| 249 |
+
self.scaler.unscale_(self.optimizer)
|
| 250 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip_norm)
|
| 251 |
+
self.scaler.step(self.optimizer)
|
| 252 |
+
self.scaler.update()
|
| 253 |
+
self.optimizer.zero_grad()
|
| 254 |
+
|
| 255 |
+
def _get_batch(
|
| 256 |
+
self, iterators: Dict, loader: DataLoader, task: str
|
| 257 |
) -> Dict[str, torch.Tensor] | None:
|
| 258 |
+
"""Get next batch, cycling iterator if exhausted."""
|
| 259 |
try:
|
| 260 |
+
batch = next(iterators[task])
|
| 261 |
except StopIteration:
|
| 262 |
+
iterators[task] = iter(loader)
|
| 263 |
try:
|
| 264 |
+
batch = next(iterators[task])
|
| 265 |
except StopIteration:
|
| 266 |
return None
|
| 267 |
return {
|
| 268 |
+
k: v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v
|
| 269 |
+
for k, v in batch.items()
|
| 270 |
}
|
| 271 |
|
| 272 |
+
# --------------- Task Forward Passes ---------------
|
| 273 |
+
|
| 274 |
def _forward_task(
|
| 275 |
+
self, task: str, batch: Dict[str, torch.Tensor]
|
| 276 |
) -> tuple[torch.Tensor, Dict[str, float]]:
|
| 277 |
+
"""Route to task-specific forward pass."""
|
| 278 |
if task == "summarization":
|
| 279 |
+
return self._forward_summarization(batch)
|
| 280 |
+
elif task == "emotion":
|
| 281 |
+
return self._forward_emotion(batch)
|
| 282 |
+
elif task == "topic":
|
| 283 |
+
return self._forward_topic(batch)
|
| 284 |
+
raise ValueError(f"Unknown task: {task}")
|
| 285 |
+
|
| 286 |
+
def _forward_summarization(
|
| 287 |
+
self, batch: Dict[str, torch.Tensor]
|
| 288 |
+
) -> tuple[torch.Tensor, Dict[str, float]]:
|
| 289 |
+
"""Seq2seq forward for summarization."""
|
| 290 |
+
inputs = {"src_ids": batch["src_ids"], "tgt_ids": batch["tgt_ids"]}
|
| 291 |
+
if "src_mask" in batch:
|
| 292 |
+
inputs["src_mask"] = batch["src_mask"]
|
| 293 |
+
|
| 294 |
+
logits = self.model.forward("summarization", inputs)
|
| 295 |
+
loss = F.cross_entropy(
|
| 296 |
+
logits.view(-1, logits.size(-1)),
|
| 297 |
+
batch["labels"].view(-1),
|
| 298 |
+
ignore_index=-100,
|
| 299 |
+
label_smoothing=self.config.label_smoothing,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Quick ROUGE estimate
|
| 303 |
+
preds = self.tokenizer.decode_batch(logits.argmax(dim=-1).tolist())
|
| 304 |
+
refs = self._decode_labels(batch["labels"])
|
| 305 |
+
return loss, {"rouge_like": rouge_like(preds, refs)}
|
| 306 |
+
|
| 307 |
+
def _forward_emotion(
|
| 308 |
+
self, batch: Dict[str, torch.Tensor]
|
| 309 |
+
) -> tuple[torch.Tensor, Dict[str, float]]:
|
| 310 |
+
"""Multi-label emotion classification."""
|
| 311 |
+
inputs = {"input_ids": batch["input_ids"]}
|
| 312 |
+
if "attention_mask" in batch:
|
| 313 |
+
inputs["attention_mask"] = batch["attention_mask"]
|
| 314 |
+
|
| 315 |
+
logits = self.model.forward("emotion", inputs)
|
| 316 |
+
loss = self.emotion_loss(logits, batch["labels"].float())
|
| 317 |
+
preds = (torch.sigmoid(logits) > 0.5).int()
|
| 318 |
+
return loss, {"f1": multilabel_f1(preds, batch["labels"].int())}
|
| 319 |
+
|
| 320 |
+
def _forward_topic(
|
| 321 |
+
self, batch: Dict[str, torch.Tensor]
|
| 322 |
+
) -> tuple[torch.Tensor, Dict[str, float]]:
|
| 323 |
+
"""Single-label topic classification."""
|
| 324 |
+
inputs = {"input_ids": batch["input_ids"]}
|
| 325 |
+
if "attention_mask" in batch:
|
| 326 |
+
inputs["attention_mask"] = batch["attention_mask"]
|
| 327 |
+
|
| 328 |
+
logits = self.model.forward("topic", inputs)
|
| 329 |
+
loss = self.topic_loss(logits, batch["labels"])
|
| 330 |
+
preds = logits.argmax(dim=-1)
|
| 331 |
+
return loss, {"accuracy": accuracy(preds.tolist(), batch["labels"].tolist())}
|
| 332 |
|
| 333 |
def _decode_labels(self, labels: torch.Tensor) -> List[str]:
|
| 334 |
+
"""Decode labels, replacing -100 with pad token."""
|
| 335 |
valid = labels.clone()
|
| 336 |
valid[valid == -100] = self.tokenizer.pad_token_id
|
| 337 |
return self.tokenizer.decode_batch(valid.tolist())
|
| 338 |
|
| 339 |
+
# --------------- Validation Generation ---------------
|
| 340 |
+
|
| 341 |
def _validate_generation(self, val_loader: DataLoader, epoch: int) -> None:
|
| 342 |
+
"""Generate sample summaries for quality check."""
|
| 343 |
self.model.eval()
|
| 344 |
+
n = self.config.validation_samples
|
| 345 |
+
|
| 346 |
+
tqdm.write(f"\n{'=' * 50}")
|
| 347 |
+
tqdm.write(f"[Validation Samples - Epoch {epoch}]")
|
| 348 |
+
tqdm.write(f"{'=' * 50}")
|
| 349 |
|
| 350 |
with torch.no_grad():
|
| 351 |
+
for i, batch in enumerate(val_loader):
|
| 352 |
+
if i >= n:
|
| 353 |
break
|
| 354 |
|
| 355 |
batch = {
|
| 356 |
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
| 357 |
for k, v in batch.items()
|
| 358 |
}
|
| 359 |
+
src_ids = batch["src_ids"][:1]
|
| 360 |
src_mask = batch.get("src_mask")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
if src_mask is not None:
|
| 362 |
src_mask = src_mask[:1]
|
|
|
|
| 363 |
|
| 364 |
+
# Encode and generate
|
| 365 |
+
enc_mask = (
|
| 366 |
+
src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
| 367 |
+
)
|
| 368 |
+
model: Any = self.model
|
| 369 |
+
memory = model.encoder(src_ids, mask=enc_mask)
|
| 370 |
+
generated = model.decoder.greedy_decode_naive(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
memory=memory,
|
| 372 |
max_len=self.config.validation_max_length,
|
| 373 |
start_token_id=self.tokenizer.bos_token_id,
|
|
|
|
| 376 |
memory_mask=src_mask,
|
| 377 |
)
|
| 378 |
|
| 379 |
+
# Decode and display
|
| 380 |
+
src = self.tokenizer.decode(src_ids[0].tolist())
|
| 381 |
+
out = self.tokenizer.decode(generated[0].tolist())
|
| 382 |
+
ref = self._decode_labels(batch["labels"][:1])[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
|
| 384 |
+
tqdm.write(f"\nSample {i + 1}:")
|
| 385 |
+
tqdm.write(f" Source: {src[:120]}..." if len(src) > 120 else f" Source: {src}")
|
| 386 |
+
tqdm.write(f" Generated: {out}")
|
| 387 |
+
tqdm.write(
|
| 388 |
+
f" Reference: {ref[:120]}..." if len(ref) > 120 else f" Reference: {ref}"
|
| 389 |
+
)
|
| 390 |
|
| 391 |
+
tqdm.write(f"{'=' * 50}\n")
|
| 392 |
self.model.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/utils.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training utilities for LexiMind.
|
| 3 |
+
|
| 4 |
+
Provides reproducibility helpers including seed management for stdlib, PyTorch,
|
| 5 |
+
and NumPy random number generators with thread-safe spawning support.
|
| 6 |
+
|
| 7 |
+
Author: Oliver Perrin
|
| 8 |
+
Date: December 2025
|
| 9 |
+
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
src/utils/config.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from pathlib import Path
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration utilities for LexiMind.
|
| 3 |
+
|
| 4 |
+
Provides YAML configuration loading with validation.
|
| 5 |
+
|
| 6 |
+
Author: Oliver Perrin
|
| 7 |
+
Date: December 2025
|
| 8 |
+
"""
|
| 9 |
|
| 10 |
from dataclasses import dataclass
|
| 11 |
from pathlib import Path
|
src/utils/io.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from pathlib import Path
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Checkpoint I/O utilities for LexiMind.
|
| 3 |
+
|
| 4 |
+
Handles model state serialization with support for torch.compile artifacts.
|
| 5 |
+
|
| 6 |
+
Author: Oliver Perrin
|
| 7 |
+
Date: December 2025
|
| 8 |
+
"""
|
| 9 |
|
| 10 |
from pathlib import Path
|
| 11 |
|
src/utils/labels.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Label metadata utilities for LexiMind.
|
| 3 |
+
|
| 4 |
+
Manages persistence and loading of emotion and topic label vocabularies
|
| 5 |
+
for multitask inference.
|
| 6 |
+
|
| 7 |
+
Author: Oliver Perrin
|
| 8 |
+
Date: December 2025
|
| 9 |
+
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
src/utils/logging.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Logging utilities for LexiMind.
|
| 3 |
+
|
| 4 |
+
Provides centralized logging configuration and logger factory.
|
| 5 |
+
|
| 6 |
+
Author: Oliver Perrin
|
| 7 |
+
Date: December 2025
|
| 8 |
+
"""
|
| 9 |
|
| 10 |
import logging
|
| 11 |
|
src/utils/random.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import random
|
| 4 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Randomness utilities for LexiMind.
|
| 3 |
+
|
| 4 |
+
Provides seed management for reproducibility.
|
| 5 |
+
|
| 6 |
+
Author: Oliver Perrin
|
| 7 |
+
Date: December 2025
|
| 8 |
+
"""
|
| 9 |
|
| 10 |
import random
|
| 11 |
|
tests/test_training/test_trainer.py
CHANGED
|
@@ -17,7 +17,7 @@ class TestTrainer(unittest.TestCase):
|
|
| 17 |
self.model = MagicMock()
|
| 18 |
self.model.to.return_value = self.model # Ensure .to() returns the same mock
|
| 19 |
self.optimizer = MagicMock(spec=torch.optim.Optimizer)
|
| 20 |
-
self.config = TrainerConfig(max_epochs=1
|
| 21 |
self.device = torch.device("cpu")
|
| 22 |
self.tokenizer = MagicMock()
|
| 23 |
self.tokenizer.pad_token_id = 0
|
|
|
|
| 17 |
self.model = MagicMock()
|
| 18 |
self.model.to.return_value = self.model # Ensure .to() returns the same mock
|
| 19 |
self.optimizer = MagicMock(spec=torch.optim.Optimizer)
|
| 20 |
+
self.config = TrainerConfig(max_epochs=1)
|
| 21 |
self.device = torch.device("cpu")
|
| 22 |
self.tokenizer = MagicMock()
|
| 23 |
self.tokenizer.pad_token_id = 0
|