OliverPerrin commited on
Commit
590a604
·
1 Parent(s): 7977c7d

Full training run, code cleanup, mypy/ruff fixes

Browse files

- Completed 3-epoch full training run (21 hours)
- Emotion F1: 94.6%, Topic Accuracy: 94.2%, ROUGE-like: 0.36
- Code simplification and standardized docstrings across all modules
- Fixed all mypy type errors (47 files pass)
- Fixed all ruff linting errors and reformatted code
- All 75 tests passing
- Added GPU optimizations: TF32, Flash Attention, memory-efficient SDP
- Training configs optimized for RTX 4070 12GB
- tqdm progress bars for training and evaluation

configs/config.yaml CHANGED
@@ -4,6 +4,13 @@ defaults:
4
  - training: default
5
  - _self_
6
 
 
 
 
 
 
 
 
7
  checkpoint_out: "checkpoints/best.pt"
8
  labels_out: "artifacts/labels.json"
9
  history_out: "outputs/training_history.json"
 
4
  - training: default
5
  - _self_
6
 
7
+ # Hydra config - prevent output dir conflicts
8
+ hydra:
9
+ run:
10
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
11
+ sweep:
12
+ dir: outputs/multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
13
+
14
  checkpoint_out: "checkpoints/best.pt"
15
  labels_out: "artifacts/labels.json"
16
  history_out: "outputs/training_history.json"
configs/training/dev.yaml CHANGED
@@ -1,35 +1,32 @@
1
  # Development/Testing Configuration for FLAN-T5-base
2
  # Fast iteration for debugging and testing changes
3
- # Training time: ~10 minutes on RTX 4070 with aot_eager backend
4
  # Use: python scripts/train.py training=dev
5
 
6
  dataloader:
7
- batch_size: 8
8
  shuffle: true
9
- num_workers: 4 # Reduced to avoid overhead
10
  pin_memory: true
11
 
12
  optimizer:
13
  name: adamw
14
- lr: 5.0e-5 # Higher LR for faster convergence on small dataset
15
  weight_decay: 0.01
16
 
17
  scheduler:
18
  name: cosine
19
- warmup_steps: 50 # Fewer warmup steps for short training
20
 
21
  trainer:
22
- max_epochs: 1 # Single epoch for quick testing
23
  gradient_clip_norm: 1.0
24
- gradient_accumulation_steps: 1 # No accumulation for speed
25
- validation_max_length: 64 # Shorter for faster validation
26
  label_smoothing: 0.1
27
  task_weights:
28
  summarization: 1.0
29
  emotion: 1.0
30
  topic: 1.0
31
-
32
- # Development-specific settings - optimized for ~10 min total
33
- max_train_samples: 2000 # Reduced for faster iteration
34
- max_val_samples: 200
35
- validation_frequency: 1000 # Validate once during training
 
1
  # Development/Testing Configuration for FLAN-T5-base
2
  # Fast iteration for debugging and testing changes
3
+ # Training time: ~3-5 minutes on RTX 4070 12GB
4
  # Use: python scripts/train.py training=dev
5
 
6
  dataloader:
7
+ batch_size: 8 # Safe for 12GB VRAM - no shared memory spillover
8
  shuffle: true
9
+ num_workers: 4
10
  pin_memory: true
11
 
12
  optimizer:
13
  name: adamw
14
+ lr: 5.0e-5 # Higher LR for fast convergence
15
  weight_decay: 0.01
16
 
17
  scheduler:
18
  name: cosine
19
+ warmup_steps: 50
20
 
21
  trainer:
22
+ max_epochs: 1
23
  gradient_clip_norm: 1.0
24
+ gradient_accumulation_steps: 1 # No accumulation - maximize throughput
25
+ validation_max_length: 64
26
  label_smoothing: 0.1
27
  task_weights:
28
  summarization: 1.0
29
  emotion: 1.0
30
  topic: 1.0
31
+ max_train_samples: 2000
32
+ max_val_samples: 200
 
 
 
configs/training/full.yaml CHANGED
@@ -1,12 +1,12 @@
1
  # Full Training Configuration for FLAN-T5-base
2
  # Complete training run on all data
3
- # Training time: ~6-8 hours on RTX 4070
4
  # Use: python scripts/train.py training=full
5
 
6
  dataloader:
7
- batch_size: 11 # Reduced for FLAN-T5-base (12 layers)
8
  shuffle: true
9
- num_workers: 8
10
  pin_memory: true
11
 
12
  optimizer:
@@ -16,12 +16,12 @@ optimizer:
16
 
17
  scheduler:
18
  name: cosine
19
- warmup_steps: 1000 # More warmup for full training
20
 
21
  trainer:
22
- max_epochs: 4
23
- gradient_clip_norm: 0.5
24
- gradient_accumulation_steps: 6 # Effective batch size = 8 * 6 = 48
25
  validation_max_length: 128
26
  label_smoothing: 0.1
27
  task_weights:
 
1
  # Full Training Configuration for FLAN-T5-base
2
  # Complete training run on all data
3
+ # Training time: ~6-8 hours on RTX 4070 12GB
4
  # Use: python scripts/train.py training=full
5
 
6
  dataloader:
7
+ batch_size: 6 # Optimized for 12GB VRAM
8
  shuffle: true
9
+ num_workers: 6
10
  pin_memory: true
11
 
12
  optimizer:
 
16
 
17
  scheduler:
18
  name: cosine
19
+ warmup_steps: 500 # ~3% of steps
20
 
21
  trainer:
22
+ max_epochs: 3 # 3 epochs usually sufficient, avoids overfit
23
+ gradient_clip_norm: 1.0
24
+ gradient_accumulation_steps: 6 # Effective batch = 36
25
  validation_max_length: 128
26
  label_smoothing: 0.1
27
  task_weights:
configs/training/medium.yaml CHANGED
@@ -1,36 +1,32 @@
1
  # Medium Configuration for FLAN-T5-base
2
  # Balanced approach - good results in reasonable time
3
- # Training time: ~2-3 hours on RTX 4070
4
  # Use: python scripts/train.py training=medium
5
- # Note: FLAN-T5-base has 12 layers (vs BART's 6), may need smaller batch
6
 
7
  dataloader:
8
- batch_size: 11 # Reduced for FLAN-T5-base (12 layers uses more VRAM)
9
  shuffle: true
10
- num_workers: 8
11
  pin_memory: true
12
 
13
  optimizer:
14
  name: adamw
15
- lr: 2.0e-5 # Slightly lower for larger model
16
  weight_decay: 0.01
17
 
18
  scheduler:
19
  name: cosine
20
- warmup_steps: 500 # More warmup for larger model
21
 
22
  trainer:
23
  max_epochs: 3
24
- gradient_clip_norm: 0.5
25
- gradient_accumulation_steps: 4 # Effective batch size = 8 * 4 = 32
26
- validation_max_length: 128
27
  label_smoothing: 0.1
28
  task_weights:
29
  summarization: 1.0
30
  emotion: 1.0
31
  topic: 1.0
32
-
33
- # Medium dataset - good representative sample
34
  max_train_samples: 50000
35
- max_val_samples: 5000
36
- validation_frequency: 5000
 
1
  # Medium Configuration for FLAN-T5-base
2
  # Balanced approach - good results in reasonable time
3
+ # Training time: ~2-3 hours on RTX 4070 12GB
4
  # Use: python scripts/train.py training=medium
 
5
 
6
  dataloader:
7
+ batch_size: 6 # Optimized for 12GB VRAM with accumulation
8
  shuffle: true
9
+ num_workers: 6
10
  pin_memory: true
11
 
12
  optimizer:
13
  name: adamw
14
+ lr: 3.0e-5 # Slightly higher - compensates for effective batch
15
  weight_decay: 0.01
16
 
17
  scheduler:
18
  name: cosine
19
+ warmup_steps: 300 # ~5% of steps
20
 
21
  trainer:
22
  max_epochs: 3
23
+ gradient_clip_norm: 1.0
24
+ gradient_accumulation_steps: 3 # Effective batch = 18
25
+ validation_max_length: 96
26
  label_smoothing: 0.1
27
  task_weights:
28
  summarization: 1.0
29
  emotion: 1.0
30
  topic: 1.0
 
 
31
  max_train_samples: 50000
32
+ max_val_samples: 5000
 
outputs/evaluation_report.json CHANGED
@@ -1,44 +1,44 @@
1
  {
2
- "split": "test",
3
  "summarization": {
4
- "rouge_like": 0.3430426484440944,
5
- "bleu": 0.0879515124653127
6
  },
7
  "emotion": {
8
- "f1_macro": 0.3558666706085205
9
  },
10
  "topic": {
11
- "accuracy": 0.8576315789473684,
12
  "classification_report": {
13
  "Business": {
14
- "precision": 0.7614165890027959,
15
- "recall": 0.86,
16
- "f1-score": 0.8077113198220465,
17
- "support": 1900
18
  },
19
  "Sci/Tech": {
20
- "precision": 0.8759791122715405,
21
- "recall": 0.7063157894736842,
22
- "f1-score": 0.782051282051282,
23
- "support": 1900
24
  },
25
  "Sports": {
26
- "precision": 0.9454638124362895,
27
- "recall": 0.9763157894736842,
28
- "f1-score": 0.9606421543241843,
29
- "support": 1900
30
  },
31
  "World": {
32
- "precision": 0.8607142857142858,
33
- "recall": 0.8878947368421053,
34
- "f1-score": 0.8740932642487047,
35
- "support": 1900
36
  },
37
  "macro avg": {
38
- "precision": 0.860893449856228,
39
- "recall": 0.8576315789473684,
40
- "f1-score": 0.8561245051115545,
41
- "support": 7600
42
  }
43
  }
44
  }
 
1
  {
2
+ "split": "val",
3
  "summarization": {
4
+ "rouge_like": 0.35947467920968945,
5
+ "bleu": 0.09027012433010549
6
  },
7
  "emotion": {
8
+ "f1_macro": 0.9455000162124634
9
  },
10
  "topic": {
11
+ "accuracy": 0.94175,
12
  "classification_report": {
13
  "Business": {
14
+ "precision": 0.9319045973038369,
15
+ "recall": 0.8986666666666666,
16
+ "f1-score": 0.9149838791786866,
17
+ "support": 3000
18
  },
19
  "Sci/Tech": {
20
+ "precision": 0.9055627425614489,
21
+ "recall": 0.9333333333333333,
22
+ "f1-score": 0.9192383453709784,
23
+ "support": 3000
24
  },
25
  "Sports": {
26
+ "precision": 0.9856475300400535,
27
+ "recall": 0.9843333333333333,
28
+ "f1-score": 0.9849899933288859,
29
+ "support": 3000
30
  },
31
  "World": {
32
+ "precision": 0.9446836700894335,
33
+ "recall": 0.9506666666666667,
34
+ "f1-score": 0.9476657252035222,
35
+ "support": 3000
36
  },
37
  "macro avg": {
38
+ "precision": 0.9419496349986932,
39
+ "recall": 0.94175,
40
+ "f1-score": 0.9417194857705183,
41
+ "support": 12000
42
  }
43
  }
44
  }
outputs/training_history.json CHANGED
@@ -1,21 +1,59 @@
1
  {
2
  "train_epoch_1": {
3
- "summarization_loss": 3.67411927986145,
4
- "summarization_rouge_like": 0.39456057390021504,
5
- "emotion_loss": 0.5643834336996079,
6
- "emotion_f1": 0.023809524163603782,
7
- "topic_loss": 1.2467568359375,
8
- "topic_accuracy": 0.587,
9
- "total_loss": 5.485259549498558,
10
  "epoch": 1.0
11
  },
12
  "val_epoch_1": {
13
- "summarization_loss": 3.2498003482818603,
14
- "summarization_rouge_like": 0.44230111155579444,
15
- "emotion_loss": 0.4288424849510193,
16
- "emotion_f1": 0.0,
17
- "topic_loss": 0.807373046875,
18
- "topic_accuracy": 0.85,
19
  "epoch": 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  }
21
  }
 
1
  {
2
  "train_epoch_1": {
3
+ "summarization_loss": 3.222269726091524,
4
+ "summarization_rouge_like": 0.4348834303103812,
5
+ "emotion_loss": 0.2681197640352259,
6
+ "emotion_f1": 0.4939010590246358,
7
+ "topic_loss": 0.2817161389551497,
8
+ "topic_accuracy": 0.9126178087058748,
9
+ "total_loss": 3.7721057520380095,
10
  "epoch": 1.0
11
  },
12
  "val_epoch_1": {
13
+ "summarization_loss": 2.9376416314440097,
14
+ "summarization_rouge_like": 0.4621969238397049,
15
+ "emotion_loss": 0.07456208207925424,
16
+ "emotion_f1": 0.922451647864638,
17
+ "topic_loss": 0.18789680490184146,
18
+ "topic_accuracy": 0.9368641532016696,
19
  "epoch": 1.0
20
+ },
21
+ "train_epoch_2": {
22
+ "summarization_loss": 3.0815064049717713,
23
+ "summarization_rouge_like": 0.44604443152864864,
24
+ "emotion_loss": 0.04770229796717623,
25
+ "emotion_f1": 0.9407868445694336,
26
+ "topic_loss": 0.1507136240392336,
27
+ "topic_accuracy": 0.9498742677227413,
28
+ "total_loss": 3.279922429068798,
29
+ "epoch": 2.0
30
+ },
31
+ "val_epoch_2": {
32
+ "summarization_loss": 2.8898715693603942,
33
+ "summarization_rouge_like": 0.4654528613816311,
34
+ "emotion_loss": 0.05001389549380918,
35
+ "emotion_f1": 0.9344953305524384,
36
+ "topic_loss": 0.1755385091801308,
37
+ "topic_accuracy": 0.9435966487133395,
38
+ "epoch": 2.0
39
+ },
40
+ "train_epoch_3": {
41
+ "summarization_loss": 3.0340622767404044,
42
+ "summarization_rouge_like": 0.4502876682264882,
43
+ "emotion_loss": 0.025708710505635942,
44
+ "emotion_f1": 0.9647584015837614,
45
+ "topic_loss": 0.11707986947991166,
46
+ "topic_accuracy": 0.9614479064357344,
47
+ "total_loss": 3.176850952874497,
48
+ "epoch": 3.0
49
+ },
50
+ "val_epoch_3": {
51
+ "summarization_loss": 2.865455434181104,
52
+ "summarization_rouge_like": 0.46790124713702563,
53
+ "emotion_loss": 0.05574661032417156,
54
+ "emotion_f1": 0.940105742034193,
55
+ "topic_loss": 0.19245651335709887,
56
+ "topic_accuracy": 0.942998204667858,
57
+ "epoch": 3.0
58
  }
59
  }
scripts/download_data.py CHANGED
@@ -1,4 +1,12 @@
1
- """Download datasets used by LexiMind."""
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
1
+ """
2
+ Dataset download script for LexiMind.
3
+
4
+ Downloads training datasets from various sources including HuggingFace Hub,
5
+ Kaggle, and Project Gutenberg. Handles automatic conversion to JSONL format.
6
+
7
+ Author: Oliver Perrin
8
+ Date: December 2025
9
+ """
10
 
11
  from __future__ import annotations
12
 
scripts/eval_rouge.py CHANGED
@@ -1,4 +1,12 @@
1
- """Utility script to evaluate LexiMind summaries with ROUGE."""
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
1
+ """
2
+ ROUGE evaluation script for LexiMind.
3
+
4
+ Computes ROUGE-1, ROUGE-2, and ROUGE-L scores on summarization outputs
5
+ with support for batched inference and customizable metrics.
6
+
7
+ Author: Oliver Perrin
8
+ Date: December 2025
9
+ """
10
 
11
  from __future__ import annotations
12
 
scripts/evaluate.py CHANGED
@@ -1,6 +1,11 @@
1
  """
2
- Evaluate the multitask model on processed validation/test splits.
3
- This is used for getting definitive scores on my test set after training is complete.
 
 
 
 
 
4
  """
5
 
6
  from __future__ import annotations
@@ -8,9 +13,12 @@ from __future__ import annotations
8
  import argparse
9
  import json
10
  import sys
 
11
  from pathlib import Path
12
- from typing import Any, List, cast
13
 
 
 
14
  import torch
15
  from sklearn.preprocessing import MultiLabelBinarizer
16
  from tqdm import tqdm
@@ -19,14 +27,7 @@ PROJECT_ROOT = Path(__file__).resolve().parents[1]
19
  if str(PROJECT_ROOT) not in sys.path:
20
  sys.path.insert(0, str(PROJECT_ROOT))
21
 
22
- import matplotlib.pyplot as plt
23
- import seaborn as sns
24
-
25
- from src.data.dataset import (
26
- load_emotion_jsonl,
27
- load_summarization_jsonl,
28
- load_topic_jsonl,
29
- )
30
  from src.inference.factory import create_inference_pipeline
31
  from src.training.metrics import (
32
  accuracy,
@@ -38,80 +39,67 @@ from src.training.metrics import (
38
  )
39
  from src.utils.config import load_yaml
40
 
41
- SPLIT_ALIASES = {
42
- "train": ("train",),
43
- "val": ("val", "validation"),
44
- "test": ("test",),
45
- }
46
 
 
47
 
48
- def _read_split(root: Path, split: str, loader) -> List[Any]:
49
- aliases = SPLIT_ALIASES.get(split, (split,))
50
- for alias in aliases:
 
51
  for ext in ("jsonl", "json"):
52
- candidate = root / f"{alias}.{ext}"
53
- if candidate.exists():
54
- return cast(List[Any], loader(str(candidate)))
55
- raise FileNotFoundError(f"Missing {split} split under {root}")
56
 
57
 
58
- def parse_args() -> argparse.Namespace:
59
- parser = argparse.ArgumentParser(description="Evaluate the LexiMind multitask model")
60
- parser.add_argument(
61
- "--split",
62
- default="val",
63
- choices=["train", "val", "test"],
64
- help="Dataset split to evaluate.",
65
- )
66
- parser.add_argument(
67
- "--checkpoint", default="checkpoints/best.pt", help="Path to the trained checkpoint."
68
- )
69
- parser.add_argument("--labels", default="artifacts/labels.json", help="Label metadata JSON.")
70
- parser.add_argument(
71
- "--data-config", default="configs/data/datasets.yaml", help="Data configuration YAML."
72
- )
73
- parser.add_argument(
74
- "--model-config", default="configs/model/base.yaml", help="Model architecture YAML."
75
- )
76
- parser.add_argument(
77
- "--device",
78
- default="cuda" if torch.cuda.is_available() else "cpu",
79
- help="Device for evaluation.",
80
- )
81
- parser.add_argument(
82
- "--batch-size",
83
- type=int,
84
- default=16,
85
- help="Batch size for generation/classification during evaluation.",
86
- )
87
- parser.add_argument(
88
- "--output-dir", default="outputs", help="Directory to save evaluation artifacts."
89
- )
90
- return parser.parse_args()
91
 
92
 
93
- def chunks(items: List, size: int):
94
- for start in range(0, len(items), size):
95
- yield items[start : start + size]
96
 
97
 
98
- def plot_confusion_matrix(cm, labels, output_path):
 
99
  plt.figure(figsize=(10, 8))
100
  sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
101
  plt.xlabel("Predicted")
102
  plt.ylabel("True")
103
  plt.title("Topic Classification Confusion Matrix")
104
  plt.tight_layout()
105
- plt.savefig(output_path)
106
  plt.close()
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def main() -> None:
110
  args = parse_args()
111
- data_cfg = load_yaml(args.data_config).data
 
112
  output_dir = Path(args.output_dir)
113
  output_dir.mkdir(parents=True, exist_ok=True)
114
 
 
 
115
  pipeline, metadata = create_inference_pipeline(
116
  checkpoint_path=args.checkpoint,
117
  labels_path=args.labels,
@@ -120,100 +108,94 @@ def main() -> None:
120
  device=args.device,
121
  )
122
 
123
- summarization_dir = Path(data_cfg["processed"]["summarization"])
124
- emotion_dir = Path(data_cfg["processed"]["emotion"])
125
- topic_dir = Path(data_cfg["processed"]["topic"])
126
-
127
- summary_examples = _read_split(summarization_dir, args.split, load_summarization_jsonl)
128
- emotion_examples = _read_split(emotion_dir, args.split, load_emotion_jsonl)
129
- topic_examples = _read_split(topic_dir, args.split, load_topic_jsonl)
130
-
131
- emotion_binarizer = MultiLabelBinarizer(classes=metadata.emotion)
132
- # Ensure scikit-learn initializes the attributes using metadata ordering.
133
- emotion_binarizer.fit([[label] for label in metadata.emotion])
134
-
135
- # Summarization
136
- print("Evaluating Summarization...")
137
- summaries_pred = []
138
- summaries_ref = []
139
- total_batches = (len(summary_examples) + args.batch_size - 1) // args.batch_size
140
- for batch in tqdm(
141
- chunks(summary_examples, args.batch_size),
142
- total=total_batches,
143
- desc="Summarization",
144
- unit="batch",
145
- ):
146
- inputs = [example.source for example in batch]
147
- summaries_pred.extend(pipeline.summarize(inputs))
148
- summaries_ref.extend([example.summary for example in batch])
149
-
150
- rouge_score = rouge_like(summaries_pred, summaries_ref)
151
- bleu_score = calculate_bleu(summaries_pred, summaries_ref)
152
-
153
- # Emotion
154
- print("Evaluating Emotion Classification...")
155
- emotion_preds_tensor = []
156
- emotion_target_tensor = []
157
- label_to_index = {label: idx for idx, label in enumerate(metadata.emotion)}
158
- total_batches = (len(emotion_examples) + args.batch_size - 1) // args.batch_size
159
-
160
- # Lower threshold to 0.3 to catch weak signals, or use argmax if appropriate
161
- # For now, we'll stick to thresholding but lower it.
162
- inference_threshold = 0.3
163
-
164
- for batch in tqdm(
165
- chunks(emotion_examples, args.batch_size), total=total_batches, desc="Emotion", unit="batch"
166
- ):
167
- inputs = [example.text for example in batch]
168
- predictions = pipeline.predict_emotions(inputs, threshold=inference_threshold)
169
- target_matrix = emotion_binarizer.transform([list(example.emotions) for example in batch])
170
- for pred, target_row in zip(predictions, target_matrix, strict=False):
171
- vector = torch.zeros(len(metadata.emotion), dtype=torch.float32)
172
- for label in pred.labels:
173
- idx = label_to_index.get(label)
174
- if idx is not None:
175
- vector[idx] = 1.0
176
- emotion_preds_tensor.append(vector)
177
- emotion_target_tensor.append(torch.tensor(target_row, dtype=torch.float32))
178
-
179
- emotion_f1 = multilabel_f1(
180
- torch.stack(emotion_preds_tensor), torch.stack(emotion_target_tensor)
 
 
 
181
  )
 
 
182
 
183
- # Topic
184
- print("Evaluating Topic Classification...")
185
- topic_preds = []
186
- topic_targets = []
187
- total_batches = (len(topic_examples) + args.batch_size - 1) // args.batch_size
188
- for batch in tqdm(
189
- chunks(topic_examples, args.batch_size), total=total_batches, desc="Topic", unit="batch"
190
- ):
191
- inputs = [example.text for example in batch]
192
- topic_predictions = pipeline.predict_topics(inputs)
193
- topic_preds.extend([pred.label for pred in topic_predictions])
194
- topic_targets.extend([example.topic for example in batch])
195
-
196
- topic_accuracy = accuracy(topic_preds, topic_targets)
197
- topic_report = classification_report_dict(topic_preds, topic_targets, labels=metadata.topic)
198
- topic_cm = get_confusion_matrix(topic_preds, topic_targets, labels=metadata.topic)
199
-
200
- # Save Confusion Matrix
201
  cm_path = output_dir / "topic_confusion_matrix.png"
202
  plot_confusion_matrix(topic_cm, metadata.topic, cm_path)
203
- print(f"Confusion matrix saved to {cm_path}")
 
 
204
 
205
  results = {
206
  "split": args.split,
207
- "summarization": {"rouge_like": rouge_score, "bleu": bleu_score},
208
  "emotion": {"f1_macro": emotion_f1},
209
- "topic": {"accuracy": topic_accuracy, "classification_report": topic_report},
210
  }
211
 
212
  report_path = output_dir / "evaluation_report.json"
213
- with open(report_path, "w", encoding="utf-8") as f:
214
  json.dump(results, f, indent=2)
215
 
216
- print(f"Evaluation complete. Report saved to {report_path}")
 
 
 
 
217
  print(json.dumps(results, indent=2))
218
 
219
 
 
1
  """
2
+ Evaluation script for LexiMind.
3
+
4
+ Computes ROUGE/BLEU for summarization, multi-label F1 for emotion,
5
+ and accuracy with confusion matrix for topic classification.
6
+
7
+ Author: Oliver Perrin
8
+ Date: December 2025
9
  """
10
 
11
  from __future__ import annotations
 
13
  import argparse
14
  import json
15
  import sys
16
+ import time
17
  from pathlib import Path
18
+ from typing import Any, Callable, List
19
 
20
+ import matplotlib.pyplot as plt
21
+ import seaborn as sns
22
  import torch
23
  from sklearn.preprocessing import MultiLabelBinarizer
24
  from tqdm import tqdm
 
27
  if str(PROJECT_ROOT) not in sys.path:
28
  sys.path.insert(0, str(PROJECT_ROOT))
29
 
30
+ from src.data.dataset import load_emotion_jsonl, load_summarization_jsonl, load_topic_jsonl
 
 
 
 
 
 
 
31
  from src.inference.factory import create_inference_pipeline
32
  from src.training.metrics import (
33
  accuracy,
 
39
  )
40
  from src.utils.config import load_yaml
41
 
42
+ # --------------- Data Loading ---------------
 
 
 
 
43
 
44
+ SPLIT_ALIASES = {"train": ("train",), "val": ("val", "validation"), "test": ("test",)}
45
 
46
+
47
+ def load_split(root: Path, split: str, loader: Callable[[str], List[Any]]) -> List[Any]:
48
+ """Load a dataset split, checking aliases."""
49
+ for alias in SPLIT_ALIASES.get(split, (split,)):
50
  for ext in ("jsonl", "json"):
51
+ path = root / f"{alias}.{ext}"
52
+ if path.exists():
53
+ return list(loader(str(path)))
54
+ raise FileNotFoundError(f"Missing {split} split in {root}")
55
 
56
 
57
+ def chunks(items: List, size: int):
58
+ """Yield batches of items."""
59
+ for i in range(0, len(items), size):
60
+ yield items[i : i + size]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
+ # --------------- Visualization ---------------
 
 
64
 
65
 
66
+ def plot_confusion_matrix(cm, labels, path: Path) -> None:
67
+ """Save confusion matrix heatmap."""
68
  plt.figure(figsize=(10, 8))
69
  sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
70
  plt.xlabel("Predicted")
71
  plt.ylabel("True")
72
  plt.title("Topic Classification Confusion Matrix")
73
  plt.tight_layout()
74
+ plt.savefig(path)
75
  plt.close()
76
 
77
 
78
+ # --------------- Main ---------------
79
+
80
+
81
+ def parse_args() -> argparse.Namespace:
82
+ p = argparse.ArgumentParser(description="Evaluate LexiMind")
83
+ p.add_argument("--split", default="val", choices=["train", "val", "test"])
84
+ p.add_argument("--checkpoint", default="checkpoints/best.pt")
85
+ p.add_argument("--labels", default="artifacts/labels.json")
86
+ p.add_argument("--data-config", default="configs/data/datasets.yaml")
87
+ p.add_argument("--model-config", default="configs/model/base.yaml")
88
+ p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
89
+ p.add_argument("--batch-size", type=int, default=148) # Larger batch for inference (no grads)
90
+ p.add_argument("--output-dir", default="outputs")
91
+ return p.parse_args()
92
+
93
+
94
  def main() -> None:
95
  args = parse_args()
96
+ start_time = time.perf_counter()
97
+
98
  output_dir = Path(args.output_dir)
99
  output_dir.mkdir(parents=True, exist_ok=True)
100
 
101
+ # Load pipeline
102
+ print("Loading model...")
103
  pipeline, metadata = create_inference_pipeline(
104
  checkpoint_path=args.checkpoint,
105
  labels_path=args.labels,
 
108
  device=args.device,
109
  )
110
 
111
+ # Load data
112
+ data_cfg = load_yaml(args.data_config).data
113
+ summ_data = load_split(
114
+ Path(data_cfg["processed"]["summarization"]), args.split, load_summarization_jsonl
115
+ )
116
+ emot_data = load_split(Path(data_cfg["processed"]["emotion"]), args.split, load_emotion_jsonl)
117
+ topic_data = load_split(Path(data_cfg["processed"]["topic"]), args.split, load_topic_jsonl)
118
+
119
+ print(f"\nEvaluating on {args.split} split:")
120
+ print(f" Summarization: {len(summ_data)} samples")
121
+ print(f" Emotion: {len(emot_data)} samples")
122
+ print(f" Topic: {len(topic_data)} samples")
123
+
124
+ # --------------- Summarization ---------------
125
+
126
+ print("\nSummarization...")
127
+ preds, refs = [], []
128
+ for batch in tqdm(list(chunks(summ_data, args.batch_size)), desc="Summarization", unit="batch"):
129
+ preds.extend(pipeline.summarize([ex.source for ex in batch]))
130
+ refs.extend([ex.summary for ex in batch])
131
+
132
+ rouge = rouge_like(preds, refs)
133
+ bleu = calculate_bleu(preds, refs)
134
+ print(f" ROUGE-like: {rouge:.4f}, BLEU: {bleu:.4f}")
135
+
136
+ # --------------- Emotion ---------------
137
+
138
+ print("\nEmotion Classification...")
139
+ binarizer = MultiLabelBinarizer(classes=metadata.emotion)
140
+ binarizer.fit([[label] for label in metadata.emotion])
141
+ label_idx = {label: i for i, label in enumerate(metadata.emotion)}
142
+
143
+ pred_vecs, target_vecs = [], []
144
+ for batch in tqdm(list(chunks(emot_data, args.batch_size)), desc="Emotion", unit="batch"):
145
+ emotion_results = pipeline.predict_emotions([ex.text for ex in batch], threshold=0.3)
146
+ targets = binarizer.transform([list(ex.emotions) for ex in batch])
147
+
148
+ for pred, target in zip(emotion_results, targets, strict=False):
149
+ vec = torch.zeros(len(metadata.emotion))
150
+ for lbl in pred.labels:
151
+ if lbl in label_idx:
152
+ vec[label_idx[lbl]] = 1.0
153
+ pred_vecs.append(vec)
154
+ target_vecs.append(torch.tensor(target, dtype=torch.float32))
155
+
156
+ emotion_f1 = multilabel_f1(torch.stack(pred_vecs), torch.stack(target_vecs))
157
+ print(f" F1 (macro): {emotion_f1:.4f}")
158
+
159
+ # --------------- Topic ---------------
160
+
161
+ print("\nTopic Classification...")
162
+ topic_pred_labels: List[str] = []
163
+ topic_true_labels: List[str] = []
164
+ for batch in tqdm(list(chunks(topic_data, args.batch_size)), desc="Topic", unit="batch"):
165
+ topic_results = pipeline.predict_topics([ex.text for ex in batch])
166
+ topic_pred_labels.extend([r.label for r in topic_results])
167
+ topic_true_labels.extend([ex.topic for ex in batch])
168
+
169
+ topic_acc = accuracy(topic_pred_labels, topic_true_labels)
170
+ topic_report = classification_report_dict(
171
+ topic_pred_labels, topic_true_labels, labels=metadata.topic
172
  )
173
+ topic_cm = get_confusion_matrix(topic_pred_labels, topic_true_labels, labels=metadata.topic)
174
+ print(f" Accuracy: {topic_acc:.4f}")
175
 
176
+ # Save confusion matrix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  cm_path = output_dir / "topic_confusion_matrix.png"
178
  plot_confusion_matrix(topic_cm, metadata.topic, cm_path)
179
+ print(f" Confusion matrix saved: {cm_path}")
180
+
181
+ # --------------- Save Results ---------------
182
 
183
  results = {
184
  "split": args.split,
185
+ "summarization": {"rouge_like": rouge, "bleu": bleu},
186
  "emotion": {"f1_macro": emotion_f1},
187
+ "topic": {"accuracy": topic_acc, "classification_report": topic_report},
188
  }
189
 
190
  report_path = output_dir / "evaluation_report.json"
191
+ with open(report_path, "w") as f:
192
  json.dump(results, f, indent=2)
193
 
194
+ total_time = time.perf_counter() - start_time
195
+ print(f"\n{'=' * 50}")
196
+ print(f"Evaluation complete in {total_time:.1f}s")
197
+ print(f"Report saved: {report_path}")
198
+ print(f"{'=' * 50}")
199
  print(json.dumps(results, indent=2))
200
 
201
 
scripts/export_model.py CHANGED
@@ -1,4 +1,12 @@
1
- """Rebuild and export the trained multitask model for downstream use."""
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
1
+ """
2
+ Model export script for LexiMind.
3
+
4
+ Rebuilds the multitask model from configuration and exports trained weights
5
+ for deployment or distribution.
6
+
7
+ Author: Oliver Perrin
8
+ Date: December 2025
9
+ """
10
 
11
  from __future__ import annotations
12
 
scripts/export_tokenizer.py CHANGED
@@ -1,4 +1,12 @@
1
- """Export the FLAN-T5 tokenizer to the artifacts directory for reproducible inference."""
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
1
+ """
2
+ Tokenizer export script for LexiMind.
3
+
4
+ Saves the FLAN-T5 tokenizer to the artifacts directory for reproducible
5
+ inference without requiring network access.
6
+
7
+ Author: Oliver Perrin
8
+ Date: December 2025
9
+ """
10
 
11
  from __future__ import annotations
12
 
scripts/inference.py CHANGED
@@ -1,4 +1,12 @@
1
- """Run inference with the multitask model."""
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
1
+ """
2
+ Inference script for the LexiMind multitask model.
3
+
4
+ Command-line interface for running summarization, emotion detection, and topic
5
+ classification on arbitrary text inputs.
6
+
7
+ Author: Oliver Perrin
8
+ Date: December 2025
9
+ """
10
 
11
  from __future__ import annotations
12
 
scripts/preprocess_data.py CHANGED
@@ -1,4 +1,13 @@
1
- """Preprocess raw datasets into JSONL splits for LexiMind training."""
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
1
+ """
2
+ Data preprocessing script for LexiMind.
3
+
4
+ Transforms raw datasets into standardized JSONL splits for training. Handles
5
+ summarization, emotion classification, topic classification, and book paragraph
6
+ extraction with text cleaning.
7
+
8
+ Author: Oliver Perrin
9
+ Date: December 2025
10
+ """
11
 
12
  from __future__ import annotations
13
 
scripts/train.py CHANGED
@@ -1,13 +1,20 @@
1
- """End-to-end training entrypoint for the LexiMind multitask model."""
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
  import json
6
- import platform
7
  import sys
8
- import warnings
9
  from pathlib import Path
10
- from typing import Any, Dict, Sequence, Tuple, cast
11
 
12
  import hydra
13
  import torch
@@ -37,8 +44,7 @@ from src.training.utils import set_seed
37
  from src.utils.io import save_state
38
  from src.utils.labels import LabelMetadata, save_label_metadata
39
 
40
- SplitExamples = Dict[str, list]
41
-
42
 
43
  SPLIT_ALIASES: Dict[str, Sequence[str]] = {
44
  "train": ("train",),
@@ -47,286 +53,214 @@ SPLIT_ALIASES: Dict[str, Sequence[str]] = {
47
  }
48
 
49
 
50
- def _read_examples(data_dir: Path, loader) -> SplitExamples:
51
- splits: SplitExamples = {}
52
- for canonical, aliases in SPLIT_ALIASES.items():
53
- found = False
54
  for alias in aliases:
55
- for extension in ("jsonl", "json"):
56
- candidate = data_dir / f"{alias}.{extension}"
57
- if candidate.exists():
58
- splits[canonical] = loader(str(candidate))
59
- found = True
60
  break
61
- if found:
62
  break
63
- if not found:
64
- raise FileNotFoundError(f"Missing {canonical} split under {data_dir}")
65
  return splits
66
 
67
 
68
- def _limit_samples(splits: SplitExamples, trainer_cfg: DictConfig) -> None:
69
- """Limit the number of samples in train/val splits if configured."""
70
- max_train = trainer_cfg.get("max_train_samples")
71
- max_val = trainer_cfg.get("max_val_samples")
72
-
73
- if max_train is not None and "train" in splits:
74
- original_len = len(splits["train"])
75
- limit = int(max_train)
76
- if original_len > limit:
77
- splits["train"] = splits["train"][:limit]
78
- print(f"Limited 'train' split from {original_len} to {limit} samples")
79
 
80
- if max_val is not None and "val" in splits:
81
- original_len = len(splits["val"])
82
- limit = int(max_val)
83
- if original_len > limit:
84
- splits["val"] = splits["val"][:limit]
85
- print(f"Limited 'val' split from {original_len} to {limit} samples")
86
 
 
87
 
88
- def compile_model_safe(model: torch.nn.Module) -> Tuple[Any, str]:
89
- """
90
- Safely compile model with best available backend.
91
 
92
- Returns:
93
- Compiled model and backend name used
94
- """
95
- system = platform.system()
 
 
 
 
96
 
97
- # NOTE: The 'inductor' backend causes NaN gradients during backward pass with
98
- # bfloat16 autocast on the decoder (seq2seq tasks). This is a known issue.
99
- # Use 'aot_eager' which provides graph optimization without inductor's codegen.
100
- # See: debug_compile_config.py and test_compile_modes.py for investigation.
101
 
102
- # Try aot_eager first - it's stable and provides good speedup
103
- try:
104
- print("Attempting to compile with 'aot_eager' backend...")
105
- compiled_model = torch.compile(model, backend="aot_eager")
106
- print("✓ Successfully compiled with 'aot_eager' backend")
107
- return cast(torch.nn.Module, compiled_model), "aot_eager"
108
- except Exception as e:
109
- warnings.warn(f"aot_eager backend failed: {e}", stacklevel=2)
110
-
111
- # Fallback: Try other backends (inductor may work for encoder-only tasks)
112
- backends_to_try = ["eager"]
113
- if system != "Windows":
114
- # On Linux, inductor might work for some configurations
115
- backends_to_try = ["eager", "inductor"]
116
-
117
- for backend in backends_to_try:
118
- try:
119
- print(f"Attempting to compile with '{backend}' backend...")
120
- compiled_model = torch.compile(model, backend=backend)
121
- # Trigger a dummy run or just return? torch.compile is lazy.
122
- # I assume it works if the call succeeds, runtime errors handled later.
123
- print(f"✓ Successfully compiled with '{backend}' backend")
124
- return cast(torch.nn.Module, compiled_model), backend
125
- except Exception as e:
126
- print(f"✗ '{backend}' backend failed: {e}")
127
- continue
128
-
129
- # No compilation worked, return original model
130
- warnings.warn("All torch.compile backends failed, using uncompiled model", stacklevel=2)
131
- return model, "none"
132
 
133
 
134
  @hydra.main(version_base=None, config_path="../configs", config_name="config")
135
  def main(cfg: DictConfig) -> None:
 
136
  print(OmegaConf.to_yaml(cfg))
137
  set_seed(cfg.seed)
138
 
139
- # Enable TF32 for Ampere/Ada GPUs (RTX 30xx/40xx)
140
- # This provides significant speedup on RTX 4070
141
  if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
142
- print("Enabling TF32 for Ampere/Ada GPU...")
143
  torch.set_float32_matmul_precision("high")
144
  torch.backends.cuda.matmul.allow_tf32 = True
145
  torch.backends.cudnn.allow_tf32 = True
146
- torch.backends.cudnn.benchmark = True # Auto-tunes convolution algorithms
 
 
 
 
 
 
 
 
 
147
 
148
- # Access configs directly from Hydra cfg object
149
  data_cfg = cfg.data
150
- training_cfg = cfg.training
151
 
152
- # Instantiate ModelConfig directly from cfg.model
153
- model_cfg = ModelConfig(
154
- d_model=cfg.model.d_model,
155
- num_encoder_layers=cfg.model.num_encoder_layers,
156
- num_decoder_layers=cfg.model.num_decoder_layers,
157
- num_attention_heads=cfg.model.num_attention_heads,
158
- ffn_dim=cfg.model.ffn_dim,
159
- dropout=cfg.model.dropout,
160
- use_pretrained=cfg.model.use_pretrained,
161
- pretrained_model_name=cfg.model.pretrained_model_name,
162
- activation=getattr(cfg.model, "activation", "gelu"),
163
- use_relative_position_bias=getattr(cfg.model, "use_relative_position_bias", False),
164
- )
165
 
166
- summarization_dir = Path(data_cfg.processed.summarization)
167
- emotion_dir = Path(data_cfg.processed.emotion)
168
- topic_dir = Path(data_cfg.processed.topic)
169
-
170
- summarization_splits = _read_examples(summarization_dir, load_summarization_jsonl)
171
- emotion_splits = _read_examples(emotion_dir, load_emotion_jsonl)
172
- topic_splits = _read_examples(topic_dir, load_topic_jsonl)
173
-
174
- # Apply sample limits if configured (e.g. for dev/medium runs)
175
- trainer_cfg = training_cfg.get("trainer", {})
176
- print("\nApplying dataset limits...")
177
- _limit_samples(summarization_splits, trainer_cfg)
178
- _limit_samples(emotion_splits, trainer_cfg)
179
- _limit_samples(topic_splits, trainer_cfg)
180
- print("Dataset limits applied.\n")
181
-
182
- tokenizer_section = data_cfg.get("tokenizer", {})
183
- tokenizer_config = TokenizerConfig(
184
- pretrained_model_name=tokenizer_section.get("pretrained_model_name", "google/flan-t5-base"),
185
- max_length=int(tokenizer_section.get("max_length", 512)),
186
- lower=bool(tokenizer_section.get("lower", False)),
187
- )
188
- tokenizer = Tokenizer(tokenizer_config)
189
 
190
- summarization_train = SummarizationDataset(summarization_splits["train"])
191
- summarization_val = SummarizationDataset(summarization_splits["val"])
192
 
193
- emotion_train = EmotionDataset(emotion_splits["train"])
194
- emotion_val = EmotionDataset(emotion_splits["val"], binarizer=emotion_train.binarizer)
 
 
 
 
 
 
195
 
 
 
 
 
196
  topic_train = TopicDataset(topic_splits["train"])
197
  topic_val = TopicDataset(topic_splits["val"], encoder=topic_train.encoder)
198
 
199
- dataloader_args = training_cfg.get("dataloader", {})
200
- batch_size = int(dataloader_args.get("batch_size", 8))
201
- shuffle = bool(dataloader_args.get("shuffle", True))
202
- # Optimization: Use multiple workers and pinned memory for faster data transfer
203
- num_workers = int(dataloader_args.get("num_workers", 4))
204
- pin_memory = bool(dataloader_args.get("pin_memory", True))
205
- max_length = tokenizer.config.max_length
206
 
207
  train_loaders = {
208
  "summarization": build_summarization_dataloader(
209
- summarization_train,
210
  tokenizer,
 
 
 
211
  batch_size=batch_size,
212
- shuffle=shuffle,
213
- max_source_length=max_length,
214
- max_target_length=max_length,
215
  num_workers=num_workers,
216
  pin_memory=pin_memory,
217
  ),
218
  "emotion": build_emotion_dataloader(
219
- emotion_train,
220
  tokenizer,
 
 
221
  batch_size=batch_size,
222
- shuffle=shuffle,
223
- max_length=max_length,
224
  num_workers=num_workers,
225
  pin_memory=pin_memory,
226
  ),
227
  "topic": build_topic_dataloader(
228
  topic_train,
229
  tokenizer,
 
 
230
  batch_size=batch_size,
231
- shuffle=shuffle,
232
- max_length=max_length,
233
  num_workers=num_workers,
234
  pin_memory=pin_memory,
235
  ),
236
  }
237
-
238
  val_loaders = {
239
  "summarization": build_summarization_dataloader(
240
- summarization_val,
241
  tokenizer,
242
- batch_size=batch_size,
243
  shuffle=False,
244
- max_source_length=max_length,
245
- max_target_length=max_length,
 
246
  num_workers=num_workers,
247
  pin_memory=pin_memory,
248
  ),
249
  "emotion": build_emotion_dataloader(
250
- emotion_val,
251
  tokenizer,
252
- batch_size=batch_size,
253
  shuffle=False,
254
- max_length=max_length,
 
255
  num_workers=num_workers,
256
  pin_memory=pin_memory,
257
  ),
258
  "topic": build_topic_dataloader(
259
  topic_val,
260
  tokenizer,
261
- batch_size=batch_size,
262
  shuffle=False,
263
- max_length=max_length,
 
264
  num_workers=num_workers,
265
  pin_memory=pin_memory,
266
  ),
267
  }
268
 
 
 
 
269
  device = torch.device(cfg.device)
 
 
 
 
 
 
 
 
 
 
 
 
270
  model = build_multitask_model(
271
  tokenizer,
272
- num_emotions=len(emotion_train.emotion_classes),
273
  num_topics=len(topic_train.topic_classes),
274
  config=model_cfg,
275
  ).to(device)
276
 
277
- optimizer_cfg = training_cfg.get("optimizer", {})
278
- lr = float(optimizer_cfg.get("lr", 3.0e-5))
279
- # Add weight decay for regularization to prevent overfitting
280
- weight_decay = float(optimizer_cfg.get("weight_decay", 0.01))
281
- optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
282
-
283
- # Optimize model execution graph with torch.compile (PyTorch 2.0+)
284
- # This fuses kernels and reduces overhead for faster training
285
- # Note: We only compile encoder/decoder for training, not the step() method used in generation
286
- # Compile encoder and decoder separately to avoid control flow issues in MultiTaskModel.forward
287
- # Compiling the top-level model causes excessive recompilation due to task switching
288
- use_compile = True # torch.compile for faster training
289
-
290
- if use_compile and model.encoder is not None:
291
- model.encoder, backend_used = compile_model_safe(model.encoder)
292
- else:
293
- backend_used = "disabled"
294
- if use_compile and model.decoder is not None:
295
- # Compile decoder.forward but keep step/greedy_decode uncompiled for generation
296
- model.decoder, _ = compile_model_safe(model.decoder)
297
-
298
- # Compile heads
299
- if use_compile:
300
- for name, head in model.heads.items():
301
- compiled_head, _ = compile_model_safe(head)
302
- model.heads[name] = compiled_head
303
- # Update the registered module as well to ensure parameters are tracked correctly
304
- setattr(model, f"head_{name}", compiled_head)
305
-
306
- print(f"Using compilation backend: {backend_used}")
307
-
308
- # Verify weights loaded correctly (check for NaNs/Infs)
309
- print("\n=== Weight Loading Verification ===")
310
- has_issues = False
311
- for name, param in model.named_parameters():
312
- if torch.isnan(param).any():
313
- print(f"WARNING: NaN in {name}")
314
- has_issues = True
315
- if torch.isinf(param).any():
316
- print(f"WARNING: Inf in {name}")
317
- has_issues = True
318
- if not has_issues:
319
- print("✓ No NaNs or Infs found in model parameters.")
320
- print("=== Verification Complete ===\n")
321
-
322
- trainer_cfg = training_cfg.get("trainer", {})
323
  trainer = Trainer(
324
  model=model,
325
  optimizer=optimizer,
326
  config=TrainerConfig(
327
  max_epochs=int(trainer_cfg.get("max_epochs", 1)),
328
  gradient_clip_norm=float(trainer_cfg.get("gradient_clip_norm", 1.0)),
329
- logging_interval=int(trainer_cfg.get("logging_interval", 50)),
330
  task_weights=trainer_cfg.get("task_weights"),
331
  label_smoothing=float(trainer_cfg.get("label_smoothing", 0.0)),
332
  gradient_accumulation_steps=int(trainer_cfg.get("gradient_accumulation_steps", 1)),
@@ -335,61 +269,43 @@ def main(cfg: DictConfig) -> None:
335
  tokenizer=tokenizer,
336
  )
337
 
338
- # Save checkpoint after every epoch to avoid losing good early checkpoints
339
- # Previous training showed overfitting at epoch 5 but good results at epoch 3
340
- def save_epoch_checkpoint(epoch: int, model: torch.nn.Module, history: Dict) -> None:
341
- epoch_path = Path(cfg.checkpoint_out).parent / f"epoch_{epoch}.pt"
342
- epoch_path.parent.mkdir(parents=True, exist_ok=True)
343
- save_state(model, str(epoch_path))
344
- print(f"Checkpoint saved: {epoch_path}")
345
 
346
- history = trainer.fit(train_loaders, val_loaders, checkpoint_callback=save_epoch_checkpoint)
 
347
 
348
- checkpoint_path = Path(cfg.checkpoint_out)
349
- checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
350
- save_state(model, str(checkpoint_path))
351
 
 
 
 
 
 
 
352
  labels_path = Path(cfg.labels_out)
353
  save_label_metadata(
354
- LabelMetadata(
355
- emotion=emotion_train.emotion_classes,
356
- topic=topic_train.topic_classes,
357
- ),
358
  labels_path,
359
  )
360
 
 
361
  history_path = Path(cfg.history_out)
362
  history_path.parent.mkdir(parents=True, exist_ok=True)
363
- with history_path.open("w", encoding="utf-8") as handle:
364
- json.dump(history, handle, indent=2)
365
-
366
- print(f"Training complete. Checkpoint saved to {checkpoint_path}")
367
- print(f"Label metadata saved to {labels_path}")
368
- print(f"History saved to {history_path}")
369
-
370
- # Run evaluation pipeline
371
- print("\nRunning evaluation pipeline...")
372
- import subprocess
373
-
374
- try:
375
- subprocess.run(
376
- [
377
- sys.executable,
378
- "scripts/evaluate.py",
379
- "--split",
380
- "test", # Evaluate on test set
381
- "--checkpoint",
382
- str(checkpoint_path),
383
- "--labels",
384
- str(labels_path),
385
- "--output-dir",
386
- "outputs",
387
- ],
388
- check=True,
389
- )
390
- print("Evaluation pipeline completed successfully.")
391
- except subprocess.CalledProcessError as e:
392
- print(f"Evaluation pipeline failed with error: {e}")
393
 
394
 
395
  if __name__ == "__main__":
 
1
+ """
2
+ Training script for LexiMind.
3
+
4
+ Orchestrates dataset loading, model construction, torch.compile optimization,
5
+ and multi-task training with checkpoint management.
6
+
7
+ Author: Oliver Perrin
8
+ Date: December 2025
9
+ """
10
 
11
  from __future__ import annotations
12
 
13
  import json
 
14
  import sys
15
+ import time
16
  from pathlib import Path
17
+ from typing import Any, Dict, Sequence
18
 
19
  import hydra
20
  import torch
 
44
  from src.utils.io import save_state
45
  from src.utils.labels import LabelMetadata, save_label_metadata
46
 
47
+ # --------------- Data Loading ---------------
 
48
 
49
  SPLIT_ALIASES: Dict[str, Sequence[str]] = {
50
  "train": ("train",),
 
53
  }
54
 
55
 
56
+ def load_splits(data_dir: Path, loader) -> Dict[str, list]:
57
+ """Load train/val/test splits from data directory."""
58
+ splits = {}
59
+ for name, aliases in SPLIT_ALIASES.items():
60
  for alias in aliases:
61
+ for ext in ("jsonl", "json"):
62
+ path = data_dir / f"{alias}.{ext}"
63
+ if path.exists():
64
+ splits[name] = loader(str(path))
 
65
  break
66
+ if name in splits:
67
  break
68
+ if name not in splits:
69
+ raise FileNotFoundError(f"Missing {name} split in {data_dir}")
70
  return splits
71
 
72
 
73
+ def limit_samples(splits: Dict[str, list], cfg: DictConfig) -> None:
74
+ """Apply sample limits for dev/debug runs."""
75
+ for split, key in [("train", "max_train_samples"), ("val", "max_val_samples")]:
76
+ limit = cfg.get(key)
77
+ if limit and split in splits and len(splits[split]) > limit:
78
+ splits[split] = splits[split][: int(limit)]
79
+ print(f" {split}: limited to {limit} samples")
 
 
 
 
80
 
 
 
 
 
 
 
81
 
82
+ # --------------- Model Compilation ---------------
83
 
 
 
 
84
 
85
+ def compile_model(model: torch.nn.Module) -> Any:
86
+ """Compile model with aot_eager backend (stable, avoids inductor NaN issues)."""
87
+ try:
88
+ compiled = torch.compile(model, backend="aot_eager")
89
+ print("✓ Compiled with aot_eager")
90
+ return compiled
91
+ except Exception:
92
+ return model
93
 
 
 
 
 
94
 
95
+ # --------------- Main ---------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
  @hydra.main(version_base=None, config_path="../configs", config_name="config")
99
  def main(cfg: DictConfig) -> None:
100
+ start_time = time.perf_counter()
101
  print(OmegaConf.to_yaml(cfg))
102
  set_seed(cfg.seed)
103
 
104
+ # Enable TF32 for Ampere+ GPUs (RTX 30xx/40xx) - ~2x matmul speedup
 
105
  if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
106
+ print(" TF32 enabled for Ampere GPU")
107
  torch.set_float32_matmul_precision("high")
108
  torch.backends.cuda.matmul.allow_tf32 = True
109
  torch.backends.cudnn.allow_tf32 = True
110
+ torch.backends.cudnn.benchmark = True # Auto-tune convolutions
111
+ torch.backends.cuda.enable_flash_sdp(True) # Flash attention if available
112
+ torch.backends.cuda.enable_mem_efficient_sdp(True) # Memory-efficient attention
113
+
114
+ # Disable debug APIs for max speed
115
+ torch.autograd.set_detect_anomaly(False)
116
+ torch.autograd.profiler.profile(False)
117
+ torch.autograd.profiler.emit_nvtx(False)
118
+
119
+ # --------------- Load Data ---------------
120
 
 
121
  data_cfg = cfg.data
122
+ trainer_cfg = cfg.training.get("trainer", {})
123
 
124
+ print("\nLoading datasets...")
125
+ summ_splits = load_splits(Path(data_cfg.processed.summarization), load_summarization_jsonl)
126
+ emot_splits = load_splits(Path(data_cfg.processed.emotion), load_emotion_jsonl)
127
+ topic_splits = load_splits(Path(data_cfg.processed.topic), load_topic_jsonl)
 
 
 
 
 
 
 
 
 
128
 
129
+ # Apply dev/debug sample limits
130
+ for splits in [summ_splits, emot_splits, topic_splits]:
131
+ limit_samples(splits, trainer_cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ # --------------- Tokenizer & Datasets ---------------
 
134
 
135
+ tok_cfg = data_cfg.get("tokenizer", {})
136
+ tokenizer = Tokenizer(
137
+ TokenizerConfig(
138
+ pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
139
+ max_length=int(tok_cfg.get("max_length", 512)),
140
+ lower=bool(tok_cfg.get("lower", False)),
141
+ )
142
+ )
143
 
144
+ summ_train = SummarizationDataset(summ_splits["train"])
145
+ summ_val = SummarizationDataset(summ_splits["val"])
146
+ emot_train = EmotionDataset(emot_splits["train"])
147
+ emot_val = EmotionDataset(emot_splits["val"], binarizer=emot_train.binarizer)
148
  topic_train = TopicDataset(topic_splits["train"])
149
  topic_val = TopicDataset(topic_splits["val"], encoder=topic_train.encoder)
150
 
151
+ # --------------- DataLoaders ---------------
152
+
153
+ dl_cfg = cfg.training.get("dataloader", {})
154
+ batch_size = int(dl_cfg.get("batch_size", 8))
155
+ num_workers = int(dl_cfg.get("num_workers", 4))
156
+ pin_memory = bool(dl_cfg.get("pin_memory", True))
157
+ max_len = tokenizer.config.max_length
158
 
159
  train_loaders = {
160
  "summarization": build_summarization_dataloader(
161
+ summ_train,
162
  tokenizer,
163
+ shuffle=True,
164
+ max_source_length=max_len,
165
+ max_target_length=max_len,
166
  batch_size=batch_size,
 
 
 
167
  num_workers=num_workers,
168
  pin_memory=pin_memory,
169
  ),
170
  "emotion": build_emotion_dataloader(
171
+ emot_train,
172
  tokenizer,
173
+ shuffle=True,
174
+ max_length=max_len,
175
  batch_size=batch_size,
 
 
176
  num_workers=num_workers,
177
  pin_memory=pin_memory,
178
  ),
179
  "topic": build_topic_dataloader(
180
  topic_train,
181
  tokenizer,
182
+ shuffle=True,
183
+ max_length=max_len,
184
  batch_size=batch_size,
 
 
185
  num_workers=num_workers,
186
  pin_memory=pin_memory,
187
  ),
188
  }
 
189
  val_loaders = {
190
  "summarization": build_summarization_dataloader(
191
+ summ_val,
192
  tokenizer,
 
193
  shuffle=False,
194
+ max_source_length=max_len,
195
+ max_target_length=max_len,
196
+ batch_size=batch_size,
197
  num_workers=num_workers,
198
  pin_memory=pin_memory,
199
  ),
200
  "emotion": build_emotion_dataloader(
201
+ emot_val,
202
  tokenizer,
 
203
  shuffle=False,
204
+ max_length=max_len,
205
+ batch_size=batch_size,
206
  num_workers=num_workers,
207
  pin_memory=pin_memory,
208
  ),
209
  "topic": build_topic_dataloader(
210
  topic_val,
211
  tokenizer,
 
212
  shuffle=False,
213
+ max_length=max_len,
214
+ batch_size=batch_size,
215
  num_workers=num_workers,
216
  pin_memory=pin_memory,
217
  ),
218
  }
219
 
220
+ # --------------- Model ---------------
221
+
222
+ print("\nBuilding model...")
223
  device = torch.device(cfg.device)
224
+ model_cfg = ModelConfig(
225
+ d_model=cfg.model.d_model,
226
+ num_encoder_layers=cfg.model.num_encoder_layers,
227
+ num_decoder_layers=cfg.model.num_decoder_layers,
228
+ num_attention_heads=cfg.model.num_attention_heads,
229
+ ffn_dim=cfg.model.ffn_dim,
230
+ dropout=cfg.model.dropout,
231
+ use_pretrained=cfg.model.use_pretrained,
232
+ pretrained_model_name=cfg.model.pretrained_model_name,
233
+ activation=getattr(cfg.model, "activation", "gelu"),
234
+ use_relative_position_bias=getattr(cfg.model, "use_relative_position_bias", False),
235
+ )
236
  model = build_multitask_model(
237
  tokenizer,
238
+ num_emotions=len(emot_train.emotion_classes),
239
  num_topics=len(topic_train.topic_classes),
240
  config=model_cfg,
241
  ).to(device)
242
 
243
+ # Compile encoder/decoder for faster training (skip heads - small overhead)
244
+ if model.encoder is not None:
245
+ model.encoder = compile_model(model.encoder)
246
+ if model.decoder is not None:
247
+ model.decoder = compile_model(model.decoder)
248
+
249
+ # --------------- Optimizer & Trainer ---------------
250
+
251
+ opt_cfg = cfg.training.get("optimizer", {})
252
+ optimizer = torch.optim.AdamW(
253
+ model.parameters(),
254
+ lr=float(opt_cfg.get("lr", 3e-5)),
255
+ weight_decay=float(opt_cfg.get("weight_decay", 0.01)),
256
+ )
257
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  trainer = Trainer(
259
  model=model,
260
  optimizer=optimizer,
261
  config=TrainerConfig(
262
  max_epochs=int(trainer_cfg.get("max_epochs", 1)),
263
  gradient_clip_norm=float(trainer_cfg.get("gradient_clip_norm", 1.0)),
 
264
  task_weights=trainer_cfg.get("task_weights"),
265
  label_smoothing=float(trainer_cfg.get("label_smoothing", 0.0)),
266
  gradient_accumulation_steps=int(trainer_cfg.get("gradient_accumulation_steps", 1)),
 
269
  tokenizer=tokenizer,
270
  )
271
 
272
+ # --------------- Train ---------------
273
+
274
+ def save_checkpoint(epoch: int, model: torch.nn.Module, history: Dict) -> None:
275
+ path = Path(cfg.checkpoint_out).parent / f"epoch_{epoch}.pt"
276
+ path.parent.mkdir(parents=True, exist_ok=True)
277
+ save_state(model, str(path))
 
278
 
279
+ print("\nStarting training...")
280
+ history = trainer.fit(train_loaders, val_loaders, checkpoint_callback=save_checkpoint)
281
 
282
+ # --------------- Save Outputs ---------------
 
 
283
 
284
+ # Best checkpoint
285
+ ckpt_path = Path(cfg.checkpoint_out)
286
+ ckpt_path.parent.mkdir(parents=True, exist_ok=True)
287
+ save_state(model, str(ckpt_path))
288
+
289
+ # Labels
290
  labels_path = Path(cfg.labels_out)
291
  save_label_metadata(
292
+ LabelMetadata(emotion=emot_train.emotion_classes, topic=topic_train.topic_classes),
 
 
 
293
  labels_path,
294
  )
295
 
296
+ # History
297
  history_path = Path(cfg.history_out)
298
  history_path.parent.mkdir(parents=True, exist_ok=True)
299
+ with history_path.open("w") as f:
300
+ json.dump(history, f, indent=2)
301
+
302
+ total_time = time.perf_counter() - start_time
303
+ print(f"\n{'=' * 50}")
304
+ print(f"Training complete in {total_time:.1f}s")
305
+ print(f" Checkpoint: {ckpt_path}")
306
+ print(f" Labels: {labels_path}")
307
+ print(f" History: {history_path}")
308
+ print(f"{'=' * 50}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
 
311
  if __name__ == "__main__":
src/api/app.py CHANGED
@@ -1,4 +1,11 @@
1
- """FastAPI application entrypoint."""
 
 
 
 
 
 
 
2
 
3
  from fastapi import FastAPI
4
 
 
1
+ """
2
+ FastAPI application factory for LexiMind.
3
+
4
+ Creates and configures the REST API application.
5
+
6
+ Author: Oliver Perrin
7
+ Date: December 2025
8
+ """
9
 
10
  from fastapi import FastAPI
11
 
src/api/dependencies.py CHANGED
@@ -1,4 +1,11 @@
1
- """Dependency providers for the FastAPI application."""
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
1
+ """
2
+ FastAPI dependency providers for LexiMind.
3
+
4
+ Manages lazy initialization and caching of the inference pipeline.
5
+
6
+ Author: Oliver Perrin
7
+ Date: December 2025
8
+ """
9
 
10
  from __future__ import annotations
11
 
src/api/routes.py CHANGED
@@ -1,4 +1,12 @@
1
- """API routes."""
 
 
 
 
 
 
 
 
2
 
3
  from typing import cast
4
 
 
1
+ """
2
+ API routes for LexiMind.
3
+
4
+ Defines REST endpoints for text analysis including summarization,
5
+ emotion detection, and topic classification.
6
+
7
+ Author: Oliver Perrin
8
+ Date: December 2025
9
+ """
10
 
11
  from typing import cast
12
 
src/api/schemas.py CHANGED
@@ -1,4 +1,11 @@
1
- """API schemas."""
 
 
 
 
 
 
 
2
 
3
  from pydantic import BaseModel
4
 
 
1
+ """
2
+ Pydantic schemas for LexiMind API.
3
+
4
+ Defines request and response models for the REST API.
5
+
6
+ Author: Oliver Perrin
7
+ Date: December 2025
8
+ """
9
 
10
  from pydantic import BaseModel
11
 
src/data/dataloader.py CHANGED
@@ -1,8 +1,15 @@
1
- """Task-aware DataLoader builders for the LexiMind multitask suite."""
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
- from typing import List
6
 
7
  import torch
8
  from torch.utils.data import DataLoader
@@ -17,9 +24,11 @@ from .dataset import (
17
  )
18
  from .tokenization import Tokenizer
19
 
 
 
20
 
21
  class SummarizationCollator:
22
- """Prepare encoder-decoder batches for abstractive summarization."""
23
 
24
  def __init__(
25
  self,
@@ -32,36 +41,24 @@ class SummarizationCollator:
32
  self.max_source_length = max_source_length
33
  self.max_target_length = max_target_length
34
 
35
- def __call__(self, batch: List[SummarizationExample]) -> dict[str, torch.Tensor]:
36
- sources = [example.source for example in batch]
37
- targets = [example.summary for example in batch]
38
 
39
- source_enc = self.tokenizer.batch_encode(sources, max_length=self.max_source_length)
40
- target_enc = self.tokenizer.batch_encode(targets, max_length=self.max_target_length)
41
 
42
- # target_enc["input_ids"] is [BOS, A, B, EOS, PAD...]
43
- # We want:
44
- # tgt_ids (decoder input): [BOS, A, B, EOS] (drop last PAD or EOS if full)
45
- # labels (target): [A, B, EOS, PAD] (drop first BOS)
46
 
47
- ids = target_enc["input_ids"]
48
- mask = target_enc["attention_mask"]
49
-
50
- # Slice to create shifted inputs/targets
51
- # tgt_ids: everything except the last token
52
  tgt_ids = ids[:, :-1]
53
-
54
- # labels: everything except the first token (BOS)
55
  labels = ids[:, 1:].clone()
56
-
57
- # Adjust mask for labels to ignore padding
58
- # The mask corresponds to the original ids. We slice it to match labels.
59
- labels_mask = mask[:, 1:]
60
- labels[labels_mask == 0] = -100
61
 
62
  return {
63
- "src_ids": source_enc["input_ids"],
64
- "src_mask": source_enc["attention_mask"],
65
  "tgt_ids": tgt_ids,
66
  "labels": labels,
67
  }
@@ -77,11 +74,13 @@ class EmotionCollator:
77
  self.binarizer = dataset.binarizer
78
  self.max_length = max_length
79
 
80
- def __call__(self, batch: List[EmotionExample]) -> dict[str, torch.Tensor]:
81
- texts = [example.text for example in batch]
82
  encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
83
- label_array = self.binarizer.transform([example.emotions for example in batch])
84
- labels = torch.as_tensor(label_array, dtype=torch.float32)
 
 
85
  return {
86
  "input_ids": encoded["input_ids"],
87
  "attention_mask": encoded["attention_mask"],
@@ -90,7 +89,7 @@ class EmotionCollator:
90
 
91
 
92
  class TopicCollator:
93
- """Prepare batches for topic classification using the projection head."""
94
 
95
  def __init__(
96
  self, tokenizer: Tokenizer, dataset: TopicDataset, *, max_length: int | None = None
@@ -99,11 +98,12 @@ class TopicCollator:
99
  self.encoder = dataset.encoder
100
  self.max_length = max_length
101
 
102
- def __call__(self, batch: List[TopicExample]) -> dict[str, torch.Tensor]:
103
- texts = [example.text for example in batch]
104
  encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
105
  labels = torch.as_tensor(
106
- self.encoder.transform([example.topic for example in batch]), dtype=torch.long
 
107
  )
108
  return {
109
  "input_ids": encoded["input_ids"],
@@ -112,6 +112,9 @@ class TopicCollator:
112
  }
113
 
114
 
 
 
 
115
  def build_summarization_dataloader(
116
  dataset: SummarizationDataset,
117
  tokenizer: Tokenizer,
@@ -123,6 +126,7 @@ def build_summarization_dataloader(
123
  num_workers: int = 0,
124
  pin_memory: bool = False,
125
  ) -> DataLoader:
 
126
  collator = SummarizationCollator(
127
  tokenizer,
128
  max_source_length=max_source_length,
@@ -135,6 +139,7 @@ def build_summarization_dataloader(
135
  collate_fn=collator,
136
  num_workers=num_workers,
137
  pin_memory=pin_memory,
 
138
  )
139
 
140
 
@@ -148,6 +153,7 @@ def build_emotion_dataloader(
148
  num_workers: int = 0,
149
  pin_memory: bool = False,
150
  ) -> DataLoader:
 
151
  collator = EmotionCollator(tokenizer, dataset, max_length=max_length)
152
  return DataLoader(
153
  dataset,
@@ -156,6 +162,7 @@ def build_emotion_dataloader(
156
  collate_fn=collator,
157
  num_workers=num_workers,
158
  pin_memory=pin_memory,
 
159
  )
160
 
161
 
@@ -169,6 +176,7 @@ def build_topic_dataloader(
169
  num_workers: int = 0,
170
  pin_memory: bool = False,
171
  ) -> DataLoader:
 
172
  collator = TopicCollator(tokenizer, dataset, max_length=max_length)
173
  return DataLoader(
174
  dataset,
@@ -177,4 +185,5 @@ def build_topic_dataloader(
177
  collate_fn=collator,
178
  num_workers=num_workers,
179
  pin_memory=pin_memory,
 
180
  )
 
1
+ """
2
+ DataLoader builders for LexiMind.
3
+
4
+ Task-specific collators and factory functions for summarization, emotion, and topic.
5
+
6
+ Author: Oliver Perrin
7
+ Date: December 2025
8
+ """
9
 
10
  from __future__ import annotations
11
 
12
+ from typing import Dict, List
13
 
14
  import torch
15
  from torch.utils.data import DataLoader
 
24
  )
25
  from .tokenization import Tokenizer
26
 
27
+ # --------------- Collators ---------------
28
+
29
 
30
  class SummarizationCollator:
31
+ """Prepare encoder-decoder batches for seq2seq summarization."""
32
 
33
  def __init__(
34
  self,
 
41
  self.max_source_length = max_source_length
42
  self.max_target_length = max_target_length
43
 
44
+ def __call__(self, batch: List[SummarizationExample]) -> Dict[str, torch.Tensor]:
45
+ sources = [ex.source for ex in batch]
46
+ targets = [ex.summary for ex in batch]
47
 
48
+ src_enc = self.tokenizer.batch_encode(sources, max_length=self.max_source_length)
49
+ tgt_enc = self.tokenizer.batch_encode(targets, max_length=self.max_target_length)
50
 
51
+ # Shift targets: tgt_ids = [BOS, A, B], labels = [A, B, EOS]
52
+ ids = tgt_enc["input_ids"]
53
+ mask = tgt_enc["attention_mask"]
 
54
 
 
 
 
 
 
55
  tgt_ids = ids[:, :-1]
 
 
56
  labels = ids[:, 1:].clone()
57
+ labels[mask[:, 1:] == 0] = -100 # Mask padding in loss
 
 
 
 
58
 
59
  return {
60
+ "src_ids": src_enc["input_ids"],
61
+ "src_mask": src_enc["attention_mask"],
62
  "tgt_ids": tgt_ids,
63
  "labels": labels,
64
  }
 
74
  self.binarizer = dataset.binarizer
75
  self.max_length = max_length
76
 
77
+ def __call__(self, batch: List[EmotionExample]) -> Dict[str, torch.Tensor]:
78
+ texts = [ex.text for ex in batch]
79
  encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
80
+ labels = torch.as_tensor(
81
+ self.binarizer.transform([ex.emotions for ex in batch]),
82
+ dtype=torch.float32,
83
+ )
84
  return {
85
  "input_ids": encoded["input_ids"],
86
  "attention_mask": encoded["attention_mask"],
 
89
 
90
 
91
  class TopicCollator:
92
+ """Prepare batches for single-label topic classification."""
93
 
94
  def __init__(
95
  self, tokenizer: Tokenizer, dataset: TopicDataset, *, max_length: int | None = None
 
98
  self.encoder = dataset.encoder
99
  self.max_length = max_length
100
 
101
+ def __call__(self, batch: List[TopicExample]) -> Dict[str, torch.Tensor]:
102
+ texts = [ex.text for ex in batch]
103
  encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
104
  labels = torch.as_tensor(
105
+ self.encoder.transform([ex.topic for ex in batch]),
106
+ dtype=torch.long,
107
  )
108
  return {
109
  "input_ids": encoded["input_ids"],
 
112
  }
113
 
114
 
115
+ # --------------- Factory Functions ---------------
116
+
117
+
118
  def build_summarization_dataloader(
119
  dataset: SummarizationDataset,
120
  tokenizer: Tokenizer,
 
126
  num_workers: int = 0,
127
  pin_memory: bool = False,
128
  ) -> DataLoader:
129
+ """Create dataloader for summarization task."""
130
  collator = SummarizationCollator(
131
  tokenizer,
132
  max_source_length=max_source_length,
 
139
  collate_fn=collator,
140
  num_workers=num_workers,
141
  pin_memory=pin_memory,
142
+ persistent_workers=num_workers > 0, # Keep workers alive between epochs
143
  )
144
 
145
 
 
153
  num_workers: int = 0,
154
  pin_memory: bool = False,
155
  ) -> DataLoader:
156
+ """Create dataloader for emotion classification task."""
157
  collator = EmotionCollator(tokenizer, dataset, max_length=max_length)
158
  return DataLoader(
159
  dataset,
 
162
  collate_fn=collator,
163
  num_workers=num_workers,
164
  pin_memory=pin_memory,
165
+ persistent_workers=num_workers > 0,
166
  )
167
 
168
 
 
176
  num_workers: int = 0,
177
  pin_memory: bool = False,
178
  ) -> DataLoader:
179
+ """Create dataloader for topic classification task."""
180
  collator = TopicCollator(tokenizer, dataset, max_length=max_length)
181
  return DataLoader(
182
  dataset,
 
185
  collate_fn=collator,
186
  num_workers=num_workers,
187
  pin_memory=pin_memory,
188
+ persistent_workers=num_workers > 0,
189
  )
src/data/dataset.py CHANGED
@@ -1,4 +1,13 @@
1
- """Dataset definitions for the LexiMind multitask training pipeline."""
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
1
+ """
2
+ Dataset definitions for the LexiMind multitask training pipeline.
3
+
4
+ Defines PyTorch Dataset classes and data loading utilities for summarization,
5
+ emotion classification, and topic classification tasks. Supports both JSON
6
+ array and JSONL file formats.
7
+
8
+ Author: Oliver Perrin
9
+ Date: December 2025
10
+ """
11
 
12
  from __future__ import annotations
13
 
src/data/preprocessing.py CHANGED
@@ -1,52 +1,64 @@
1
- """Text preprocessing utilities built around Hugging Face tokenizers."""
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
  from dataclasses import dataclass, replace
6
- from typing import Iterable, List, Sequence
7
 
8
  import torch
9
- from sklearn.base import BaseEstimator, TransformerMixin
10
 
11
  from .tokenization import Tokenizer, TokenizerConfig
12
 
 
13
 
14
- class BasicTextCleaner(BaseEstimator, TransformerMixin):
15
- """Minimal text cleaner following scikit-learn conventions."""
16
 
17
- def __init__(self, lowercase: bool = True, strip: bool = True) -> None:
 
 
 
18
  self.lowercase = lowercase
19
- self.strip = strip
20
 
21
- def fit(self, texts: Iterable[str], y: Iterable[str] | None = None):
22
- return self
 
 
 
 
 
 
 
 
23
 
24
- def transform(self, texts: Iterable[str]) -> List[str]:
25
- return [self._clean_text(text) for text in texts]
 
 
26
 
27
- def _clean_text(self, text: str) -> str:
28
- item = text.strip() if self.strip else text
29
- if self.lowercase:
30
- item = item.lower()
31
- return " ".join(item.split())
32
 
33
 
34
  @dataclass
35
  class Batch:
36
- """Bundle of tensors returned by the text preprocessor."""
37
 
38
  input_ids: torch.Tensor
39
  attention_mask: torch.Tensor
40
  lengths: List[int]
41
 
42
 
43
- class TextPreprocessor:
44
- """Coordinate lightweight text cleaning and tokenization.
45
 
46
- When supplying an already-initialized tokenizer instance, its configuration is left
47
- untouched. If a differing ``max_length`` is requested, a ``ValueError`` is raised to
48
- avoid mutating shared tokenizer state.
49
- """
50
 
51
  def __init__(
52
  self,
@@ -56,19 +68,10 @@ class TextPreprocessor:
56
  tokenizer_name: str = "google/flan-t5-base",
57
  max_length: int | None = None,
58
  lowercase: bool = True,
59
- remove_stopwords: bool = False,
60
- sklearn_transformer: TransformerMixin | None = None,
61
  ) -> None:
62
- self.cleaner = BasicTextCleaner(lowercase=lowercase, strip=True)
63
- self.lowercase = lowercase
64
- if remove_stopwords:
65
- raise ValueError(
66
- "Stop-word removal is not supported because it conflicts with subword tokenizers; "
67
- "clean the text externally before initializing TextPreprocessor."
68
- )
69
- self._stop_words = None
70
- self._sklearn_transformer = sklearn_transformer
71
 
 
72
  if tokenizer is None:
73
  cfg = tokenizer_config or TokenizerConfig(pretrained_model_name=tokenizer_name)
74
  if max_length is not None:
@@ -78,52 +81,33 @@ class TextPreprocessor:
78
  self.tokenizer = tokenizer
79
  if max_length is not None and max_length != tokenizer.config.max_length:
80
  raise ValueError(
81
- "Provided tokenizer config.max_length does not match requested max_length; "
82
- "initialise the tokenizer with desired settings before passing it in."
83
  )
84
 
85
  self.max_length = max_length or self.tokenizer.config.max_length
86
 
87
  def clean_text(self, text: str) -> str:
88
- item = self.cleaner.transform([text])[0]
89
- return self._normalize_tokens(item)
90
-
91
- def _normalize_tokens(self, text: str) -> str:
92
- """Apply token-level normalization and optional stop-word filtering."""
93
- # Note: Pre-tokenization word-splitting is incompatible with subword tokenizers.
94
- # Stop-word filtering should be done post-tokenization or not at all for transformers.
95
- return text
96
-
97
- def _apply_sklearn_transform(self, texts: List[str]) -> List[str]:
98
- if self._sklearn_transformer is None:
99
- return texts
100
-
101
- transform = getattr(self._sklearn_transformer, "transform", None)
102
- if transform is None:
103
- raise AttributeError("Provided sklearn transformer must implement a 'transform' method")
104
- transformed = transform(texts)
105
- if isinstance(transformed, list):
106
- return transformed # assume downstream type is already list[str]
107
- if hasattr(transformed, "tolist"):
108
- transformed = transformed.tolist()
109
-
110
- result = list(transformed)
111
- if not all(isinstance(item, str) for item in result):
112
- result = [str(item) for item in result]
113
- return result
114
-
115
- def _prepare_texts(self, texts: Sequence[str]) -> List[str]:
116
- cleaned = self.cleaner.transform(texts)
117
- normalized = [self._normalize_tokens(text) for text in cleaned]
118
- return self._apply_sklearn_transform(normalized)
119
 
120
  def batch_encode(self, texts: Sequence[str]) -> Batch:
121
- cleaned = self._prepare_texts(texts)
 
122
  encoded = self.tokenizer.batch_encode(cleaned, max_length=self.max_length)
123
- input_ids: torch.Tensor = encoded["input_ids"]
124
- attention_mask: torch.Tensor = encoded["attention_mask"].to(dtype=torch.bool)
 
125
  lengths = attention_mask.sum(dim=1).tolist()
 
126
  return Batch(input_ids=input_ids, attention_mask=attention_mask, lengths=lengths)
127
 
128
  def __call__(self, texts: Sequence[str]) -> Batch:
 
129
  return self.batch_encode(texts)
 
 
 
 
 
 
 
1
+ """
2
+ Text preprocessing for LexiMind.
3
+
4
+ Lightweight text cleaning and tokenization pipeline for model input preparation.
5
+
6
+ Author: Oliver Perrin
7
+ Date: December 2025
8
+ """
9
 
10
  from __future__ import annotations
11
 
12
  from dataclasses import dataclass, replace
13
+ from typing import List, Sequence
14
 
15
  import torch
 
16
 
17
  from .tokenization import Tokenizer, TokenizerConfig
18
 
19
+ # --------------- Text Cleaning ---------------
20
 
 
 
21
 
22
+ class TextCleaner:
23
+ """Basic text normalization."""
24
+
25
+ def __init__(self, lowercase: bool = True) -> None:
26
  self.lowercase = lowercase
 
27
 
28
+ def clean(self, text: str) -> str:
29
+ """Strip, normalize whitespace, optionally lowercase."""
30
+ text = text.strip()
31
+ if self.lowercase:
32
+ text = text.lower()
33
+ return " ".join(text.split())
34
+
35
+ def clean_batch(self, texts: Sequence[str]) -> List[str]:
36
+ """Clean multiple texts."""
37
+ return [self.clean(t) for t in texts]
38
 
39
+ # Backwards compatibility alias
40
+ def transform(self, texts: Sequence[str]) -> List[str]:
41
+ """Alias for clean_batch (sklearn-style interface)."""
42
+ return self.clean_batch(texts)
43
 
44
+
45
+ # --------------- Batch Output ---------------
 
 
 
46
 
47
 
48
  @dataclass
49
  class Batch:
50
+ """Tokenized batch ready for model consumption."""
51
 
52
  input_ids: torch.Tensor
53
  attention_mask: torch.Tensor
54
  lengths: List[int]
55
 
56
 
57
+ # --------------- Preprocessor ---------------
 
58
 
59
+
60
+ class TextPreprocessor:
61
+ """Combines text cleaning with tokenization."""
 
62
 
63
  def __init__(
64
  self,
 
68
  tokenizer_name: str = "google/flan-t5-base",
69
  max_length: int | None = None,
70
  lowercase: bool = True,
 
 
71
  ) -> None:
72
+ self.cleaner = TextCleaner(lowercase=lowercase)
 
 
 
 
 
 
 
 
73
 
74
+ # Initialize or validate tokenizer
75
  if tokenizer is None:
76
  cfg = tokenizer_config or TokenizerConfig(pretrained_model_name=tokenizer_name)
77
  if max_length is not None:
 
81
  self.tokenizer = tokenizer
82
  if max_length is not None and max_length != tokenizer.config.max_length:
83
  raise ValueError(
84
+ "max_length conflicts with tokenizer config - "
85
+ "initialize tokenizer with desired settings"
86
  )
87
 
88
  self.max_length = max_length or self.tokenizer.config.max_length
89
 
90
  def clean_text(self, text: str) -> str:
91
+ """Clean a single text."""
92
+ return self.cleaner.clean(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  def batch_encode(self, texts: Sequence[str]) -> Batch:
95
+ """Clean and tokenize texts into a batch."""
96
+ cleaned = self.cleaner.clean_batch(texts)
97
  encoded = self.tokenizer.batch_encode(cleaned, max_length=self.max_length)
98
+
99
+ input_ids = encoded["input_ids"]
100
+ attention_mask = encoded["attention_mask"].to(dtype=torch.bool)
101
  lengths = attention_mask.sum(dim=1).tolist()
102
+
103
  return Batch(input_ids=input_ids, attention_mask=attention_mask, lengths=lengths)
104
 
105
  def __call__(self, texts: Sequence[str]) -> Batch:
106
+ """Alias for batch_encode."""
107
  return self.batch_encode(texts)
108
+
109
+
110
+ # --------------- Backwards Compatibility ---------------
111
+
112
+ # Keep old name for any imports
113
+ BasicTextCleaner = TextCleaner
src/data/tokenization.py CHANGED
@@ -1,4 +1,13 @@
1
- """Tokenizer wrapper around HuggingFace models used across LexiMind."""
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
1
+ """
2
+ Tokenizer facade for LexiMind.
3
+
4
+ Wraps HuggingFace tokenizers with a simplified interface that handles
5
+ special token management, batch encoding, and T5-specific conventions
6
+ for decoder input preparation.
7
+
8
+ Author: Oliver Perrin
9
+ Date: December 2025
10
+ """
11
 
12
  from __future__ import annotations
13
 
src/inference/factory.py CHANGED
@@ -1,4 +1,12 @@
1
- """Helpers to assemble an inference pipeline from saved artifacts."""
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
1
+ """
2
+ Inference pipeline factory for LexiMind.
3
+
4
+ Assembles a complete inference pipeline from saved checkpoints, tokenizer
5
+ artifacts, and label metadata. Handles model loading and configuration.
6
+
7
+ Author: Oliver Perrin
8
+ Date: December 2025
9
+ """
10
 
11
  from __future__ import annotations
12
 
src/inference/pipeline.py CHANGED
@@ -1,9 +1,17 @@
1
- """Inference helpers for multitask LexiMind models."""
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
  from dataclasses import dataclass, fields, replace
6
- from typing import Any, Iterable, List, Sequence, cast
7
 
8
  import torch
9
  import torch.nn.functional as F
@@ -11,10 +19,12 @@ import torch.nn.functional as F
11
  from ..data.preprocessing import Batch, TextPreprocessor
12
  from ..data.tokenization import Tokenizer
13
 
 
 
14
 
15
  @dataclass
16
  class InferenceConfig:
17
- """Configuration knobs for the inference pipeline."""
18
 
19
  summary_max_length: int = 128
20
  emotion_threshold: float = 0.5
@@ -33,8 +43,11 @@ class TopicPrediction:
33
  confidence: float
34
 
35
 
 
 
 
36
  class InferencePipeline:
37
- """Run summarization, emotion, and topic heads through a unified interface."""
38
 
39
  def __init__(
40
  self,
@@ -50,50 +63,49 @@ class InferencePipeline:
50
  self.model = model
51
  self.tokenizer = tokenizer
52
  self.config = config or InferenceConfig()
53
- chosen_device = device or self.config.device
54
- if chosen_device is None:
55
- first_param = next(model.parameters(), None)
56
- chosen_device = first_param.device if first_param is not None else "cpu"
57
- self.device = torch.device(chosen_device)
 
 
 
58
  self.model.to(self.device)
59
  self.model.eval()
60
 
61
  self.preprocessor = preprocessor or TextPreprocessor(tokenizer=tokenizer)
62
- self.emotion_labels = list(emotion_labels) if emotion_labels is not None else None
63
- self.topic_labels = list(topic_labels) if topic_labels is not None else None
 
 
64
 
65
  def summarize(self, texts: Sequence[str], *, max_length: int | None = None) -> List[str]:
 
66
  if not texts:
67
  return []
68
- batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
 
69
  src_ids = batch.input_ids
70
  src_mask = batch.attention_mask
71
  max_len = max_length or self.config.summary_max_length
72
 
73
- if not hasattr(self.model, "encoder") or not hasattr(self.model, "decoder"):
74
- raise RuntimeError(
75
- "Model must expose encoder and decoder attributes for summarization."
76
- )
77
-
78
- # Cast to Any to allow access to dynamic attributes encoder and decoder
79
  model = cast(Any, self.model)
 
 
80
 
81
  with torch.inference_mode():
82
- encoder_mask = (
 
83
  src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
84
  )
85
- memory = model.encoder(src_ids, mask=encoder_mask)
86
- min_len = 10
87
 
88
- # Ban BOS, PAD, UNK from being generated
89
- ban_token_ids = [
90
- self.tokenizer.bos_token_id,
91
- self.tokenizer.pad_token_id,
92
- ]
93
- unk_id = getattr(self.tokenizer._tokenizer, "unk_token_id", None)
94
- if isinstance(unk_id, int):
95
- ban_token_ids.append(unk_id)
96
- ban_token_ids = [tid for tid in ban_token_ids if tid is not None]
97
 
98
  generated = model.decoder.greedy_decode(
99
  memory=memory,
@@ -101,16 +113,15 @@ class InferencePipeline:
101
  start_token_id=self.tokenizer.bos_token_id,
102
  end_token_id=self.tokenizer.eos_token_id,
103
  device=self.device,
104
- min_len=min_len,
105
- ban_token_ids=ban_token_ids,
106
  no_repeat_ngram_size=3,
107
  memory_mask=src_mask,
108
  )
109
 
110
- decoded_list = self.tokenizer.decode_batch(generated.tolist())
111
- final_summaries = decoded_list
112
 
113
- return final_summaries
114
 
115
  def predict_emotions(
116
  self,
@@ -118,78 +129,91 @@ class InferencePipeline:
118
  *,
119
  threshold: float | None = None,
120
  ) -> List[EmotionPrediction]:
 
121
  if not texts:
122
  return []
123
- if self.emotion_labels is None or not self.emotion_labels:
124
- raise RuntimeError("emotion_labels must be provided to decode emotion predictions")
125
 
126
- batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
127
- model_inputs = self._batch_to_model_inputs(batch)
128
- decision_threshold = threshold or self.config.emotion_threshold
129
 
130
  with torch.inference_mode():
131
- logits = self.model.forward("emotion", model_inputs)
132
  probs = torch.sigmoid(logits)
133
 
134
- predictions: List[EmotionPrediction] = []
135
  for row in probs.cpu():
136
  pairs = [
137
  (label, score)
138
  for label, score in zip(self.emotion_labels, row.tolist(), strict=False)
139
- if score >= decision_threshold
140
  ]
141
- labels = [label for label, _ in pairs]
142
- scores = [score for _, score in pairs]
143
- predictions.append(EmotionPrediction(labels=labels, scores=scores))
144
- return predictions
 
 
 
 
 
145
 
146
  def predict_topics(self, texts: Sequence[str]) -> List[TopicPrediction]:
 
147
  if not texts:
148
  return []
149
- if self.topic_labels is None or not self.topic_labels:
150
- raise RuntimeError("topic_labels must be provided to decode topic predictions")
151
 
152
- batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
153
- model_inputs = self._batch_to_model_inputs(batch)
154
 
155
  with torch.inference_mode():
156
- logits = self.model.forward("topic", model_inputs)
157
  probs = F.softmax(logits, dim=-1)
158
 
159
- results: List[TopicPrediction] = []
160
  for row in probs.cpu():
161
- scores = row.tolist()
162
- best_index = int(row.argmax().item())
163
  results.append(
164
- TopicPrediction(label=self.topic_labels[best_index], confidence=scores[best_index])
 
 
 
165
  )
166
  return results
167
 
168
- def batch_predict(self, texts: Iterable[str]) -> dict[str, object]:
 
 
 
 
 
 
169
  text_list = list(texts)
170
- if self.emotion_labels is None or not self.emotion_labels:
171
- raise RuntimeError("emotion_labels must be provided for batch predictions")
172
- if self.topic_labels is None or not self.topic_labels:
173
- raise RuntimeError("topic_labels must be provided for batch predictions")
174
  return {
175
  "summaries": self.summarize(text_list),
176
  "emotion": self.predict_emotions(text_list),
177
  "topic": self.predict_topics(text_list),
178
  }
179
 
180
- def _batch_to_device(self, batch: Batch) -> Batch:
181
- tensor_updates: dict[str, torch.Tensor] = {}
182
- for item in fields(batch):
183
- value = getattr(batch, item.name)
184
- if torch.is_tensor(value):
185
- tensor_updates[item.name] = value.to(self.device)
186
- if not tensor_updates:
187
- return batch
188
- return replace(batch, **tensor_updates)
 
189
 
190
  @staticmethod
191
- def _batch_to_model_inputs(batch: Batch) -> dict[str, torch.Tensor]:
192
- inputs: dict[str, torch.Tensor] = {"input_ids": batch.input_ids}
 
193
  if batch.attention_mask is not None:
194
  inputs["attention_mask"] = batch.attention_mask
195
  return inputs
 
1
+ """
2
+ Inference pipeline for LexiMind.
3
+
4
+ Unified interface for summarization, emotion detection, and topic classification
5
+ with batched processing and device management.
6
+
7
+ Author: Oliver Perrin
8
+ Date: December 2025
9
+ """
10
 
11
  from __future__ import annotations
12
 
13
  from dataclasses import dataclass, fields, replace
14
+ from typing import Any, Dict, List, Sequence, cast
15
 
16
  import torch
17
  import torch.nn.functional as F
 
19
  from ..data.preprocessing import Batch, TextPreprocessor
20
  from ..data.tokenization import Tokenizer
21
 
22
+ # --------------- Configuration ---------------
23
+
24
 
25
  @dataclass
26
  class InferenceConfig:
27
+ """Pipeline settings."""
28
 
29
  summary_max_length: int = 128
30
  emotion_threshold: float = 0.5
 
43
  confidence: float
44
 
45
 
46
+ # --------------- Pipeline ---------------
47
+
48
+
49
  class InferencePipeline:
50
+ """Multi-task inference with batched processing."""
51
 
52
  def __init__(
53
  self,
 
63
  self.model = model
64
  self.tokenizer = tokenizer
65
  self.config = config or InferenceConfig()
66
+
67
+ # Resolve device
68
+ chosen = device or self.config.device
69
+ if chosen is None:
70
+ param = next(model.parameters(), None)
71
+ chosen = param.device if param else "cpu"
72
+ self.device = torch.device(chosen)
73
+
74
  self.model.to(self.device)
75
  self.model.eval()
76
 
77
  self.preprocessor = preprocessor or TextPreprocessor(tokenizer=tokenizer)
78
+ self.emotion_labels = list(emotion_labels) if emotion_labels else None
79
+ self.topic_labels = list(topic_labels) if topic_labels else None
80
+
81
+ # --------------- Summarization ---------------
82
 
83
  def summarize(self, texts: Sequence[str], *, max_length: int | None = None) -> List[str]:
84
+ """Generate summaries for input texts."""
85
  if not texts:
86
  return []
87
+
88
+ batch = self._to_device(self.preprocessor.batch_encode(texts))
89
  src_ids = batch.input_ids
90
  src_mask = batch.attention_mask
91
  max_len = max_length or self.config.summary_max_length
92
 
 
 
 
 
 
 
93
  model = cast(Any, self.model)
94
+ if not hasattr(model, "encoder") or not hasattr(model, "decoder"):
95
+ raise RuntimeError("Model must have encoder and decoder for summarization")
96
 
97
  with torch.inference_mode():
98
+ # Encode
99
+ enc_mask = (
100
  src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
101
  )
102
+ memory = model.encoder(src_ids, mask=enc_mask)
 
103
 
104
+ # Decode with constraints to improve quality
105
+ ban_ids = [self.tokenizer.bos_token_id, self.tokenizer.pad_token_id]
106
+ unk = getattr(self.tokenizer._tokenizer, "unk_token_id", None)
107
+ if isinstance(unk, int):
108
+ ban_ids.append(unk)
 
 
 
 
109
 
110
  generated = model.decoder.greedy_decode(
111
  memory=memory,
 
113
  start_token_id=self.tokenizer.bos_token_id,
114
  end_token_id=self.tokenizer.eos_token_id,
115
  device=self.device,
116
+ min_len=10,
117
+ ban_token_ids=[i for i in ban_ids if i is not None],
118
  no_repeat_ngram_size=3,
119
  memory_mask=src_mask,
120
  )
121
 
122
+ return self.tokenizer.decode_batch(generated.tolist())
 
123
 
124
+ # --------------- Emotion ---------------
125
 
126
  def predict_emotions(
127
  self,
 
129
  *,
130
  threshold: float | None = None,
131
  ) -> List[EmotionPrediction]:
132
+ """Predict emotions for input texts."""
133
  if not texts:
134
  return []
135
+ if not self.emotion_labels:
136
+ raise RuntimeError("emotion_labels required for emotion prediction")
137
 
138
+ batch = self._to_device(self.preprocessor.batch_encode(texts))
139
+ inputs = self._model_inputs(batch)
140
+ thresh = threshold or self.config.emotion_threshold
141
 
142
  with torch.inference_mode():
143
+ logits = self.model.forward("emotion", inputs)
144
  probs = torch.sigmoid(logits)
145
 
146
+ results = []
147
  for row in probs.cpu():
148
  pairs = [
149
  (label, score)
150
  for label, score in zip(self.emotion_labels, row.tolist(), strict=False)
151
+ if score >= thresh
152
  ]
153
+ results.append(
154
+ EmotionPrediction(
155
+ labels=[label for label, _ in pairs],
156
+ scores=[score for _, score in pairs],
157
+ )
158
+ )
159
+ return results
160
+
161
+ # --------------- Topic ---------------
162
 
163
  def predict_topics(self, texts: Sequence[str]) -> List[TopicPrediction]:
164
+ """Predict topic for input texts."""
165
  if not texts:
166
  return []
167
+ if not self.topic_labels:
168
+ raise RuntimeError("topic_labels required for topic prediction")
169
 
170
+ batch = self._to_device(self.preprocessor.batch_encode(texts))
171
+ inputs = self._model_inputs(batch)
172
 
173
  with torch.inference_mode():
174
+ logits = self.model.forward("topic", inputs)
175
  probs = F.softmax(logits, dim=-1)
176
 
177
+ results = []
178
  for row in probs.cpu():
179
+ idx = int(row.argmax().item())
 
180
  results.append(
181
+ TopicPrediction(
182
+ label=self.topic_labels[idx],
183
+ confidence=row[idx].item(),
184
+ )
185
  )
186
  return results
187
 
188
+ # --------------- Batch Prediction ---------------
189
+
190
+ def batch_predict(self, texts: Sequence[str]) -> Dict[str, Any]:
191
+ """Run all three tasks on input texts."""
192
+ if not self.emotion_labels or not self.topic_labels:
193
+ raise RuntimeError("Both emotion_labels and topic_labels required")
194
+
195
  text_list = list(texts)
 
 
 
 
196
  return {
197
  "summaries": self.summarize(text_list),
198
  "emotion": self.predict_emotions(text_list),
199
  "topic": self.predict_topics(text_list),
200
  }
201
 
202
+ # --------------- Helpers ---------------
203
+
204
+ def _to_device(self, batch: Batch) -> Batch:
205
+ """Move batch tensors to device with non_blocking for speed."""
206
+ updates = {}
207
+ for f in fields(batch):
208
+ val = getattr(batch, f.name)
209
+ if torch.is_tensor(val):
210
+ updates[f.name] = val.to(self.device, non_blocking=True)
211
+ return replace(batch, **updates) if updates else batch
212
 
213
  @staticmethod
214
+ def _model_inputs(batch: Batch) -> Dict[str, torch.Tensor]:
215
+ """Extract model inputs from batch."""
216
+ inputs = {"input_ids": batch.input_ids}
217
  if batch.attention_mask is not None:
218
  inputs["attention_mask"] = batch.attention_mask
219
  return inputs
src/inference/postprocessing.py CHANGED
@@ -1,4 +1,11 @@
1
- """Output cleaning helpers."""
 
 
 
 
 
 
 
2
 
3
  from typing import List
4
 
 
1
+ """
2
+ Output postprocessing utilities for LexiMind.
3
+
4
+ Provides text cleaning helpers for model outputs.
5
+
6
+ Author: Oliver Perrin
7
+ Date: December 2025
8
+ """
9
 
10
  from typing import List
11
 
src/models/decoder.py CHANGED
@@ -1,16 +1,17 @@
1
- """
2
- Transformer Decoder (Pre-LN) - implementation.
3
-
4
- Implements:
5
- - create_causal_mask
6
- - TransformerDecoderLayer
7
- - TransformerDecoder (stack + naive greedy decoding)
8
-
9
- Conventions:
10
- - Masks are boolean: True = allowed, False = masked.
11
- - MultiHeadAttention expects masks broadcastable to (B, num_heads, T_q, T_k).
12
- - This decoder uses Pre-LN (RMSNorm before each sublayer).
13
- - RMSNorm is just simpler than LayerNorm and more computationally efficient, it's become the modern convention. These reasons are why I used it here.
 
14
  """
15
 
16
  from typing import Any, Dict, List, Literal, Optional, Tuple, Union
 
1
+ """Transformer Decoder implementation (Pre-LN).
2
+
3
+ This module implements the decoder component of the Transformer architecture:
4
+ - create_causal_mask: Generate causal attention masks
5
+ - TransformerDecoderLayer: Single decoder block with self-attn + cross-attn + FFN
6
+ - TransformerDecoder: Full stack with embeddings, positional encoding, and generation
7
+
8
+ Design notes:
9
+ - Pre-LN with RMSNorm for training stability
10
+ - Masks are boolean: True = attend, False = mask
11
+ - Supports T5-style relative position bias
12
+
13
+ Author: Oliver Perrin
14
+ Date: 2025-10-23
15
  """
16
 
17
  from typing import Any, Dict, List, Literal, Optional, Tuple, Union
src/models/encoder.py CHANGED
@@ -1,17 +1,16 @@
1
- """
2
- Transformer encoder implementation (Pre-LN).
3
-
4
- Contains:
5
- - TransformerEncoderLayer: one encoder block (self-attention + FFN with residuals + LayerNorm (RMSNorm - modern convention))
6
- - TransformerEncoder: embedding + positional encoding + stack of encoder layers
7
-
8
- Design choices:
9
- - Pre-LN (RMSNorm before each sublayer) for stable training.
10
- - The FeedForward module is position-wise and does NOT include residuals or normalization.
11
- - MultiHeadAttention handles mask broadcasting from (B, S, S) -> (B, 1, S, S) internally.
12
- - The encoder accepts either token ids (LongTensor) or precomputed embeddings (FloatTensor).
13
- If you pass token ids, provide vocab_size when constructing the encoder and optionally pad_token_id.
14
- - Optionally collect attention weights by passing collect_attn=True to forward().
15
  """
16
 
17
  from typing import List, Literal, Optional, Tuple, Union
@@ -213,9 +212,9 @@ class TransformerEncoder(nn.Module):
213
  Build a 3D attention mask (batch, seq, seq) from input_ids and pad_token_id.
214
  True indicates valid positions; False indicates masked (pad).
215
  """
216
- assert (
217
- self.pad_token_id is not None
218
- ), "pad_token_id must be set to build padding mask from ids."
219
  # mask shape: (batch, seq) where True = token kept (non-pad)
220
  pad_mask = input_ids != self.pad_token_id
221
  # Convert to (batch, seq_q, seq_k) by outer product broadcasting
 
1
+ """Transformer Encoder implementation (Pre-LN).
2
+
3
+ This module implements the encoder component of the Transformer architecture:
4
+ - TransformerEncoderLayer: Single encoder block with self-attention + FFN
5
+ - TransformerEncoder: Full stack with embeddings and positional encoding
6
+
7
+ Design notes:
8
+ - Pre-LN with RMSNorm for training stability
9
+ - Masks are boolean: True = attend, False = mask
10
+ - Supports T5-style relative position bias
11
+
12
+ Author: Oliver Perrin
13
+ Date: 2025-10-23
 
14
  """
15
 
16
  from typing import List, Literal, Optional, Tuple, Union
 
212
  Build a 3D attention mask (batch, seq, seq) from input_ids and pad_token_id.
213
  True indicates valid positions; False indicates masked (pad).
214
  """
215
+ assert self.pad_token_id is not None, (
216
+ "pad_token_id must be set to build padding mask from ids."
217
+ )
218
  # mask shape: (batch, seq) where True = token kept (non-pad)
219
  pad_mask = input_ids != self.pad_token_id
220
  # Convert to (batch, seq_q, seq_k) by outer product broadcasting
src/models/factory.py CHANGED
@@ -1,4 +1,14 @@
1
- """Factory helpers to assemble multitask models for inference/training."""
 
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
1
+ """Factory helpers to assemble multitask models.
2
+
3
+ This module provides model construction and weight loading utilities:
4
+ - ModelConfig: Dataclass for architecture hyperparameters
5
+ - load_model_config: Load configuration from YAML
6
+ - build_multitask_model: Construct full model with task heads
7
+ - Weight loading: Transfer pretrained T5/FLAN-T5 or LLaMA weights
8
+
9
+ Author: Oliver Perrin
10
+ Date: 2025-10-23
11
+ """
12
 
13
  from __future__ import annotations
14
 
src/models/feedforward.py CHANGED
@@ -1,5 +1,11 @@
1
- """
2
- Position-wise Feed-Forward Network.
 
 
 
 
 
 
3
  """
4
 
5
  from typing import Literal, Optional
 
1
+ """Position-wise Feed-Forward Network.
2
+
3
+ This module implements the FFN sublayer used in Transformer blocks:
4
+ - Standard FFN: Two linear layers with activation (GELU/ReLU)
5
+ - Gated FFN: SwiGLU (LLaMA-style) or Gated-GELU (T5/FLAN-T5 style)
6
+
7
+ Author: Oliver Perrin
8
+ Date: 2025-10-23
9
  """
10
 
11
  from typing import Literal, Optional
src/models/heads.py CHANGED
@@ -1,13 +1,13 @@
1
- """
2
- Prediction heads for Transformer models.
3
 
4
- Includes:
5
- - ClassificationHead: sequence-level classification with simple pooling (mean/cls/max).
6
- - TokenClassificationHead: per-token classification (e.g., NER).
7
- - LMHead: language-modeling head mapping hidden states to vocabulary logits. Optional weight tying to an Embedding.
8
- - ProjectionHead: small projection MLP for representation learning / contrastive heads.
9
 
10
- Keep these heads minimal, well-tested, and easy to compose on top of encoder/decoder outputs.
 
11
  """
12
 
13
  from typing import Literal, Optional
@@ -117,12 +117,12 @@ class LMHead(nn.Module):
117
 
118
  if tie_embedding is not None:
119
  # Validate sizes
120
- assert (
121
- tie_embedding.num_embeddings == vocab_size
122
- ), "vocab size mismatch for weight tying"
123
- assert (
124
- tie_embedding.embedding_dim == d_model
125
- ), "embedding dim must match d_model for weight tying"
126
  # Tie weights: point the projection weight to the embedding weight Tensor
127
  # Remove the existing projection parameter in favor of the embedding weight
128
  # This keeps the same Parameter object, so updates affect both modules.
 
1
+ """Prediction heads for Transformer models.
 
2
 
3
+ This module provides task-specific output heads:
4
+ - ClassificationHead: Sequence-level classification with pooling (mean/cls/max)
5
+ - TokenClassificationHead: Per-token classification (NER, POS tagging)
6
+ - LMHead: Language modeling logits with optional weight tying
7
+ - ProjectionHead: MLP for representation learning / contrastive tasks
8
 
9
+ Author: Oliver Perrin
10
+ Date: 2025-10-23
11
  """
12
 
13
  from typing import Literal, Optional
 
117
 
118
  if tie_embedding is not None:
119
  # Validate sizes
120
+ assert tie_embedding.num_embeddings == vocab_size, (
121
+ "vocab size mismatch for weight tying"
122
+ )
123
+ assert tie_embedding.embedding_dim == d_model, (
124
+ "embedding dim must match d_model for weight tying"
125
+ )
126
  # Tie weights: point the projection weight to the embedding weight Tensor
127
  # Remove the existing projection parameter in favor of the embedding weight
128
  # This keeps the same Parameter object, so updates affect both modules.
src/models/multitask.py CHANGED
@@ -1,18 +1,12 @@
1
- """
2
- Multitask model composition utilities.
3
-
4
- Provides:
5
- - MultiTaskModel: lightweight wrapper to compose an encoder and/or decoder with
6
- multiple task heads (classification, token classification, LM head, etc.)
7
- - add_head / remove_head helpers
8
- - forward(task_name, ...) that routes inputs to the correct sub-modules
9
- - compute_loss helper that uses common losses and ignore_index support
10
-
11
- Design goals:
12
- - Keep composition simple and explicit (use named heads per task)
13
- - Support encoder-only tasks (classification, token classification) and
14
- seq2seq tasks (encoder -> decoder -> LMHead)
15
- - Minimal dependencies on training loop; return logits and (optionally) loss
16
  """
17
 
18
  from typing import Any, Dict, Optional
 
1
+ """Multitask model composition utilities.
2
+
3
+ This module provides infrastructure for multi-task learning:
4
+ - MultiTaskModel: Compose encoder/decoder with multiple task heads
5
+ - Routing: forward(task_name, ...) dispatches to correct components
6
+ - Loss computation: Built-in cross-entropy with ignore_index support
7
+
8
+ Author: Oliver Perrin
9
+ Date: 2025-10-23
 
 
 
 
 
 
10
  """
11
 
12
  from typing import Any, Dict, Optional
src/models/positional_encoding.py CHANGED
@@ -1,10 +1,12 @@
1
- # src/models/positional_encoding.py
2
-
3
  """
4
  Positional Encoding for Transformer models.
5
 
6
- Injects information about the position of tokens in a sequence, since
7
- self-attention has no inherent notion of token order.
 
 
 
 
8
  """
9
 
10
  import math
 
 
 
1
  """
2
  Positional Encoding for Transformer models.
3
 
4
+ Provides sinusoidal position embeddings that inject sequential order information
5
+ into token representations. Required because self-attention is permutation-invariant
6
+ and has no inherent notion of token position.
7
+
8
+ Author: Oliver Perrin
9
+ Date: December 2025
10
  """
11
 
12
  import math
src/training/metrics.py CHANGED
@@ -1,4 +1,13 @@
1
- """Metric helpers used during training and evaluation."""
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
1
+ """
2
+ Training and evaluation metrics for LexiMind.
3
+
4
+ Provides metric computation utilities for all task types: accuracy for topic
5
+ classification, multi-label F1 for emotion detection, and ROUGE/BLEU for
6
+ summarization quality assessment.
7
+
8
+ Author: Oliver Perrin
9
+ Date: December 2025
10
+ """
11
 
12
  from __future__ import annotations
13
 
src/training/trainer.py CHANGED
@@ -1,38 +1,50 @@
1
- """Multi-task trainer coordinating summarization, emotion, and topic heads."""
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
- import shutil
6
  import time
7
  from collections import defaultdict
8
  from dataclasses import dataclass
9
- from typing import Callable, Dict, Iterator, List
10
 
11
  import mlflow
12
  import torch
13
  import torch.nn.functional as F
14
  from torch.utils.data import DataLoader
 
15
 
16
  from ..data.tokenization import Tokenizer
17
  from .metrics import accuracy, multilabel_f1, rouge_like
18
 
 
 
19
 
20
  @dataclass
21
  class TrainerConfig:
 
 
22
  max_epochs: int = 1
23
  gradient_clip_norm: float = 1.0
24
- logging_interval: int = 50
25
  task_weights: Dict[str, float] | None = None
26
  validation_samples: int = 3
27
  validation_max_length: int = 128
28
- label_smoothing: float = 0.0 # Label smoothing for regularization (e.g., 0.1)
29
  experiment_name: str = "LexiMind"
30
  run_name: str | None = None
31
  gradient_accumulation_steps: int = 1
32
 
33
 
 
34
  class Trainer:
35
- """Coordinates multi-task optimisation across task-specific dataloaders."""
36
 
37
  def __init__(
38
  self,
@@ -47,392 +59,315 @@ class Trainer:
47
  self.config = config
48
  self.device = device
49
  self.tokenizer = tokenizer
 
 
50
  self.emotion_loss = torch.nn.BCEWithLogitsLoss()
51
  self.topic_loss = torch.nn.CrossEntropyLoss()
52
- # Apply label smoothing to summarization task if configured
53
- self.label_smoothing = config.label_smoothing
54
- self._progress_last_len = 0
55
- self.gradient_accumulation_steps = max(1, config.gradient_accumulation_steps)
56
- self._nan_counter = 0 # Track consecutive NaNs
57
-
58
- # Mixed Precision Training
59
- # Initialize GradScaler for float16/bfloat16 training
60
- # This scales gradients to prevent underflow during backward pass
61
- # Note: bfloat16 generally doesn't need scaling, but we keep it for safety unless it causes NaNs
62
- self.scaler = torch.GradScaler("cuda", enabled=(device.type == "cuda"))
63
-
64
- # Initialize MLflow
65
  mlflow.set_experiment(config.experiment_name)
66
 
 
 
 
 
 
 
 
67
  def fit(
68
  self,
69
  train_loaders: Dict[str, DataLoader],
70
  val_loaders: Dict[str, DataLoader] | None = None,
71
  checkpoint_callback: Callable | None = None,
72
  ) -> Dict[str, Dict[str, float]]:
73
- """Train the model.
74
-
75
- Args:
76
- train_loaders: Task-specific training dataloaders
77
- val_loaders: Optional task-specific validation dataloaders
78
- checkpoint_callback: Optional callback(epoch, model, history) to save checkpoints
79
-
80
- Returns:
81
- Training history dictionary
82
- """
83
  history: Dict[str, Dict[str, float]] = {}
84
- total_epochs = max(1, self.config.max_epochs)
85
- start_time = time.perf_counter()
86
 
87
  with mlflow.start_run(run_name=self.config.run_name):
88
- # Log configuration
89
- mlflow.log_params(
90
- {
91
- "max_epochs": self.config.max_epochs,
92
- "gradient_clip_norm": self.config.gradient_clip_norm,
93
- "label_smoothing": self.config.label_smoothing,
94
- "task_weights": str(self.config.task_weights),
95
- "device": str(self.device),
96
- }
97
  )
98
 
99
- for epoch in range(1, total_epochs + 1):
100
  epoch_start = time.perf_counter()
101
- train_metrics = self._run_epoch(
102
- train_loaders,
103
- train=True,
104
- epoch=epoch,
105
- total_epochs=total_epochs,
106
- epoch_start=epoch_start,
107
- global_start=start_time,
108
- )
109
- history[f"train_epoch_{epoch}"] = train_metrics
110
 
111
- # Log training metrics to MLflow
112
- for k, v in train_metrics.items():
113
- if k != "epoch":
114
- mlflow.log_metric(f"train_{k}", v, step=epoch)
115
 
 
116
  if val_loaders:
117
  val_metrics = self._run_epoch(val_loaders, train=False, epoch=epoch)
118
  history[f"val_epoch_{epoch}"] = val_metrics
 
119
 
120
- # Log validation metrics to MLflow
121
- for k, v in val_metrics.items():
122
- if k != "epoch":
123
- mlflow.log_metric(f"val_{k}", v, step=epoch)
124
-
125
- # Generate sample summaries for manual quality assessment
126
  if "summarization" in val_loaders:
127
  self._validate_generation(val_loaders["summarization"], epoch)
128
 
129
- # Save checkpoint after each epoch
130
- if checkpoint_callback is not None:
131
  checkpoint_callback(epoch, self.model, history)
132
 
133
- epoch_duration = time.perf_counter() - epoch_start
134
- total_elapsed = time.perf_counter() - start_time
135
- self._print_epoch_progress(epoch, total_epochs, epoch_duration, total_elapsed)
 
 
 
 
 
 
 
136
 
 
 
137
  return history
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def _run_epoch(
140
  self,
141
  loaders: Dict[str, DataLoader],
142
  *,
143
  train: bool,
144
  epoch: int,
145
- total_epochs: int | None = None,
146
- epoch_start: float | None = None,
147
- global_start: float | None = None,
148
  ) -> Dict[str, float]:
149
- phase = "train" if train else "eval"
 
150
  self.model.train(train)
151
- metrics_accumulator: Dict[str, list[float]] = defaultdict(list)
152
- iterator_map: Dict[str, Iterator[Dict[str, torch.Tensor]]] = {
153
- task: iter(loader) for task, loader in loaders.items()
154
- }
155
  max_batches = max(len(loader) for loader in loaders.values())
156
- progress_enabled = (
157
- train
158
- and max_batches > 0
159
- and total_epochs is not None
160
- and epoch_start is not None
161
- and global_start is not None
 
 
 
162
  )
163
 
164
- def emit_progress(step: int, final: bool = False) -> None:
165
- if not progress_enabled:
166
- return
167
- total_epochs_value = total_epochs
168
- epoch_start_value = epoch_start
169
- global_start_value = global_start
170
- assert total_epochs_value is not None
171
- assert epoch_start_value is not None
172
- assert global_start_value is not None
173
- self._update_epoch_progress(
174
- epoch=epoch,
175
- total_epochs=total_epochs_value,
176
- step=step,
177
- total_steps=max_batches,
178
- epoch_start=epoch_start_value,
179
- global_start=global_start_value,
180
- final=final,
181
- )
182
-
183
- emit_progress(0)
184
-
185
  context = torch.enable_grad() if train else torch.no_grad()
186
  with context:
187
- for step in range(max_batches):
188
- # Mark step begin for CUDA Graphs (inductor) to handle memory reuse correctly
189
- if (
190
- train
191
- and self.device.type == "cuda"
192
- and hasattr(torch.compiler, "cudagraph_mark_step_begin")
193
- ):
194
- torch.compiler.cudagraph_mark_step_begin()
195
-
196
- backward_performed = False
197
- step_total_loss = 0.0
198
-
199
- # Mixed Precision Context
200
- # Using bfloat16 for my RTX 4070 (Ampere/Ada) - better stability than float16
201
- # Disable scaler for bfloat16 to prevent NaNs
202
- use_bfloat16 = self.device.type == "cuda" and torch.cuda.is_bf16_supported()
203
 
204
  for task, loader in loaders.items():
205
- batch = self._next_batch(iterator_map, loader, task)
206
  if batch is None:
207
  continue
208
 
209
- with torch.autocast(
210
- "cuda",
211
- dtype=torch.bfloat16 if use_bfloat16 else torch.float16,
212
- enabled=(self.device.type == "cuda"),
213
- ):
214
- loss, task_metrics = self._forward_task(task, batch, train)
215
 
 
216
  if torch.isnan(loss):
217
- if train:
218
- self._nan_counter += 1
219
- print(
220
- f"Warning: NaN loss detected for task '{task}'. Skipping update for this task. (Consecutive NaNs: {self._nan_counter})"
221
- )
222
- if self._nan_counter > 10:
223
- raise RuntimeError(
224
- "Too many consecutive NaN losses. Training is diverging."
225
- )
226
  continue
227
- else:
228
- if train:
229
- self._nan_counter = 0
230
-
231
- weight = self._task_weight(task)
232
- # Scale loss by gradient accumulation steps
233
- weighted_loss = (loss * weight) / self.gradient_accumulation_steps
234
- step_total_loss += weighted_loss.item() * self.gradient_accumulation_steps
235
 
236
- metrics_accumulator[f"{task}_loss"].append(loss.item())
237
- for metric_name, metric_value in task_metrics.items():
238
- metrics_accumulator[f"{task}_{metric_name}"].append(metric_value)
 
239
 
 
240
  if train:
241
- # Scale loss before backward to prevent underflow
242
- # We accumulate gradients from all tasks before stepping the optimizer
243
- # This effectively minimizes the weighted sum of losses: L_total = w1*L1 + w2*L2 + ...
244
- if use_bfloat16:
245
- # bfloat16 doesn't need scaling and it can cause NaNs
246
- weighted_loss.backward()
247
  else:
248
- self.scaler.scale(weighted_loss).backward()
249
- backward_performed = True
250
-
251
- if backward_performed:
252
- metrics_accumulator["total_loss"].append(step_total_loss)
253
-
254
- # Perform optimizer step only after accumulating enough gradients
255
- if (
256
- train
257
- and backward_performed
258
- and (step + 1) % self.gradient_accumulation_steps == 0
259
- ):
260
- # Unscale gradients before clipping
261
- if use_bfloat16:
262
- torch.nn.utils.clip_grad_norm_(
263
- self.model.parameters(), self.config.gradient_clip_norm
264
- )
265
- self.optimizer.step()
266
- self.optimizer.zero_grad()
267
- else:
268
- self.scaler.unscale_(self.optimizer)
269
- torch.nn.utils.clip_grad_norm_(
270
- self.model.parameters(), self.config.gradient_clip_norm
271
- )
272
-
273
- # Step optimizer using scaler
274
- self.scaler.step(self.optimizer)
275
- self.scaler.update()
276
- self.optimizer.zero_grad()
277
-
278
- if (
279
- train
280
- and self.config.logging_interval
281
- and (step + 1) % self.config.logging_interval == 0
282
- ):
283
- if torch.cuda.is_available() and self.device.type == "cuda":
284
- torch.cuda.empty_cache()
285
- emit_progress(step + 1)
286
- emit_progress(max_batches, final=True)
287
-
288
- averaged = {
289
- name: sum(values) / len(values)
290
- for name, values in metrics_accumulator.items()
291
- if values
292
- }
293
  averaged["epoch"] = float(epoch)
294
- metric_str = ", ".join(f"{k}={v:.4f}" for k, v in averaged.items() if k != "epoch")
295
- print(f"[{phase}] epoch {epoch}: {metric_str}")
 
 
 
296
  return averaged
297
 
298
- def _next_batch(
299
- self,
300
- iterator_map: Dict[str, Iterator[Dict[str, torch.Tensor]]],
301
- loader: DataLoader,
302
- task: str,
 
 
 
 
 
 
 
 
 
303
  ) -> Dict[str, torch.Tensor] | None:
 
304
  try:
305
- batch = next(iterator_map[task])
306
  except StopIteration:
307
- iterator_map[task] = iter(loader)
308
  try:
309
- batch = next(iterator_map[task])
310
  except StopIteration:
311
  return None
312
  return {
313
- key: value.to(self.device) if isinstance(value, torch.Tensor) else value
314
- for key, value in batch.items()
315
  }
316
 
 
 
317
  def _forward_task(
318
- self, task: str, batch: Dict[str, torch.Tensor], train: bool
319
  ) -> tuple[torch.Tensor, Dict[str, float]]:
 
320
  if task == "summarization":
321
- summarization_inputs = {
322
- "src_ids": batch["src_ids"],
323
- "tgt_ids": batch["tgt_ids"],
324
- }
325
- if "src_mask" in batch:
326
- summarization_inputs["src_mask"] = batch["src_mask"]
327
- logits = self.model.forward("summarization", summarization_inputs)
328
- vocab_size = logits.size(-1)
329
- # Apply label smoothing for regularization - prevents overconfident predictions
330
- loss = F.cross_entropy(
331
- logits.view(-1, vocab_size),
332
- batch["labels"].view(-1),
333
- ignore_index=-100,
334
- label_smoothing=self.label_smoothing,
335
- )
336
- summaries = self._decode_predictions(logits)
337
- references = self._decode_labels(batch["labels"])
338
- rouge = rouge_like(summaries, references)
339
- return loss, {"rouge_like": rouge}
340
-
341
- if task == "emotion":
342
- emotion_inputs = {"input_ids": batch["input_ids"]}
343
- if "attention_mask" in batch:
344
- emotion_inputs["attention_mask"] = batch["attention_mask"]
345
- logits = self.model.forward("emotion", emotion_inputs)
346
- loss = self.emotion_loss(logits, batch["labels"].float())
347
- probs = torch.sigmoid(logits)
348
- preds = (probs > 0.5).int()
349
- labels = batch["labels"].int()
350
- f1 = multilabel_f1(preds, labels)
351
- return loss, {"f1": f1}
352
-
353
- if task == "topic":
354
- topic_inputs = {"input_ids": batch["input_ids"]}
355
- if "attention_mask" in batch:
356
- topic_inputs["attention_mask"] = batch["attention_mask"]
357
- logits = self.model.forward("topic", topic_inputs)
358
- loss = self.topic_loss(logits, batch["labels"])
359
- preds = logits.argmax(dim=-1)
360
- acc = accuracy(preds.tolist(), batch["labels"].tolist())
361
- return loss, {"accuracy": acc}
362
-
363
- raise ValueError(f"Unknown task '{task}'")
364
-
365
- def _task_weight(self, task: str) -> float:
366
- if not self.config.task_weights:
367
- return 1.0
368
- return self.config.task_weights.get(task, 1.0)
369
-
370
- def _decode_predictions(self, logits: torch.Tensor) -> List[str]:
371
- generated = logits.argmax(dim=-1)
372
- return self.tokenizer.decode_batch(generated.tolist())
 
373
 
374
  def _decode_labels(self, labels: torch.Tensor) -> List[str]:
 
375
  valid = labels.clone()
376
  valid[valid == -100] = self.tokenizer.pad_token_id
377
  return self.tokenizer.decode_batch(valid.tolist())
378
 
 
 
379
  def _validate_generation(self, val_loader: DataLoader, epoch: int) -> None:
380
- """Generate and print sample summaries to monitor quality during training."""
381
  self.model.eval()
382
- samples_generated = 0
383
- print(f"\n{'=' * 80}")
384
- print(f"[Validation Generation - Epoch {epoch}]")
385
- print(f"{'=' * 80}")
 
386
 
387
  with torch.no_grad():
388
- for batch in val_loader:
389
- if samples_generated >= self.config.validation_samples:
390
  break
391
 
392
  batch = {
393
  k: v.to(self.device) if isinstance(v, torch.Tensor) else v
394
  for k, v in batch.items()
395
  }
396
- src_ids = batch["src_ids"]
397
  src_mask = batch.get("src_mask")
398
- labels = batch["labels"]
399
-
400
- # Only process first item from batch
401
- src_ids = src_ids[:1]
402
  if src_mask is not None:
403
  src_mask = src_mask[:1]
404
- labels = labels[:1]
405
 
406
- # Encode source
407
- encoder_mask = None
408
- if src_mask is not None:
409
- encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2)
410
- memory = self.model.encoder(src_ids, mask=encoder_mask)
411
-
412
- # DEBUG: Check encoder output statistics
413
- if samples_generated == 0:
414
- print("\n[DEBUG] Encoder output stats:")
415
- print(f" Shape: {memory.shape}")
416
- print(f" Mean: {memory.mean().item():.6f}")
417
- print(f" Std: {memory.std().item():.6f}")
418
- print(f" Min: {memory.min().item():.6f}")
419
- print(f" Max: {memory.max().item():.6f}")
420
- print(f" Has NaN: {torch.isnan(memory).any().item()}")
421
- print(f" Has Inf: {torch.isinf(memory).any().item()}")
422
-
423
- # Check first few positions
424
- print(f" First position norm: {memory[0, 0].norm().item():.4f}")
425
- print(f" Last position norm: {memory[0, -1].norm().item():.4f}")
426
-
427
- # Ban special tokens from generation
428
- ban_token_ids = [self.tokenizer.bos_token_id, self.tokenizer.pad_token_id]
429
- unk_id = getattr(self.tokenizer._tokenizer, "unk_token_id", None)
430
- if isinstance(unk_id, int):
431
- ban_token_ids.append(unk_id)
432
- ban_token_ids = [tid for tid in ban_token_ids if tid is not None]
433
-
434
- # Generate using naive method (full forward, O(N^2)) for debugging
435
- generated = self.model.decoder.greedy_decode_naive(
436
  memory=memory,
437
  max_len=self.config.validation_max_length,
438
  start_token_id=self.tokenizer.bos_token_id,
@@ -441,139 +376,17 @@ class Trainer:
441
  memory_mask=src_mask,
442
  )
443
 
444
- # Decode
445
- source_text = self.tokenizer.decode(src_ids[0].tolist())
446
- generated_text = self.tokenizer.decode(generated[0].tolist())
447
- reference_text = self._decode_labels(labels)[0]
448
-
449
- print(f"\nSample {samples_generated + 1}:")
450
- print(
451
- f"Raw token IDs: {generated[0][:20].tolist()}..."
452
- ) # Debug: show first 20 tokens
453
- print(
454
- f"Source: {source_text[:200]}..."
455
- if len(source_text) > 200
456
- else f"Source: {source_text}"
457
- )
458
- print(f"Generated: {generated_text}")
459
- print(
460
- f"Reference: {reference_text[:200]}..."
461
- if len(reference_text) > 200
462
- else f"Reference: {reference_text}"
463
- )
464
- print("-" * 80)
465
 
466
- samples_generated += 1
 
 
 
 
 
467
 
468
- print(f"{'=' * 80}\n")
469
  self.model.train()
470
-
471
- def _print_epoch_progress(
472
- self,
473
- epoch: int,
474
- total_epochs: int,
475
- epoch_duration: float,
476
- total_elapsed: float,
477
- ) -> None:
478
- progress = epoch / total_epochs
479
- percent = progress * 100
480
- remaining_epochs = total_epochs - epoch
481
- eta = (total_elapsed / epoch) * remaining_epochs if epoch > 0 else 0.0
482
- bar = self._format_progress_bar(progress)
483
- message = (
484
- f"[progress] {bar} {percent:5.1f}% | epoch {epoch}/{total_epochs} "
485
- f"| last {epoch_duration:6.2f}s | total {total_elapsed:6.2f}s | ETA {eta:6.2f}s"
486
- )
487
- print(message, flush=True)
488
-
489
- @staticmethod
490
- def _format_progress_bar(progress: float, width: int = 20) -> str:
491
- clamped = max(0.0, min(1.0, progress))
492
- filled = int(round(clamped * width))
493
- bar = "#" * filled + "-" * (width - filled)
494
- return f"[{bar}]"
495
-
496
- def _update_epoch_progress(
497
- self,
498
- *,
499
- epoch: int,
500
- total_epochs: int,
501
- step: int,
502
- total_steps: int,
503
- epoch_start: float,
504
- global_start: float,
505
- final: bool = False,
506
- ) -> None:
507
- if total_steps <= 0 or total_epochs <= 0:
508
- return
509
- bounded_step = max(0, min(step, total_steps))
510
- step_fraction = bounded_step / total_steps
511
- epochs_completed = (epoch - 1) + step_fraction
512
- overall_progress = epochs_completed / total_epochs
513
- percent = overall_progress * 100.0
514
- epoch_elapsed = time.perf_counter() - epoch_start
515
- total_elapsed = time.perf_counter() - global_start
516
- if epochs_completed > 0:
517
- remaining_epochs = max(total_epochs - epochs_completed, 0.0)
518
- total_eta = (
519
- (total_elapsed / epochs_completed) * remaining_epochs if total_elapsed > 0 else 0.0
520
- )
521
- else:
522
- total_eta = 0.0
523
-
524
- if step > 0:
525
- epoch_eta = (epoch_elapsed / step) * (total_steps - step)
526
- else:
527
- epoch_eta = 0.0
528
-
529
- bar = self._format_progress_bar(overall_progress, width=self._progress_bar_width())
530
- message = (
531
- f"[progress] {bar} {percent:5.1f}% "
532
- f"e {epoch}/{total_epochs} "
533
- f"s {bounded_step}/{total_steps} "
534
- f"ep_eta {self._format_duration(epoch_eta)} "
535
- f"tot_eta {self._format_duration(total_eta)}"
536
- )
537
- display = self._truncate_to_terminal(message)
538
- padding = " " * max(self._progress_last_len - len(display), 0)
539
- print(f"\r{display}{padding}", end="", flush=True)
540
- if final:
541
- print()
542
- self._progress_last_len = 0
543
- else:
544
- self._progress_last_len = len(display)
545
-
546
- def _truncate_to_terminal(self, text: str) -> str:
547
- columns = self._terminal_width()
548
- if columns <= 0:
549
- return text
550
- if len(text) >= columns:
551
- return text[: max(columns - 1, 1)]
552
- return text
553
-
554
- def _progress_bar_width(self) -> int:
555
- columns = self._terminal_width()
556
- reserved = 60
557
- if columns <= reserved:
558
- return 10
559
- return max(10, min(30, columns - reserved))
560
-
561
- @staticmethod
562
- def _terminal_width() -> int:
563
- try:
564
- return shutil.get_terminal_size(fallback=(120, 20)).columns
565
- except OSError:
566
- return 120
567
-
568
- @staticmethod
569
- def _format_duration(seconds: float) -> str:
570
- seconds = max(0.0, seconds)
571
- if seconds >= 3600:
572
- hours = int(seconds // 3600)
573
- minutes = int((seconds % 3600) // 60)
574
- return f"{hours}h{minutes:02}m"
575
- if seconds >= 60:
576
- minutes = int(seconds // 60)
577
- secs = int(seconds % 60)
578
- return f"{minutes}m{secs:02}s"
579
- return f"{seconds:4.1f}s"
 
1
+ """
2
+ Multi-task Trainer for LexiMind.
3
+
4
+ Handles training across summarization, emotion, and topic heads with mixed-precision,
5
+ gradient accumulation, and MLflow logging.
6
+
7
+ Author: Oliver Perrin
8
+ Date: December 2025
9
+ """
10
 
11
  from __future__ import annotations
12
 
 
13
  import time
14
  from collections import defaultdict
15
  from dataclasses import dataclass
16
+ from typing import Any, Callable, Dict, List
17
 
18
  import mlflow
19
  import torch
20
  import torch.nn.functional as F
21
  from torch.utils.data import DataLoader
22
+ from tqdm import tqdm
23
 
24
  from ..data.tokenization import Tokenizer
25
  from .metrics import accuracy, multilabel_f1, rouge_like
26
 
27
+ # --------------- Configuration ---------------
28
+
29
 
30
  @dataclass
31
  class TrainerConfig:
32
+ """Training hyperparameters."""
33
+
34
  max_epochs: int = 1
35
  gradient_clip_norm: float = 1.0
 
36
  task_weights: Dict[str, float] | None = None
37
  validation_samples: int = 3
38
  validation_max_length: int = 128
39
+ label_smoothing: float = 0.0
40
  experiment_name: str = "LexiMind"
41
  run_name: str | None = None
42
  gradient_accumulation_steps: int = 1
43
 
44
 
45
+ # --------------- Trainer ---------------
46
  class Trainer:
47
+ """Multi-task trainer with AMP and gradient accumulation."""
48
 
49
  def __init__(
50
  self,
 
59
  self.config = config
60
  self.device = device
61
  self.tokenizer = tokenizer
62
+
63
+ # Task losses
64
  self.emotion_loss = torch.nn.BCEWithLogitsLoss()
65
  self.topic_loss = torch.nn.CrossEntropyLoss()
66
+
67
+ # AMP setup: bfloat16 for Ampere+ GPUs, float16 otherwise
68
+ self.use_amp = device.type == "cuda"
69
+ self.use_bfloat16 = self.use_amp and torch.cuda.is_bf16_supported()
70
+ self.scaler = torch.GradScaler("cuda", enabled=(self.use_amp and not self.use_bfloat16))
71
+
72
+ self._nan_counter = 0
 
 
 
 
 
 
73
  mlflow.set_experiment(config.experiment_name)
74
 
75
+ # CUDA optimizations
76
+ if device.type == "cuda":
77
+ torch.backends.cuda.enable_flash_sdp(True)
78
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
79
+
80
+ # --------------- Training Loop ---------------
81
+
82
  def fit(
83
  self,
84
  train_loaders: Dict[str, DataLoader],
85
  val_loaders: Dict[str, DataLoader] | None = None,
86
  checkpoint_callback: Callable | None = None,
87
  ) -> Dict[str, Dict[str, float]]:
88
+ """Train model across all tasks with progress tracking."""
 
 
 
 
 
 
 
 
 
89
  history: Dict[str, Dict[str, float]] = {}
90
+ total_start = time.perf_counter()
 
91
 
92
  with mlflow.start_run(run_name=self.config.run_name):
93
+ self._log_config()
94
+
95
+ # Epoch progress bar
96
+ epoch_pbar = tqdm(
97
+ range(1, self.config.max_epochs + 1),
98
+ desc="Training",
99
+ unit="epoch",
100
+ position=0,
 
101
  )
102
 
103
+ for epoch in epoch_pbar:
104
  epoch_start = time.perf_counter()
 
 
 
 
 
 
 
 
 
105
 
106
+ # Train
107
+ train_metrics = self._run_epoch(train_loaders, train=True, epoch=epoch)
108
+ history[f"train_epoch_{epoch}"] = train_metrics
109
+ self._log_metrics(train_metrics, "train", epoch)
110
 
111
+ # Validate
112
  if val_loaders:
113
  val_metrics = self._run_epoch(val_loaders, train=False, epoch=epoch)
114
  history[f"val_epoch_{epoch}"] = val_metrics
115
+ self._log_metrics(val_metrics, "val", epoch)
116
 
 
 
 
 
 
 
117
  if "summarization" in val_loaders:
118
  self._validate_generation(val_loaders["summarization"], epoch)
119
 
120
+ # Checkpoint
121
+ if checkpoint_callback:
122
  checkpoint_callback(epoch, self.model, history)
123
 
124
+ # Update epoch progress bar with metrics
125
+ epoch_time = time.perf_counter() - epoch_start
126
+ total_time = time.perf_counter() - total_start
127
+ desc = f"Epoch {epoch}/{self.config.max_epochs}"
128
+ if "total_loss" in train_metrics:
129
+ desc += f" | loss={train_metrics['total_loss']:.3f}"
130
+ epoch_pbar.set_description(desc)
131
+ epoch_pbar.set_postfix(
132
+ {"time": f"{epoch_time:.1f}s", "total": f"{total_time:.1f}s"}
133
+ )
134
 
135
+ total_time = time.perf_counter() - total_start
136
+ print(f"\n✓ Training complete in {total_time:.1f}s")
137
  return history
138
 
139
+ def _log_config(self) -> None:
140
+ """Log config to MLflow."""
141
+ mlflow.log_params(
142
+ {
143
+ "max_epochs": self.config.max_epochs,
144
+ "gradient_clip_norm": self.config.gradient_clip_norm,
145
+ "label_smoothing": self.config.label_smoothing,
146
+ "task_weights": str(self.config.task_weights),
147
+ }
148
+ )
149
+
150
+ def _log_metrics(self, metrics: Dict[str, float], prefix: str, epoch: int) -> None:
151
+ """Log metrics to MLflow."""
152
+ for k, v in metrics.items():
153
+ if k != "epoch":
154
+ mlflow.log_metric(f"{prefix}_{k}", v, step=epoch)
155
+
156
+ # --------------- Epoch Execution ---------------
157
+
158
  def _run_epoch(
159
  self,
160
  loaders: Dict[str, DataLoader],
161
  *,
162
  train: bool,
163
  epoch: int,
 
 
 
164
  ) -> Dict[str, float]:
165
+ """Run one epoch with progress bar."""
166
+ phase = "Train" if train else "Val"
167
  self.model.train(train)
168
+
169
+ metrics: Dict[str, List[float]] = defaultdict(list)
170
+ iterators = {task: iter(loader) for task, loader in loaders.items()}
 
171
  max_batches = max(len(loader) for loader in loaders.values())
172
+ accum_steps = self.config.gradient_accumulation_steps
173
+
174
+ # Batch progress bar (nested under epoch bar)
175
+ pbar = tqdm(
176
+ range(max_batches),
177
+ desc=f" {phase}",
178
+ unit="batch",
179
+ leave=False,
180
+ position=1,
181
  )
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  context = torch.enable_grad() if train else torch.no_grad()
184
  with context:
185
+ for step in pbar:
186
+ step_loss = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  for task, loader in loaders.items():
189
+ batch = self._get_batch(iterators, loader, task)
190
  if batch is None:
191
  continue
192
 
193
+ # Forward with AMP
194
+ amp_dtype = torch.bfloat16 if self.use_bfloat16 else torch.float16
195
+ with torch.autocast("cuda", dtype=amp_dtype, enabled=self.use_amp):
196
+ loss, task_metrics = self._forward_task(task, batch)
 
 
197
 
198
+ # NaN check
199
  if torch.isnan(loss):
200
+ self._nan_counter += 1
201
+ if self._nan_counter > 10:
202
+ raise RuntimeError("Training diverging - too many NaN losses")
 
 
 
 
 
 
203
  continue
204
+ self._nan_counter = 0
 
 
 
 
 
 
 
205
 
206
+ # Record metrics
207
+ metrics[f"{task}_loss"].append(loss.item())
208
+ for name, val in task_metrics.items():
209
+ metrics[f"{task}_{name}"].append(val)
210
 
211
+ # Backward
212
  if train:
213
+ weight = (self.config.task_weights or {}).get(task, 1.0)
214
+ scaled = (loss * weight) / accum_steps
215
+ step_loss += scaled.item() * accum_steps
216
+
217
+ if self.use_bfloat16:
218
+ scaled.backward()
219
  else:
220
+ self.scaler.scale(scaled).backward()
221
+
222
+ # Optimizer step
223
+ if train and (step + 1) % accum_steps == 0:
224
+ self._optimizer_step()
225
+
226
+ if step_loss > 0:
227
+ metrics["total_loss"].append(step_loss)
228
+
229
+ # Update progress bar
230
+ if metrics["total_loss"]:
231
+ pbar.set_postfix({"loss": f"{metrics['total_loss'][-1]:.3f}"})
232
+
233
+ # Average and print summary
234
+ averaged = {k: sum(v) / len(v) for k, v in metrics.items() if v}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  averaged["epoch"] = float(epoch)
236
+
237
+ summary = f"[{phase.lower()}] epoch {epoch}: "
238
+ summary += ", ".join(f"{k}={v:.4f}" for k, v in averaged.items() if k != "epoch")
239
+ tqdm.write(summary)
240
+
241
  return averaged
242
 
243
+ def _optimizer_step(self) -> None:
244
+ """Optimizer step with gradient clipping."""
245
+ if self.use_bfloat16:
246
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip_norm)
247
+ self.optimizer.step()
248
+ else:
249
+ self.scaler.unscale_(self.optimizer)
250
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip_norm)
251
+ self.scaler.step(self.optimizer)
252
+ self.scaler.update()
253
+ self.optimizer.zero_grad()
254
+
255
+ def _get_batch(
256
+ self, iterators: Dict, loader: DataLoader, task: str
257
  ) -> Dict[str, torch.Tensor] | None:
258
+ """Get next batch, cycling iterator if exhausted."""
259
  try:
260
+ batch = next(iterators[task])
261
  except StopIteration:
262
+ iterators[task] = iter(loader)
263
  try:
264
+ batch = next(iterators[task])
265
  except StopIteration:
266
  return None
267
  return {
268
+ k: v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v
269
+ for k, v in batch.items()
270
  }
271
 
272
+ # --------------- Task Forward Passes ---------------
273
+
274
  def _forward_task(
275
+ self, task: str, batch: Dict[str, torch.Tensor]
276
  ) -> tuple[torch.Tensor, Dict[str, float]]:
277
+ """Route to task-specific forward pass."""
278
  if task == "summarization":
279
+ return self._forward_summarization(batch)
280
+ elif task == "emotion":
281
+ return self._forward_emotion(batch)
282
+ elif task == "topic":
283
+ return self._forward_topic(batch)
284
+ raise ValueError(f"Unknown task: {task}")
285
+
286
+ def _forward_summarization(
287
+ self, batch: Dict[str, torch.Tensor]
288
+ ) -> tuple[torch.Tensor, Dict[str, float]]:
289
+ """Seq2seq forward for summarization."""
290
+ inputs = {"src_ids": batch["src_ids"], "tgt_ids": batch["tgt_ids"]}
291
+ if "src_mask" in batch:
292
+ inputs["src_mask"] = batch["src_mask"]
293
+
294
+ logits = self.model.forward("summarization", inputs)
295
+ loss = F.cross_entropy(
296
+ logits.view(-1, logits.size(-1)),
297
+ batch["labels"].view(-1),
298
+ ignore_index=-100,
299
+ label_smoothing=self.config.label_smoothing,
300
+ )
301
+
302
+ # Quick ROUGE estimate
303
+ preds = self.tokenizer.decode_batch(logits.argmax(dim=-1).tolist())
304
+ refs = self._decode_labels(batch["labels"])
305
+ return loss, {"rouge_like": rouge_like(preds, refs)}
306
+
307
+ def _forward_emotion(
308
+ self, batch: Dict[str, torch.Tensor]
309
+ ) -> tuple[torch.Tensor, Dict[str, float]]:
310
+ """Multi-label emotion classification."""
311
+ inputs = {"input_ids": batch["input_ids"]}
312
+ if "attention_mask" in batch:
313
+ inputs["attention_mask"] = batch["attention_mask"]
314
+
315
+ logits = self.model.forward("emotion", inputs)
316
+ loss = self.emotion_loss(logits, batch["labels"].float())
317
+ preds = (torch.sigmoid(logits) > 0.5).int()
318
+ return loss, {"f1": multilabel_f1(preds, batch["labels"].int())}
319
+
320
+ def _forward_topic(
321
+ self, batch: Dict[str, torch.Tensor]
322
+ ) -> tuple[torch.Tensor, Dict[str, float]]:
323
+ """Single-label topic classification."""
324
+ inputs = {"input_ids": batch["input_ids"]}
325
+ if "attention_mask" in batch:
326
+ inputs["attention_mask"] = batch["attention_mask"]
327
+
328
+ logits = self.model.forward("topic", inputs)
329
+ loss = self.topic_loss(logits, batch["labels"])
330
+ preds = logits.argmax(dim=-1)
331
+ return loss, {"accuracy": accuracy(preds.tolist(), batch["labels"].tolist())}
332
 
333
  def _decode_labels(self, labels: torch.Tensor) -> List[str]:
334
+ """Decode labels, replacing -100 with pad token."""
335
  valid = labels.clone()
336
  valid[valid == -100] = self.tokenizer.pad_token_id
337
  return self.tokenizer.decode_batch(valid.tolist())
338
 
339
+ # --------------- Validation Generation ---------------
340
+
341
  def _validate_generation(self, val_loader: DataLoader, epoch: int) -> None:
342
+ """Generate sample summaries for quality check."""
343
  self.model.eval()
344
+ n = self.config.validation_samples
345
+
346
+ tqdm.write(f"\n{'=' * 50}")
347
+ tqdm.write(f"[Validation Samples - Epoch {epoch}]")
348
+ tqdm.write(f"{'=' * 50}")
349
 
350
  with torch.no_grad():
351
+ for i, batch in enumerate(val_loader):
352
+ if i >= n:
353
  break
354
 
355
  batch = {
356
  k: v.to(self.device) if isinstance(v, torch.Tensor) else v
357
  for k, v in batch.items()
358
  }
359
+ src_ids = batch["src_ids"][:1]
360
  src_mask = batch.get("src_mask")
 
 
 
 
361
  if src_mask is not None:
362
  src_mask = src_mask[:1]
 
363
 
364
+ # Encode and generate
365
+ enc_mask = (
366
+ src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
367
+ )
368
+ model: Any = self.model
369
+ memory = model.encoder(src_ids, mask=enc_mask)
370
+ generated = model.decoder.greedy_decode_naive(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  memory=memory,
372
  max_len=self.config.validation_max_length,
373
  start_token_id=self.tokenizer.bos_token_id,
 
376
  memory_mask=src_mask,
377
  )
378
 
379
+ # Decode and display
380
+ src = self.tokenizer.decode(src_ids[0].tolist())
381
+ out = self.tokenizer.decode(generated[0].tolist())
382
+ ref = self._decode_labels(batch["labels"][:1])[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
+ tqdm.write(f"\nSample {i + 1}:")
385
+ tqdm.write(f" Source: {src[:120]}..." if len(src) > 120 else f" Source: {src}")
386
+ tqdm.write(f" Generated: {out}")
387
+ tqdm.write(
388
+ f" Reference: {ref[:120]}..." if len(ref) > 120 else f" Reference: {ref}"
389
+ )
390
 
391
+ tqdm.write(f"{'=' * 50}\n")
392
  self.model.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/utils.py CHANGED
@@ -1,4 +1,12 @@
1
- """Small training helpers."""
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
1
+ """
2
+ Training utilities for LexiMind.
3
+
4
+ Provides reproducibility helpers including seed management for stdlib, PyTorch,
5
+ and NumPy random number generators with thread-safe spawning support.
6
+
7
+ Author: Oliver Perrin
8
+ Date: December 2025
9
+ """
10
 
11
  from __future__ import annotations
12
 
src/utils/config.py CHANGED
@@ -1,4 +1,11 @@
1
- """YAML config loader."""
 
 
 
 
 
 
 
2
 
3
  from dataclasses import dataclass
4
  from pathlib import Path
 
1
+ """
2
+ Configuration utilities for LexiMind.
3
+
4
+ Provides YAML configuration loading with validation.
5
+
6
+ Author: Oliver Perrin
7
+ Date: December 2025
8
+ """
9
 
10
  from dataclasses import dataclass
11
  from pathlib import Path
src/utils/io.py CHANGED
@@ -1,4 +1,11 @@
1
- """Checkpoint IO helpers."""
 
 
 
 
 
 
 
2
 
3
  from pathlib import Path
4
 
 
1
+ """
2
+ Checkpoint I/O utilities for LexiMind.
3
+
4
+ Handles model state serialization with support for torch.compile artifacts.
5
+
6
+ Author: Oliver Perrin
7
+ Date: December 2025
8
+ """
9
 
10
  from pathlib import Path
11
 
src/utils/labels.py CHANGED
@@ -1,4 +1,12 @@
1
- """Label metadata helpers for multitask inference."""
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
1
+ """
2
+ Label metadata utilities for LexiMind.
3
+
4
+ Manages persistence and loading of emotion and topic label vocabularies
5
+ for multitask inference.
6
+
7
+ Author: Oliver Perrin
8
+ Date: December 2025
9
+ """
10
 
11
  from __future__ import annotations
12
 
src/utils/logging.py CHANGED
@@ -1,4 +1,11 @@
1
- """Logging setup."""
 
 
 
 
 
 
 
2
 
3
  import logging
4
 
 
1
+ """
2
+ Logging utilities for LexiMind.
3
+
4
+ Provides centralized logging configuration and logger factory.
5
+
6
+ Author: Oliver Perrin
7
+ Date: December 2025
8
+ """
9
 
10
  import logging
11
 
src/utils/random.py CHANGED
@@ -1,4 +1,11 @@
1
- """Randomness helpers."""
 
 
 
 
 
 
 
2
 
3
  import random
4
 
 
1
+ """
2
+ Randomness utilities for LexiMind.
3
+
4
+ Provides seed management for reproducibility.
5
+
6
+ Author: Oliver Perrin
7
+ Date: December 2025
8
+ """
9
 
10
  import random
11
 
tests/test_training/test_trainer.py CHANGED
@@ -17,7 +17,7 @@ class TestTrainer(unittest.TestCase):
17
  self.model = MagicMock()
18
  self.model.to.return_value = self.model # Ensure .to() returns the same mock
19
  self.optimizer = MagicMock(spec=torch.optim.Optimizer)
20
- self.config = TrainerConfig(max_epochs=1, logging_interval=1)
21
  self.device = torch.device("cpu")
22
  self.tokenizer = MagicMock()
23
  self.tokenizer.pad_token_id = 0
 
17
  self.model = MagicMock()
18
  self.model.to.return_value = self.model # Ensure .to() returns the same mock
19
  self.optimizer = MagicMock(spec=torch.optim.Optimizer)
20
+ self.config = TrainerConfig(max_epochs=1)
21
  self.device = torch.device("cpu")
22
  self.tokenizer = MagicMock()
23
  self.tokenizer.pad_token_id = 0