OliverPerrin commited on
Commit
1fbc47b
·
1 Parent(s): f9edbb4

chore: snapshot current refinements

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +45 -153
  2. configs/data/datasets.yaml +26 -0
  3. configs/model/base.yaml +6 -50
  4. configs/model/large.yaml +6 -0
  5. configs/model/small.yaml +6 -23
  6. configs/training/default.yaml +12 -0
  7. configs/training/full.yaml +12 -0
  8. configs/training/quick_test.yaml +9 -0
  9. docker/Dockerfile +0 -0
  10. docker/docker-compose.yml +0 -0
  11. docs/api.md +79 -0
  12. docs/architecture.md +57 -0
  13. docs/training.md +59 -0
  14. pyproject.toml +6 -11
  15. requirements.txt +6 -16
  16. scripts/download_data.py +182 -0
  17. scripts/download_data.sh +5 -0
  18. scripts/evaluate.py +134 -0
  19. scripts/export_model.py +69 -0
  20. scripts/inference.py +112 -0
  21. scripts/preprocess_data.py +321 -0
  22. scripts/test_gpu.py +0 -27
  23. scripts/train.py +217 -6
  24. setup.py +16 -6
  25. src/__init__.py +1 -0
  26. src/api/__init__.py +1 -0
  27. src/api/app.py +10 -0
  28. src/api/dependencies.py +42 -0
  29. src/api/inference/__init__.py +0 -7
  30. src/api/inference/inference.py +0 -133
  31. src/api/routes.py +34 -0
  32. src/api/schemas.py +14 -0
  33. src/data/__init__.py +1 -0
  34. src/data/dataloader.py +117 -0
  35. src/data/dataset.py +229 -0
  36. src/data/download.py +39 -60
  37. src/data/preprocessing.py +95 -225
  38. src/data/tokenization.py +122 -0
  39. src/inference/__init__.py +10 -5
  40. src/inference/baseline_summarizer.py +0 -41
  41. src/inference/factory.py +75 -0
  42. src/inference/generation.py +14 -0
  43. src/inference/pipeline.py +166 -0
  44. src/inference/postprocessing.py +6 -0
  45. src/models/factory.py +105 -0
  46. src/models/multitask.py +48 -9
  47. src/training/__init__.py +1 -0
  48. src/training/callbacks.py +37 -0
  49. src/training/losses.py +13 -0
  50. src/training/metrics.py +36 -0
README.md CHANGED
@@ -1,175 +1,67 @@
1
- # LexiMind: Multi-Task Transformer for Document Analysis
2
 
3
- A PyTorch-based multi-task learning system that performs abstractive summarization, emotion classification, and topic clustering on textual data using a shared Transformer encoder architecture.
 
 
 
4
 
5
- ## 🎯 Project Overview
 
 
 
6
 
7
- LexiMind demonstrates multi-task learning (MTL) by training a single model to simultaneously:
8
- 1. **Abstractive Summarization**: Generate concise summaries with user-defined compression levels
9
- 2. **Emotion Classification**: Detect multiple emotions present in text (multi-label classification)
10
- 3. **Topic Clustering**: Group documents by semantic similarity for topic discovery
 
11
 
12
- ### Key Features
13
- - Custom encoder-decoder Transformer architecture with shared representations
14
- - Multi-task loss function with learnable task weighting
15
- - Attention weight visualization for model interpretability
16
- - Interactive web interface for real-time inference
17
- - Trained on diverse corpora: news articles (CNN/DailyMail, BBC) and literary texts (Project Gutenberg)
18
-
19
- ## 🏗️ Architecture
20
-
21
- ```
22
- Input Text
23
-
24
- ┌─────────────────────┐
25
- │ Shared Encoder │ ← TransformerEncoder (6 layers)
26
- │ (Multi-head Attn) │
27
- └─────────────────────┘
28
- ↓ ↓ ↓
29
- │ │ └──────────────┐
30
- │ │ │
31
- │ └─────────┐ │
32
- │ │ │
33
- ↓ ↓ ↓
34
- ┌─────────┐ ┌────────┐ ┌─────────┐
35
- │ Decoder │ │Classify│ │ Project │
36
- │ Head │ │ Head │ │ Head │
37
- └─────────┘ └────────┘ └─────────┘
38
- ↓ ↓ ↓
39
- Summary Emotions Embeddings
40
- (for clustering)
41
- ```
42
-
43
- ## 📊 Datasets
44
-
45
- - **CNN/DailyMail**: 300k+ news articles with human-written summaries
46
- - **BBC News**: 2,225 articles across 5 categories
47
- - **Project Gutenberg**: Classic literature for long-form text analysis
48
-
49
- ## 🚀 Quick Start
50
-
51
- ### Installation
52
  ```bash
53
  git clone https://github.com/OliverPerrin/LexiMind.git
54
  cd LexiMind
55
  pip install -r requirements.txt
56
- ```
57
 
58
- ### Download Data
59
- ```bash
60
- python src/download_datasets.py
61
- ```
62
 
63
- ### Train Model
64
- ```bash
65
- python src/train.py --config configs/default.yaml
66
- ```
67
-
68
- ### Launch Interface
69
- ```bash
70
- python src/app.py
71
  ```
72
 
73
- ## 📁 Project Structure
 
74
 
 
75
  ```
76
- LexiMind/
77
- ├── src/
78
- ├── models/
79
- │ │ ├── encoder.py # Shared Transformer encoder
80
- │ │ ├── summarization.py # Seq2seq decoder head
81
- │ │ ├── emotion.py # Multi-label classification head
82
- │ │ └── clustering.py # Projection head for embeddings
83
- │ ├── data/
84
- │ │ ├── download_datasets.py # Data acquisition
85
- │ │ ├── preprocessing.py # Text cleaning & tokenization
86
- │ │ └── dataset.py # PyTorch Dataset classes
87
- │ ├── training/
88
- │ │ ├── train.py # Training loop
89
- │ │ ├── losses.py # Multi-task loss functions
90
- │ │ └── metrics.py # ROUGE, F1, silhouette scores
91
- │ ├── inference/
92
- │ │ └── pipeline.py # End-to-end inference
93
- │ ├── visualization/
94
- │ │ └── attention.py # Attention heatmap generation
95
- │ └── app.py # Gradio/FastAPI interface
96
- ├── configs/
97
- │ └── default.yaml # Model & training hyperparameters
98
- ├── tests/
99
- │ └── test_*.py # Unit tests
100
- ├── notebooks/
101
- │ └── exploratory.ipynb # Data exploration & analysis
102
- ├── requirements.txt
103
- └── README.md
104
  ```
105
 
106
- ## 🧪 Evaluation Metrics
 
107
 
108
- | Task | Metric | Score |
109
- |------|--------|-------|
110
- | Summarization | ROUGE-1 / ROUGE-L | TBD |
111
- | Emotion Classification | Macro F1 | TBD |
112
- | Topic Clustering | Silhouette Score | TBD |
113
 
114
- ## 🔬 Technical Details
 
 
 
 
115
 
116
- ### Model Specifications
117
- - **Encoder**: 6-layer Transformer (d_model=512, 8 attention heads)
118
- - **Decoder**: 6-layer autoregressive Transformer
119
- - **Vocab Size**: 32,000 (SentencePiece tokenizer)
120
- - **Parameters**: ~60M total
121
-
122
- ### Training
123
- - **Optimizer**: AdamW (lr=1e-4, weight_decay=0.01)
124
- - **Scheduler**: Linear warmup (5000 steps) + cosine decay
125
- - **Loss**: Weighted sum of cross-entropy (summarization), BCE (emotions), triplet loss (clustering)
126
- - **Hardware**: Trained on single NVIDIA RTX 3090 (24GB VRAM)
127
- - **Time**: ~48 hours for 10 epochs
128
-
129
- ### Multi-Task Learning Strategy
130
- Uses uncertainty weighting ([Kendall et al., 2018](https://arxiv.org/abs/1705.07115)) to automatically balance task losses:
131
-
132
- ```
133
- L_total = Σ (1/2σ²_i * L_i + log(σ_i))
134
  ```
135
 
136
- where σ_i are learnable parameters representing task uncertainty.
137
-
138
- ## 🎨 Interface Preview
139
-
140
- The web interface provides:
141
- - Text input with real-time token count
142
- - Compression level slider (20%-80%)
143
- - Side-by-side original/summary comparison
144
- - Emotion probability bars with color coding
145
- - Interactive attention heatmap (click tokens to highlight attention)
146
- - Downloadable results (JSON/CSV)
147
-
148
- ## 📈 Future Enhancements
149
-
150
- - [ ] Add multilingual support (mBART)
151
- - [ ] Implement beam search for better summaries
152
- - [ ] Fine-tune on domain-specific corpora (medical, legal)
153
- - [ ] Add semantic search across document embeddings
154
- - [ ] Deploy as REST API with Docker
155
- - [ ] Implement model distillation for mobile deployment
156
-
157
- ## 📚 References
158
-
159
- - Vaswani et al. (2017) - [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
160
- - Lewis et al. (2019) - [BART: Denoising Sequence-to-Sequence Pre-training](https://arxiv.org/abs/1910.13461)
161
- - Caruana (1997) - [Multitask Learning](https://link.springer.com/article/10.1023/A:1007379606734)
162
- - Demszky et al. (2020) - [GoEmotions Dataset](https://arxiv.org/abs/2005.00547)
163
-
164
- ## 📄 License
165
-
166
- GNU General Public License v3.0
167
-
168
- ## 👤 Author
169
-
170
- **Oliver Perrin**
171
- - Portfolio: [oliverperrin.com](https://oliverperrin.com)
172
- - LinkedIn: [linkedin.com/in/oliverperrin](https://linkedin.com/in/oliverperrin)
173
- - Email: [email protected]
174
 
175
- ---
 
 
1
+ # LexiMind (Inference Edition)
2
 
3
+ LexiMind now ships as a focused inference sandbox for the custom multitask Transformer found in
4
+ `src/models`. Training, dataset downloaders, and legacy scripts have been removed so it is easy to
5
+ load a checkpoint, run the Streamlit demo, and experiment with summarization, emotion
6
+ classification, and topic cues on your own text.
7
 
8
+ ## What Stays
9
+ - Transformer encoder/decoder and task heads under `src/models`
10
+ - Unit tests for the model stack (`tests/test_models`)
11
+ - Streamlit UI (`src/ui/streamlit_app.py`) wired to the inference helpers in `src/api/inference`
12
 
13
+ ## What Changed
14
+ - Hugging Face tokenizers provide all tokenization (see `TextPreprocessor`)
15
+ - Training, dataset downloaders, and CLI scripts have been removed
16
+ - Scikit-learn powers light text normalization (stop-word removal optional)
17
+ - Requirements trimmed to inference-only dependencies
18
 
19
+ ## Quick Start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  ```bash
21
  git clone https://github.com/OliverPerrin/LexiMind.git
22
  cd LexiMind
23
  pip install -r requirements.txt
 
24
 
25
+ # Optional extras via setup.py packaging metadata
26
+ pip install .[web] # installs streamlit + plotly
27
+ pip install .[api] # installs fastapi
28
+ pip install .[all] # installs both groups
29
 
30
+ streamlit run src/ui/streamlit_app.py
 
 
 
 
 
 
 
31
  ```
32
 
33
+ Configure the Streamlit app via the sidebar to point at your tokenizer directory and model
34
+ checkpoint (defaults assume `artifacts/hf_tokenizer` and `checkpoints/best.pt`).
35
 
36
+ ## Minimal Project Map
37
  ```
38
+ src/
39
+ ├── api/ # load_models + helpers
40
+ ├── data/ # TextPreprocessor using Hugging Face + sklearn
41
+ ├── inference/ # thin summarizer facade
42
+ ├── models/ # core Transformer architecture (untouched)
43
+ └── ui/ # Streamlit interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ```
45
 
46
+ Everything outside `src/` now holds optional assets such as checkpoints, tokenizer exports, and
47
+ documentation stubs.
48
 
49
+ ## Loading a Checkpoint Programmatically
50
+ ```python
51
+ from src.api.inference import load_models, summarize_text
 
 
52
 
53
+ models = load_models({
54
+ "checkpoint_path": "checkpoints/best.pt",
55
+ "tokenizer_path": "artifacts/hf_tokenizer",
56
+ "hf_tokenizer_name": "facebook/bart-base",
57
+ })
58
 
59
+ summary, _ = summarize_text("Paste any article here.", models=models)
60
+ print(summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ```
62
 
63
+ ## License
64
+ GPL-3.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ ## Author
67
+ Oliver Perrin · [email protected]
configs/data/datasets.yaml CHANGED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ raw:
2
+ summarization: data/raw/summarization/cnn_dailymail
3
+ emotion: data/raw/emotion
4
+ topic: data/raw/topic
5
+ books: data/raw/books
6
+ processed:
7
+ summarization: data/processed/summarization
8
+ emotion: data/processed/emotion
9
+ topic: data/processed/topic
10
+ books: data/processed/books
11
+ tokenizer:
12
+ pretrained_model_name: facebook/bart-base
13
+ max_length: 512
14
+ lower: false
15
+ downloads:
16
+ summarization:
17
+ dataset: gowrishankarp/newspaper-text-summarization-cnn-dailymail
18
+ output: data/raw/summarization/cnn_dailymail
19
+ books:
20
+ - name: pride_and_prejudice
21
+ url: https://www.gutenberg.org/cache/epub/1342/pg1342.txt
22
+ output: data/raw/books/pride_and_prejudice.txt
23
+ emotion:
24
+ dataset: dair-ai/emotion
25
+ topic:
26
+ dataset: ag_news
configs/model/base.yaml CHANGED
@@ -1,50 +1,6 @@
1
- model:
2
- vocab_size: 32000
3
- d_model: 512
4
- num_encoder_layers: 6
5
- num_decoder_layers: 6
6
- num_heads: 8
7
- d_ff: 2048
8
- dropout: 0.1
9
- max_seq_length: 512
10
-
11
- tasks:
12
- summarization:
13
- enabled: true
14
- decoder_layers: 6
15
-
16
- emotion:
17
- enabled: true
18
- num_classes: 27
19
- pool_strategy: "mean" # Options: mean, max, cls, attention
20
-
21
- clustering:
22
- enabled: true
23
- embedding_dim: 128
24
- normalize: true
25
-
26
- training:
27
- batch_size: 16
28
- gradient_accumulation_steps: 2 # Effective batch = 32
29
- learning_rate: 1e-4
30
- weight_decay: 0.01
31
- num_epochs: 10
32
- warmup_steps: 1000
33
- max_grad_norm: 1.0
34
-
35
- scheduler:
36
- type: "cosine" # Options: linear, cosine, polynomial
37
-
38
- mixed_precision: true # Use AMP for faster training
39
-
40
- data:
41
- max_length: 512
42
- summary_max_length: 128
43
- train_split: 0.8
44
- val_split: 0.1
45
- test_split: 0.1
46
-
47
- preprocessing:
48
- lowercase: true
49
- remove_stopwords: false
50
- min_token_length: 3
 
1
+ d_model: 512
2
+ num_encoder_layers: 6
3
+ num_decoder_layers: 6
4
+ num_attention_heads: 8
5
+ ffn_dim: 2048
6
+ dropout: 0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/model/large.yaml CHANGED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ d_model: 768
2
+ num_encoder_layers: 12
3
+ num_decoder_layers: 12
4
+ num_attention_heads: 12
5
+ ffn_dim: 3072
6
+ dropout: 0.1
configs/model/small.yaml CHANGED
@@ -1,23 +1,6 @@
1
- # configs/model/small.yaml (for fast iteration)
2
- model:
3
- d_model: 256
4
- num_encoder_layers: 4
5
- num_decoder_layers: 4
6
- num_heads: 8
7
-
8
- training:
9
- batch_size: 32 # ~4GB VRAM
10
- gradient_accumulation_steps: 1
11
- mixed_precision: true # Essential!
12
-
13
- # configs/model/base.yaml (production)
14
- model:
15
- d_model: 512
16
- num_encoder_layers: 6
17
- num_decoder_layers: 6
18
- num_heads: 8
19
-
20
- training:
21
- batch_size: 8 # ~8GB VRAM
22
- gradient_accumulation_steps: 4 # Effective batch = 32
23
- mixed_precision: true
 
1
+ d_model: 256
2
+ num_encoder_layers: 4
3
+ num_decoder_layers: 4
4
+ num_attention_heads: 4
5
+ ffn_dim: 1024
6
+ dropout: 0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/training/default.yaml CHANGED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataloader:
2
+ batch_size: 8
3
+ shuffle: true
4
+ optimizer:
5
+ name: adamw
6
+ lr: 3.0e-5
7
+ scheduler:
8
+ name: cosine
9
+ warmup_steps: 500
10
+ trainer:
11
+ max_epochs: 5
12
+ gradient_clip_norm: 1.0
configs/training/full.yaml CHANGED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataloader:
2
+ batch_size: 16
3
+ shuffle: true
4
+ optimizer:
5
+ name: adamw
6
+ lr: 2.0e-5
7
+ scheduler:
8
+ name: cosine
9
+ warmup_steps: 1000
10
+ trainer:
11
+ max_epochs: 15
12
+ gradient_clip_norm: 1.0
configs/training/quick_test.yaml CHANGED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ dataloader:
2
+ batch_size: 2
3
+ shuffle: false
4
+ optimizer:
5
+ name: adamw
6
+ lr: 1.0e-4
7
+ trainer:
8
+ max_epochs: 1
9
+ gradient_clip_norm: 0.5
docker/Dockerfile DELETED
File without changes
docker/docker-compose.yml DELETED
File without changes
docs/api.md CHANGED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # API & CLI Documentation
2
+
3
+ ## FastAPI Service
4
+ The FastAPI application is defined in `src/api/app.py` and wires routes from
5
+ `src/api/routes.py`. All dependencies resolve through `src/api/dependencies.py`, which lazily constructs the shared inference pipeline.
6
+
7
+ ### POST `/summarize`
8
+ - **Request Body** (`SummaryRequest`):
9
+ ```json
10
+ {
11
+ "text": "Your input document"
12
+ }
13
+ ```
14
+ - **Response** (`SummaryResponse`):
15
+ ```json
16
+ {
17
+ "summary": "Generated abstractive summary",
18
+ "emotion_labels": ["joy", "surprise"],
19
+ "emotion_scores": [0.91, 0.63],
20
+ "topic": "news",
21
+ "topic_confidence": 0.82
22
+ }
23
+ ```
24
+ - **Behaviour:**
25
+ 1. Text is preprocessed through `TextPreprocessor` (with optional sklearn transformer if configured).
26
+ 2. The multitask model generates a summary via greedy decoding.
27
+ 3. Emotion and topic heads produce logits which are converted to probabilities and mapped to
28
+ human-readable labels using `artifacts/labels.json`.
29
+ 4. Results are returned as structured JSON suitable for a future Gradio interface.
30
+
31
+ ### Error Handling
32
+ - If the checkpoint or label metadata is missing, the dependency raises an HTTP 503 error with
33
+ an explanatory message.
34
+ - Validation errors (missing `text`) are handled automatically by FastAPI/Pydantic.
35
+
36
+ ## Command-Line Interface
37
+ `scripts/inference.py` provides a CLI that mirrors the API behaviour.
38
+
39
+ ### Usage
40
+ ```bash
41
+ python scripts/inference.py "Document to analyse" \
42
+ --checkpoint checkpoints/best.pt \
43
+ --labels artifacts/labels.json \
44
+ --tokenizer artifacts/hf_tokenizer \
45
+ --model-config configs/model/base.yaml \
46
+ --device cpu
47
+ ```
48
+
49
+ Options:
50
+ - `text` – zero or more positional arguments. If omitted, use `--file` to point to a newline
51
+ delimited text file.
52
+ - `--file` – optional path containing one text per line.
53
+ - `--checkpoint` – path to the trained model weights.
54
+ - `--labels` – JSON containing emotion/topic vocabularies (defaults to `artifacts/labels.json`).
55
+ - `--tokenizer` – optional tokenizer directory; defaults to the exported artifact if present.
56
+ - `--model-config` – YAML describing the architecture.
57
+ - `--device` – `cpu` or `cuda`. Passing `cuda` attempts to run inference on GPU.
58
+ - `--summary-max-length` – overrides the default maximum generation length.
59
+
60
+ ### Output
61
+ The CLI prints a JSON array where each entry contains the original text, summary, emotion labels
62
+ with scores, and topic prediction. This format is identical to the REST response, facilitating
63
+ integration tests and future Gradio UI rendering.
64
+
65
+ ## Future Gradio UI
66
+ - The planned UI will call the same inference pipeline and display results interactively.
67
+ - Given the response schema, the UI can show:
68
+ - Generated summary text.
69
+ - Emotion chips with probability bars.
70
+ - Topic confidence gauges.
71
+ - Placeholder panel for attention heatmaps and explanations.
72
+ - Once implemented, documentation updates will add a `docs/ui.md` section and screenshots under
73
+ `docs/images/`.
74
+
75
+ ## Testing
76
+ - `tests/test_api/test_routes.py` stubs the pipeline to ensure response fields and dependency
77
+ overrides behave as expected.
78
+ - `tests/test_inference/test_pipeline.py` validates pipeline methods end-to-end with dummy models,
79
+ guaranteeing API and CLI consumers receive consistent payload shapes.
docs/architecture.md CHANGED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LexiMind Architecture
2
+
3
+ ## Overview
4
+ LexiMind couples a from-scratch Transformer implementation with a modern data and inference stack. The project consists of three major layers:
5
+
6
+ 1. **Data & Preprocessing** – lightweight text cleaning built on top of scikit-learn
7
+ primitives and a Hugging Face tokenizer wrapper with deterministic batching helpers.
8
+ 2. **Model Composition** – the bespoke encoder/decoder stack with task heads assembled via
9
+ `MultiTaskModel`, plus `models.factory.build_multitask_model` to rebuild the network from
10
+ configuration files.
11
+ 3. **Inference & Serving** – a multi-task pipeline capable of summarization, emotion, and topic classification; surfaced through a CLI and FastAPI service with plans for a Gradio UI.
12
+
13
+ ## Custom Transformer Stack
14
+ - `src/models/encoder.py` and `src/models/decoder.py` implement Pre-LayerNorm Transformer
15
+ blocks with explicit positional encoding, masking logic, and incremental decoding support.
16
+ - `src/models/heads.py` provides modular output heads. Summarization uses an `LMHead` tied to
17
+ the decoder embedding weights; emotion and topic tasks use `ClassificationHead` instances.
18
+ - `src/models/multitask.py` routes inputs to the correct head, computes task-specific losses,
19
+ and exposes a single forward API used by the trainer and inference pipeline.
20
+ - `src/models/factory.py` rebuilds the encoder, decoder, and heads directly from YAML config
21
+ and tokenizer metadata so inference rebuilds the exact architecture used in training.
22
+
23
+ ## Data, Tokenization, and Preprocessing
24
+ - `src/data/tokenization.py` wraps `AutoTokenizer` to provide tensor-aware batching and helper
25
+ utilities for decoder input shifting, BOS/EOS resolution, and vocab size retrieval.
26
+ - `src/data/preprocessing.py` introduces `TextPreprocessor`, layering a `BasicTextCleaner` with
27
+ optional scikit-learn transformers (via `sklearn_transformer`) before tokenization. This keeps
28
+ the default cleaning minimal while allowing future reuse of `sklearn.preprocessing` utilities
29
+ without changing calling code.
30
+ - `src/data/dataset.py` and `src/data/dataloader.py` define strongly typed dataset containers and
31
+ collators that encode inputs with the shared tokenizer and set up task-specific labels (multi-label
32
+ emotions, categorical topics, seq2seq summaries).
33
+
34
+ ## Training Pipeline
35
+ - `src/training/trainer.py` coordinates multi-task optimization with per-task loss functions, gradient clipping, and shared tokenizer decoding for metric computation.
36
+ - Metrics in `src/training/metrics.py` include accuracy, multi-label F1, and a ROUGE-like overlap score for summarization. These metrics mirror the trainer outputs logged per task.
37
+ - Label vocabularies are serialized to `artifacts/labels.json` after training so inference can decode class indices consistently.
38
+
39
+ ## Inference & Serving
40
+ - `src/inference/pipeline.py` exposes summarization, emotion, and topic predictions with shared pre-processing, generation, and thresholding logic. It expects label vocabularies from the serialized metadata file.
41
+ - `src/inference/factory.py` rebuilds the full pipeline by loading the tokenizer (preferring the exported tokenizer artifact), reconstructing the model via the factory helpers, restoring checkpoints, and injecting label metadata.
42
+ - The CLI (`scripts/inference.py`) drives the pipeline from the command line. The FastAPI app (`src/api/routes.py`) exposes the `/summarize` endpoint that returns summaries, emotion labels + scores, and topic predictions. Test coverage in `tests/test_inference` and `tests/test_api` validates both layers with lightweight stubs.
43
+
44
+ ## Gradio UI Roadmap
45
+ - The inference pipeline returns structured outputs that are already suitable for a web UI.
46
+ - Planned steps for a Gradio demo:
47
+ 1. Wrap `InferencePipeline.batch_predict` inside Gradio callbacks for text input.
48
+ 2. Display summaries alongside emotion tag chips and topic confidence bars.
49
+ 3. Surface token-level attention visualizations by extending the pipeline to emit decoder attention maps (hooks already exist in the decoder).
50
+ - Documentation and code paths were structured to keep the Gradio integration isolated in a future `src/ui/gradio_app.py` module without altering core logic.
51
+
52
+ ## Key Decisions
53
+ - **Custom Transformer Preservation** – all modeling remains on the bespoke encoder/decoder, satisfying the constraint to avoid Hugging Face model classes while still leveraging their tokenizer implementation.
54
+ - **Tokenizer Artifact Preference** – inference automatically favors the exported tokenizer in `artifacts/hf_tokenizer`, guaranteeing consistent vocabularies between training and serving.
55
+ - **Sklearn-friendly Preprocessing** – the text preprocessor now accepts an optional
56
+ `TransformerMixin` so additional normalization (lemmatization, custom token filters, etc.) can be injected using familiar scikit-learn tooling without rewriting the batching code.
57
+ - **Documentation Alignment** – the `docs/` folder mirrors the structure requested, capturing design reasoning and paving the way for future diagrams in `docs/images`.
docs/training.md CHANGED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Procedure
2
+
3
+ ## Data Sources
4
+ - **Summarization** – expects JSONL files with `source` and `summary` fields under
5
+ `data/processed/summarization`.
6
+ - **Emotion Classification** – multi-label samples loaded from JSONL files with
7
+ `text` and `emotions` arrays. The dataset owns a `MultiLabelBinarizer` for consistent encoding.
8
+ - **Topic Classification** – single-label categorical samples with `text` and `topic` fields, encoded via `LabelEncoder`.
9
+
10
+ Paths and tokenizer defaults are configured in `configs/data/datasets.yaml`. The tokenizer section chooses the Hugging Face backbone (`facebook/bart-base` by default) and maximum length. Gutenberg book downloads are controlled via the `downloads.books` list (each entry includes `name`, `url`, and `output`).
11
+
12
+ ## Dataloaders & Collators
13
+ - `SummarizationCollator` encodes encoder/decoder inputs, prepares decoder input IDs via `Tokenizer.prepare_decoder_inputs`, and masks padding tokens with `-100` for loss computation.
14
+ - `EmotionCollator` applies the dataset's `MultiLabelBinarizer`, returning dense float tensors suitable for `BCEWithLogitsLoss`.
15
+ - `TopicCollator` emits integer class IDs via the dataset's `LabelEncoder` for `CrossEntropyLoss`.
16
+
17
+ These collators keep all tokenization centralized, reducing duplication and making it easy to swap in additional sklearn transformations through `TextPreprocessor` should we wish to extend cleaning or normalization.
18
+
19
+ ## Model Assembly
20
+ - `src/models/factory.build_multitask_model` rebuilds the encoder, decoder, and heads from the tokenizer metadata and YAML config. This factory is used both during training and inference to eliminate drift between environments.
21
+ - The model wraps:
22
+ - Transformer encoder/decoder stacks with shared positional encodings.
23
+ - LM head tied to decoder embeddings for summarization.
24
+ - Mean-pooled classification heads for emotion and topic tasks.
25
+
26
+ ## Optimisation Loop
27
+ - `src/training/trainer.Trainer` orchestrates multi-task training.
28
+ - Cross-entropy is used for summarization (seq2seq logits vs. shifted labels).
29
+ - `BCEWithLogitsLoss` handles multi-label emotions.
30
+ - `CrossEntropyLoss` handles topic classification.
31
+ - Gradient clipping ensures stability, and per-task weights can be configured via
32
+ `TrainerConfig.task_weights` to balance gradients if needed.
33
+ - Metrics tracked per task:
34
+ - **Summarization** – ROUGE-like overlap metric (`training.metrics.rouge_like`).
35
+ - **Emotion** – micro F1 score for multi-label predictions.
36
+ - **Topic** – categorical accuracy.
37
+
38
+ ## Checkpoints & Artifacts
39
+ - `src/utils/io.save_state` stores model weights; checkpoints live under `checkpoints/`.
40
+ - `artifacts/labels.json` captures the ordered emotion/topic vocabularies immediately after
41
+ training. This file is required for inference so class indices map back to human-readable labels.
42
+ - The tokenizer is exported to `artifacts/hf_tokenizer/` for reproducible vocabularies.
43
+
44
+ ## Running Training
45
+ 1. Ensure processed datasets are available (see `data/processed/` structure).
46
+ 2. Choose a configuration (e.g., `configs/training/default.yaml`) for hyperparameters and data splits.
47
+ 3. Instantiate the tokenizer via `TokenizerConfig` and build datasets/dataloaders.
48
+ 4. Use `build_multitask_model` to construct the model, create an optimizer, and run
49
+ `Trainer.fit(train_loaders, val_loaders)`.
50
+ 5. Save checkpoints and update `artifacts/labels.json` with the dataset label order.
51
+
52
+ > **Note:** A full CLI for training is forthcoming. The scripts in `scripts/` currently act as
53
+ > scaffolding; once the Gradio UI is introduced we will extend these utilities to launch
54
+ > training jobs with configuration files directly.
55
+
56
+ ## Future Enhancements
57
+ - Integrate curriculum scheduling or task-balanced sampling once empirical results dictate.
58
+ - Capture attention maps during training to support visualization in the planned Gradio UI.
59
+ - Leverage the optional `sklearn_transformer` hook in `TextPreprocessor` for lemmatization or domain-specific normalization when datasets require it.
pyproject.toml CHANGED
@@ -13,19 +13,14 @@ license = {text = "GPL-3.0"}
13
 
14
  dependencies = [
15
  "torch>=2.0.0",
16
- "transformers>=4.30.0",
17
- "datasets>=2.14.0",
18
- "tokenizers>=0.13.0",
19
  "numpy>=1.24.0",
20
  "pandas>=2.0.0",
21
- "scikit-learn>=1.3.0",
22
- "matplotlib>=3.7.0",
23
- "seaborn>=0.12.0",
24
- "tqdm>=4.65.0",
25
- "pyyaml>=6.0",
26
- "omegaconf>=2.3.0",
27
- "tensorboard>=2.13.0",
28
- "gradio>=3.35.0",
29
  ]
30
 
31
  [project.optional-dependencies]
 
13
 
14
  dependencies = [
15
  "torch>=2.0.0",
16
+ "scikit-learn>=1.4.0",
 
 
17
  "numpy>=1.24.0",
18
  "pandas>=2.0.0",
19
+ "streamlit>=1.25.0",
20
+ "plotly>=5.18.0",
21
+ "transformers>=4.40.0",
22
+ "fastapi>=0.110.0",
23
+ "datasets>=4.4.0",
 
 
 
24
  ]
25
 
26
  [project.optional-dependencies]
requirements.txt CHANGED
@@ -1,22 +1,12 @@
1
  # requirements.txt
2
  torch>=2.0.0
3
- transformers>=4.30.0
4
- datasets>=2.14.0
5
- tokenizers>=0.13.0
6
  numpy>=1.24.0
7
  pandas>=2.0.0
8
- scikit-learn>=1.3.0
9
- matplotlib>=3.7.0
10
- seaborn>=0.12.0
11
- nltk>=3.8.0
12
- tqdm>=4.65.0
13
- pyyaml>=6.0
14
- omegaconf>=2.3.0
15
- tensorboard>=2.13.0
16
- gradio>=3.35.0
17
- requests>=2.31.0
18
- kaggle>=1.5.12
19
  streamlit>=1.25.0
20
  plotly>=5.18.0
21
- faiss-cpu==1.9.0; platform_system != "Windows"
22
- faiss-cpu==1.9.0; platform_system == "Windows"
 
 
 
1
  # requirements.txt
2
  torch>=2.0.0
3
+ transformers>=4.40.0
4
+ scikit-learn>=1.4.0
 
5
  numpy>=1.24.0
6
  pandas>=2.0.0
 
 
 
 
 
 
 
 
 
 
 
7
  streamlit>=1.25.0
8
  plotly>=5.18.0
9
+ fastapi>=0.110.0
10
+ datasets>=4.4.0
11
+ pytest
12
+ matplotlib
scripts/download_data.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Download datasets used by LexiMind."""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Iterable, Iterator, cast
9
+
10
+ from datasets import ClassLabel, Dataset, DatasetDict, load_dataset
11
+
12
+
13
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
14
+ if str(PROJECT_ROOT) not in sys.path:
15
+ sys.path.insert(0, str(PROJECT_ROOT))
16
+
17
+ from src.data.download import gutenberg_download, kaggle_download
18
+ from src.utils.config import load_yaml
19
+
20
+
21
+ DEFAULT_SUMMARIZATION_DATASET = "gowrishankarp/newspaper-text-summarization-cnn-dailymail"
22
+ DEFAULT_EMOTION_DATASET = "dair-ai/emotion"
23
+ DEFAULT_TOPIC_DATASET = "ag_news"
24
+ DEFAULT_BOOK_URL = "https://www.gutenberg.org/cache/epub/1342/pg1342.txt"
25
+ DEFAULT_BOOK_OUTPUT = "data/raw/books/pride_and_prejudice.txt"
26
+
27
+
28
+ def parse_args() -> argparse.Namespace:
29
+ parser = argparse.ArgumentParser(description="Download datasets required for LexiMind training")
30
+ parser.add_argument(
31
+ "--config",
32
+ default="configs/data/datasets.yaml",
33
+ help="Path to the dataset configuration YAML.",
34
+ )
35
+ parser.add_argument("--skip-kaggle", action="store_true", help="Skip downloading the Kaggle summarization dataset.")
36
+ parser.add_argument("--skip-book", action="store_true", help="Skip downloading Gutenberg book texts.")
37
+ return parser.parse_args()
38
+
39
+
40
+ def _safe_load_config(path: str | None) -> dict:
41
+ if not path:
42
+ return {}
43
+ config_path = Path(path)
44
+ if not config_path.exists():
45
+ raise FileNotFoundError(f"Config file not found: {config_path}")
46
+ return load_yaml(str(config_path)).data
47
+
48
+
49
+ def _write_jsonl(records: Iterable[dict[str, object]], destination: Path) -> None:
50
+ destination.parent.mkdir(parents=True, exist_ok=True)
51
+ with destination.open("w", encoding="utf-8") as handle:
52
+ for record in records:
53
+ handle.write(json.dumps(record, ensure_ascii=False) + "\n")
54
+
55
+
56
+ def _emotion_records(dataset_split: Dataset, label_names: list[str] | None) -> Iterator[dict[str, object]]:
57
+ for item in dataset_split:
58
+ data = dict(item)
59
+ text = data.get("text", "")
60
+ label_value = data.get("label")
61
+ def resolve_label(index: object) -> str:
62
+ if isinstance(index, int) and label_names and 0 <= index < len(label_names):
63
+ return label_names[index]
64
+ return str(index)
65
+
66
+ if isinstance(label_value, list):
67
+ labels = [resolve_label(idx) for idx in label_value]
68
+ else:
69
+ labels = [resolve_label(label_value)]
70
+ yield {"text": text, "emotions": labels}
71
+
72
+
73
+ def _topic_records(dataset_split: Dataset, label_names: list[str] | None) -> Iterator[dict[str, object]]:
74
+ for item in dataset_split:
75
+ data = dict(item)
76
+ text = data.get("text") or data.get("content") or ""
77
+ label_value = data.get("label")
78
+ def resolve_topic(raw: object) -> str:
79
+ if label_names:
80
+ idx: int | None = None
81
+ if isinstance(raw, int):
82
+ idx = raw
83
+ elif isinstance(raw, str):
84
+ try:
85
+ idx = int(raw)
86
+ except ValueError:
87
+ idx = None
88
+ if idx is not None and 0 <= idx < len(label_names):
89
+ return label_names[idx]
90
+ return str(raw) if raw is not None else ""
91
+
92
+ if isinstance(label_value, list):
93
+ topic = resolve_topic(label_value[0]) if label_value else ""
94
+ else:
95
+ topic = resolve_topic(label_value)
96
+ yield {"text": text, "topic": topic}
97
+
98
+
99
+ def main() -> None:
100
+ args = parse_args()
101
+ config = _safe_load_config(args.config)
102
+
103
+ raw_paths = config.get("raw", {}) if isinstance(config, dict) else {}
104
+ downloads_cfg = config.get("downloads", {}) if isinstance(config, dict) else {}
105
+
106
+ summarization_cfg = downloads_cfg.get("summarization", {}) if isinstance(downloads_cfg, dict) else {}
107
+ summarization_dataset = summarization_cfg.get("dataset", DEFAULT_SUMMARIZATION_DATASET)
108
+ summarization_output = summarization_cfg.get("output", raw_paths.get("summarization", "data/raw/summarization"))
109
+
110
+ if not args.skip_kaggle and summarization_dataset:
111
+ print(f"Downloading summarization dataset '{summarization_dataset}' -> {summarization_output}")
112
+ kaggle_download(summarization_dataset, summarization_output)
113
+ else:
114
+ print("Skipping Kaggle summarization download.")
115
+
116
+ books_root = Path(raw_paths.get("books", "data/raw/books"))
117
+ books_root.mkdir(parents=True, exist_ok=True)
118
+
119
+ books_entries: list[dict[str, object]] = []
120
+ if isinstance(downloads_cfg, dict):
121
+ raw_entries = downloads_cfg.get("books")
122
+ if isinstance(raw_entries, list):
123
+ books_entries = [entry for entry in raw_entries if isinstance(entry, dict)]
124
+
125
+ if not args.skip_book:
126
+ if not books_entries:
127
+ books_entries = [
128
+ {
129
+ "name": "pride_and_prejudice",
130
+ "url": DEFAULT_BOOK_URL,
131
+ "output": DEFAULT_BOOK_OUTPUT,
132
+ }
133
+ ]
134
+ for entry in books_entries:
135
+ name = str(entry.get("name") or "gutenberg_text")
136
+ url = str(entry.get("url") or DEFAULT_BOOK_URL)
137
+ output_value = entry.get("output")
138
+ destination = Path(output_value) if isinstance(output_value, str) and output_value else books_root / f"{name}.txt"
139
+ destination.parent.mkdir(parents=True, exist_ok=True)
140
+ print(f"Downloading Gutenberg text '{name}' from {url} -> {destination}")
141
+ gutenberg_download(url, str(destination))
142
+ else:
143
+ print("Skipping Gutenberg downloads.")
144
+ emotion_cfg = downloads_cfg.get("emotion", {}) if isinstance(downloads_cfg, dict) else {}
145
+ emotion_name = emotion_cfg.get("dataset", DEFAULT_EMOTION_DATASET)
146
+ emotion_dir = Path(raw_paths.get("emotion", "data/raw/emotion"))
147
+ emotion_dir.mkdir(parents=True, exist_ok=True)
148
+ print(f"Downloading emotion dataset '{emotion_name}' -> {emotion_dir}")
149
+ emotion_dataset = cast(DatasetDict, load_dataset(emotion_name))
150
+ first_emotion_key = next(iter(emotion_dataset.keys()), None) if emotion_dataset else None
151
+ emotion_label_feature = (
152
+ emotion_dataset[first_emotion_key].features.get("label")
153
+ if first_emotion_key is not None
154
+ else None
155
+ )
156
+ emotion_label_names = emotion_label_feature.names if isinstance(emotion_label_feature, ClassLabel) else None
157
+ for split_name, split in emotion_dataset.items():
158
+ output_path = emotion_dir / f"{str(split_name)}.jsonl"
159
+ _write_jsonl(_emotion_records(split, emotion_label_names), output_path)
160
+
161
+ topic_cfg = downloads_cfg.get("topic", {}) if isinstance(downloads_cfg, dict) else {}
162
+ topic_name = topic_cfg.get("dataset", DEFAULT_TOPIC_DATASET)
163
+ topic_dir = Path(raw_paths.get("topic", "data/raw/topic"))
164
+ topic_dir.mkdir(parents=True, exist_ok=True)
165
+ print(f"Downloading topic dataset '{topic_name}' -> {topic_dir}")
166
+ topic_dataset = cast(DatasetDict, load_dataset(topic_name))
167
+ first_topic_key = next(iter(topic_dataset.keys()), None) if topic_dataset else None
168
+ topic_label_feature = (
169
+ topic_dataset[first_topic_key].features.get("label")
170
+ if first_topic_key is not None
171
+ else None
172
+ )
173
+ topic_label_names = topic_label_feature.names if isinstance(topic_label_feature, ClassLabel) else None
174
+ for split_name, split in topic_dataset.items():
175
+ output_path = topic_dir / f"{str(split_name)}.jsonl"
176
+ _write_jsonl(_topic_records(split, topic_label_names), output_path)
177
+
178
+ print("Download routine finished.")
179
+
180
+
181
+ if __name__ == "__main__":
182
+ main()
scripts/download_data.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ python3 "${SCRIPT_DIR}/download_data.py"
scripts/evaluate.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluate the multitask model on processed validation/test splits."""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import List
9
+
10
+ import torch
11
+ from sklearn.preprocessing import MultiLabelBinarizer
12
+
13
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
14
+ if str(PROJECT_ROOT) not in sys.path:
15
+ sys.path.insert(0, str(PROJECT_ROOT))
16
+
17
+ from src.data.dataset import (
18
+ load_emotion_jsonl,
19
+ load_summarization_jsonl,
20
+ load_topic_jsonl,
21
+ )
22
+ from src.inference.factory import create_inference_pipeline
23
+ from src.training.metrics import accuracy, multilabel_f1, rouge_like
24
+ from src.utils.config import load_yaml
25
+
26
+
27
+ SPLIT_ALIASES = {
28
+ "train": ("train",),
29
+ "val": ("val", "validation"),
30
+ "test": ("test",),
31
+ }
32
+
33
+
34
+ def _read_split(root: Path, split: str, loader) -> list:
35
+ aliases = SPLIT_ALIASES.get(split, (split,))
36
+ for alias in aliases:
37
+ for ext in ("jsonl", "json"):
38
+ candidate = root / f"{alias}.{ext}"
39
+ if candidate.exists():
40
+ return loader(str(candidate))
41
+ raise FileNotFoundError(f"Missing {split} split under {root}")
42
+
43
+
44
+ def parse_args() -> argparse.Namespace:
45
+ parser = argparse.ArgumentParser(description="Evaluate the LexiMind multitask model")
46
+ parser.add_argument("--split", default="val", choices=["train", "val", "test"], help="Dataset split to evaluate.")
47
+ parser.add_argument("--checkpoint", default="checkpoints/best.pt", help="Path to the trained checkpoint.")
48
+ parser.add_argument("--labels", default="artifacts/labels.json", help="Label metadata JSON.")
49
+ parser.add_argument("--data-config", default="configs/data/datasets.yaml", help="Data configuration YAML.")
50
+ parser.add_argument("--model-config", default="configs/model/base.yaml", help="Model architecture YAML.")
51
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device for evaluation.")
52
+ parser.add_argument("--batch-size", type=int, default=16, help="Batch size for generation/classification during evaluation.")
53
+ return parser.parse_args()
54
+
55
+
56
+ def chunks(items: List, size: int):
57
+ for start in range(0, len(items), size):
58
+ yield items[start : start + size]
59
+
60
+
61
+ def main() -> None:
62
+ args = parse_args()
63
+ data_cfg = load_yaml(args.data_config).data
64
+
65
+ pipeline, metadata = create_inference_pipeline(
66
+ checkpoint_path=args.checkpoint,
67
+ labels_path=args.labels,
68
+ tokenizer_config=None,
69
+ model_config_path=args.model_config,
70
+ device=args.device,
71
+ )
72
+
73
+ summarization_dir = Path(data_cfg["processed"]["summarization"])
74
+ emotion_dir = Path(data_cfg["processed"]["emotion"])
75
+ topic_dir = Path(data_cfg["processed"]["topic"])
76
+
77
+ summary_examples = _read_split(summarization_dir, args.split, load_summarization_jsonl)
78
+ emotion_examples = _read_split(emotion_dir, args.split, load_emotion_jsonl)
79
+ topic_examples = _read_split(topic_dir, args.split, load_topic_jsonl)
80
+
81
+ emotion_binarizer = MultiLabelBinarizer(classes=metadata.emotion)
82
+ # Ensure scikit-learn initializes the attributes using metadata ordering.
83
+ emotion_binarizer.fit([[label] for label in metadata.emotion])
84
+
85
+ # Summarization
86
+ summaries_pred = []
87
+ summaries_ref = []
88
+ for batch in chunks(summary_examples, args.batch_size):
89
+ inputs = [example.source for example in batch]
90
+ summaries_pred.extend(pipeline.summarize(inputs))
91
+ summaries_ref.extend([example.summary for example in batch])
92
+ rouge_score = rouge_like(summaries_pred, summaries_ref)
93
+
94
+ # Emotion
95
+ emotion_preds_tensor = []
96
+ emotion_target_tensor = []
97
+ label_to_index = {label: idx for idx, label in enumerate(metadata.emotion)}
98
+ for batch in chunks(emotion_examples, args.batch_size):
99
+ inputs = [example.text for example in batch]
100
+ predictions = pipeline.predict_emotions(inputs)
101
+ target_matrix = emotion_binarizer.transform([list(example.emotions) for example in batch])
102
+ for pred, target_row in zip(predictions, target_matrix):
103
+ vector = torch.zeros(len(metadata.emotion), dtype=torch.float32)
104
+ for label in pred.labels:
105
+ idx = label_to_index.get(label)
106
+ if idx is not None:
107
+ vector[idx] = 1.0
108
+ emotion_preds_tensor.append(vector)
109
+ emotion_target_tensor.append(torch.tensor(target_row, dtype=torch.float32))
110
+ emotion_f1 = multilabel_f1(torch.stack(emotion_preds_tensor), torch.stack(emotion_target_tensor))
111
+
112
+ # Topic
113
+ topic_preds = []
114
+ topic_targets = []
115
+ for batch in chunks(topic_examples, args.batch_size):
116
+ inputs = [example.text for example in batch]
117
+ predictions = pipeline.predict_topics(inputs)
118
+ topic_preds.extend([pred.label for pred in predictions])
119
+ topic_targets.extend([example.topic for example in batch])
120
+ topic_accuracy = accuracy(topic_preds, topic_targets)
121
+
122
+ print(json.dumps(
123
+ {
124
+ "split": args.split,
125
+ "rouge_like": rouge_score,
126
+ "emotion_f1": emotion_f1,
127
+ "topic_accuracy": topic_accuracy,
128
+ },
129
+ indent=2,
130
+ ))
131
+
132
+
133
+ if __name__ == "__main__":
134
+ main()
scripts/export_model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Rebuild and export the trained multitask model for downstream use."""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+ from src.data.tokenization import Tokenizer, TokenizerConfig
10
+ from src.models.factory import build_multitask_model, load_model_config
11
+ from src.utils.config import load_yaml
12
+ from src.utils.labels import load_label_metadata
13
+
14
+
15
+ def parse_args() -> argparse.Namespace:
16
+ parser = argparse.ArgumentParser(description="Export LexiMind model weights")
17
+ parser.add_argument("--checkpoint", default="checkpoints/best.pt", help="Path to the trained checkpoint.")
18
+ parser.add_argument("--output", default="outputs/model.pt", help="Output path for the exported state dict.")
19
+ parser.add_argument("--labels", default="artifacts/labels.json", help="Label metadata JSON produced after training.")
20
+ parser.add_argument("--model-config", default="configs/model/base.yaml", help="Model architecture configuration.")
21
+ parser.add_argument("--data-config", default="configs/data/datasets.yaml", help="Data configuration (for tokenizer settings).")
22
+ return parser.parse_args()
23
+
24
+
25
+ def main() -> None:
26
+ """Export multitask model weights from a training checkpoint to a standalone state dict."""
27
+ args = parse_args()
28
+
29
+ checkpoint = Path(args.checkpoint)
30
+ if not checkpoint.exists():
31
+ raise FileNotFoundError(checkpoint)
32
+
33
+ labels = load_label_metadata(args.labels)
34
+ data_cfg = load_yaml(args.data_config).data
35
+ tokenizer_section = data_cfg.get("tokenizer", {})
36
+ tokenizer_config = TokenizerConfig(
37
+ pretrained_model_name=tokenizer_section.get("pretrained_model_name", "facebook/bart-base"),
38
+ max_length=int(tokenizer_section.get("max_length", 512)),
39
+ lower=bool(tokenizer_section.get("lower", False)),
40
+ )
41
+ tokenizer = Tokenizer(tokenizer_config)
42
+
43
+ model = build_multitask_model(
44
+ tokenizer,
45
+ num_emotions=labels.emotion_size,
46
+ num_topics=labels.topic_size,
47
+ config=load_model_config(args.model_config),
48
+ )
49
+
50
+ raw_state = torch.load(checkpoint, map_location="cpu")
51
+ if isinstance(raw_state, dict):
52
+ if "model_state_dict" in raw_state and isinstance(raw_state["model_state_dict"], dict):
53
+ state_dict = raw_state["model_state_dict"]
54
+ elif "state_dict" in raw_state and isinstance(raw_state["state_dict"], dict):
55
+ state_dict = raw_state["state_dict"]
56
+ else:
57
+ state_dict = raw_state
58
+ else:
59
+ raise TypeError(f"Unsupported checkpoint format: expected dict, got {type(raw_state)!r}")
60
+ model.load_state_dict(state_dict)
61
+
62
+ output_path = Path(args.output)
63
+ output_path.parent.mkdir(parents=True, exist_ok=True)
64
+ torch.save(model.state_dict(), output_path)
65
+ print(f"Model exported to {output_path}")
66
+
67
+
68
+ if __name__ == "__main__":
69
+ main()
scripts/inference.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run inference with the multitask model."""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ from pathlib import Path
7
+ from typing import List, cast
8
+
9
+ from src.data.tokenization import TokenizerConfig
10
+ from src.inference import EmotionPrediction, TopicPrediction, create_inference_pipeline
11
+
12
+
13
+ def _load_texts(positional: List[str], file_path: Path | None) -> List[str]:
14
+ texts = [text for text in positional if text]
15
+ if file_path is not None:
16
+ if not file_path.exists():
17
+ raise FileNotFoundError(file_path)
18
+ with file_path.open("r", encoding="utf-8") as handle:
19
+ texts.extend([line.strip() for line in handle if line.strip()])
20
+ if not texts:
21
+ raise ValueError("No input texts provided. Pass text arguments or use --file.")
22
+ return texts
23
+
24
+
25
+ def parse_args() -> argparse.Namespace:
26
+ parser = argparse.ArgumentParser(description="Run LexiMind multitask inference.")
27
+ parser.add_argument("text", nargs="*", help="Input text(s) to analyse.")
28
+ parser.add_argument("--file", type=Path, help="Path to a file containing one text per line.")
29
+ parser.add_argument(
30
+ "--checkpoint",
31
+ type=Path,
32
+ default=Path("checkpoints/best.pt"),
33
+ help="Path to the model checkpoint produced during training.",
34
+ )
35
+ parser.add_argument(
36
+ "--labels",
37
+ type=Path,
38
+ default=Path("artifacts/labels.json"),
39
+ help="JSON file containing emotion/topic label vocabularies.",
40
+ )
41
+ parser.add_argument(
42
+ "--tokenizer",
43
+ type=Path,
44
+ default=None,
45
+ help="Optional path to a tokenizer directory exported during training.",
46
+ )
47
+ parser.add_argument(
48
+ "--model-config",
49
+ type=Path,
50
+ default=Path("configs/model/base.yaml"),
51
+ help="Model architecture config used to rebuild the transformer stack.",
52
+ )
53
+ parser.add_argument("--device", default="cpu", help="Device to run inference on (cpu or cuda).")
54
+ parser.add_argument(
55
+ "--summary-max-length",
56
+ type=int,
57
+ default=None,
58
+ help="Optional maximum length for generated summaries.",
59
+ )
60
+ return parser.parse_args()
61
+
62
+
63
+ def main() -> None:
64
+ args = parse_args()
65
+ texts = _load_texts(args.text, args.file)
66
+
67
+ tokenizer_config = None
68
+ if args.tokenizer is not None:
69
+ tokenizer_config = TokenizerConfig(pretrained_model_name=str(args.tokenizer))
70
+ else:
71
+ local_dir = Path("artifacts/hf_tokenizer")
72
+ if local_dir.exists():
73
+ tokenizer_config = TokenizerConfig(pretrained_model_name=str(local_dir))
74
+
75
+ pipeline, _ = create_inference_pipeline(
76
+ checkpoint_path=args.checkpoint,
77
+ labels_path=args.labels,
78
+ tokenizer_config=tokenizer_config,
79
+ model_config_path=args.model_config,
80
+ device=args.device,
81
+ summary_max_length=args.summary_max_length,
82
+ )
83
+
84
+ results = pipeline.batch_predict(texts)
85
+ summaries = cast(List[str], results["summaries"])
86
+ emotion_preds = cast(List[EmotionPrediction], results["emotion"])
87
+ topic_preds = cast(List[TopicPrediction], results["topic"])
88
+
89
+ packaged = []
90
+ for idx, text in enumerate(texts):
91
+ emotion = emotion_preds[idx]
92
+ topic = topic_preds[idx]
93
+ packaged.append(
94
+ {
95
+ "text": text,
96
+ "summary": summaries[idx],
97
+ "emotion": {
98
+ "labels": emotion.labels,
99
+ "scores": emotion.scores,
100
+ },
101
+ "topic": {
102
+ "label": topic.label,
103
+ "confidence": topic.confidence,
104
+ },
105
+ }
106
+ )
107
+
108
+ print(json.dumps(packaged, indent=2, ensure_ascii=False))
109
+
110
+
111
+ if __name__ == "__main__":
112
+ main()
scripts/preprocess_data.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Preprocess raw datasets into JSONL splits for LexiMind training."""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import csv
6
+ import json
7
+ import sys
8
+ from pathlib import Path
9
+ from typing import Dict, Iterable, Iterator, Sequence, Tuple
10
+
11
+ from sklearn.model_selection import train_test_split
12
+
13
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
14
+ if str(PROJECT_ROOT) not in sys.path:
15
+ sys.path.insert(0, str(PROJECT_ROOT))
16
+
17
+ from src.data.preprocessing import BasicTextCleaner
18
+ from src.utils.config import load_yaml
19
+
20
+
21
+ def parse_args() -> argparse.Namespace:
22
+ parser = argparse.ArgumentParser(description="Preprocess datasets configured for LexiMind")
23
+ parser.add_argument(
24
+ "--config",
25
+ default="configs/data/datasets.yaml",
26
+ help="Path to data configuration YAML.",
27
+ )
28
+ parser.add_argument("--val-ratio", type=float, default=0.1, help="Validation split size for topic dataset when no validation split is present.")
29
+ parser.add_argument("--seed", type=int, default=17, help="Random seed for deterministic splitting.")
30
+ return parser.parse_args()
31
+
32
+
33
+ def _resolve_csv(base: Path, filename: str) -> Path | None:
34
+ primary = base / filename
35
+ if primary.exists():
36
+ return primary
37
+ nested = base / "cnn_dailymail" / filename
38
+ if nested.exists():
39
+ return nested
40
+ return None
41
+
42
+
43
+ def _write_jsonl(records: Iterable[Dict[str, object]], destination: Path) -> None:
44
+ destination.parent.mkdir(parents=True, exist_ok=True)
45
+ with destination.open("w", encoding="utf-8") as handle:
46
+ for record in records:
47
+ handle.write(json.dumps(record, ensure_ascii=False) + "\n")
48
+
49
+
50
+ def _read_jsonl(path: Path) -> Iterator[Dict[str, object]]:
51
+ with path.open("r", encoding="utf-8") as handle:
52
+ for line in handle:
53
+ row = line.strip()
54
+ if not row:
55
+ continue
56
+ yield json.loads(row)
57
+
58
+
59
+ def preprocess_books(
60
+ raw_dir: Path,
61
+ processed_dir: Path,
62
+ cleaner: BasicTextCleaner,
63
+ *,
64
+ min_tokens: int = 30,
65
+ ) -> None:
66
+ if not raw_dir.exists():
67
+ print(f"Skipping book preprocessing (missing directory: {raw_dir})")
68
+ return
69
+
70
+ processed_dir.mkdir(parents=True, exist_ok=True)
71
+ index: list[Dict[str, object]] = []
72
+
73
+ for book_path in sorted(raw_dir.glob("*.txt")):
74
+ text = book_path.read_text(encoding="utf-8").lstrip("\ufeff")
75
+ normalized = text.replace("\r\n", "\n")
76
+ paragraphs = [paragraph.strip() for paragraph in normalized.split("\n\n") if paragraph.strip()]
77
+
78
+ records: list[Dict[str, object]] = []
79
+ for paragraph_id, paragraph in enumerate(paragraphs):
80
+ cleaned = cleaner.transform([paragraph])[0]
81
+ tokens = cleaned.split()
82
+ if len(tokens) < min_tokens:
83
+ continue
84
+ record = {
85
+ "book": book_path.stem,
86
+ "title": book_path.stem.replace("_", " ").title(),
87
+ "paragraph_id": paragraph_id,
88
+ "text": paragraph,
89
+ "clean_text": cleaned,
90
+ "token_count": len(tokens),
91
+ "char_count": len(paragraph),
92
+ }
93
+ records.append(record)
94
+
95
+ if not records:
96
+ print(f"No suitably sized paragraphs found in {book_path}; skipping.")
97
+ continue
98
+
99
+ output_path = processed_dir / f"{book_path.stem}.jsonl"
100
+ print(f"Writing book segments for '{book_path.stem}' to {output_path}")
101
+ _write_jsonl(records, output_path)
102
+ index.append(
103
+ {
104
+ "book": book_path.stem,
105
+ "title": records[0]["title"],
106
+ "paragraphs": len(records),
107
+ "source": str(book_path),
108
+ "output": str(output_path),
109
+ }
110
+ )
111
+
112
+ if index:
113
+ index_path = processed_dir / "index.json"
114
+ with index_path.open("w", encoding="utf-8") as handle:
115
+ json.dump(index, handle, ensure_ascii=False, indent=2)
116
+ print(f"Book index written to {index_path}")
117
+
118
+
119
+ def preprocess_summarization(raw_dir: Path, processed_dir: Path) -> None:
120
+ if not raw_dir.exists():
121
+ print(f"Skipping summarization preprocessing (missing directory: {raw_dir})")
122
+ return
123
+
124
+ for split in ("train", "validation", "test"):
125
+ source_path = _resolve_csv(raw_dir, f"{split}.csv")
126
+ if source_path is None:
127
+ print(f"Skipping summarization split '{split}' (file not found)")
128
+ continue
129
+
130
+ output_path = processed_dir / f"{split}.jsonl"
131
+ output_path.parent.mkdir(parents=True, exist_ok=True)
132
+ print(f"Writing summarization split '{split}' to {output_path}")
133
+ with source_path.open("r", encoding="utf-8", newline="") as source_handle, output_path.open("w", encoding="utf-8") as sink:
134
+ reader = csv.DictReader(source_handle)
135
+ for row in reader:
136
+ article = row.get("article") or row.get("Article") or ""
137
+ highlights = row.get("highlights") or row.get("summary") or ""
138
+ payload = {"source": article.strip(), "summary": highlights.strip()}
139
+ sink.write(json.dumps(payload, ensure_ascii=False) + "\n")
140
+
141
+
142
+ def preprocess_emotion(raw_dir: Path, processed_dir: Path, cleaner: BasicTextCleaner) -> None:
143
+ if not raw_dir.exists():
144
+ print(f"Skipping emotion preprocessing (missing directory: {raw_dir})")
145
+ return
146
+
147
+ split_aliases: Dict[str, Sequence[str]] = {
148
+ "train": ("train",),
149
+ "val": ("val", "validation"),
150
+ "test": ("test",),
151
+ }
152
+
153
+ for split, aliases in split_aliases.items():
154
+ source_path: Path | None = None
155
+ for alias in aliases:
156
+ for extension in ("jsonl", "txt", "csv"):
157
+ candidate = raw_dir / f"{alias}.{extension}"
158
+ if candidate.exists():
159
+ source_path = candidate
160
+ break
161
+ if source_path is not None:
162
+ break
163
+ if source_path is None:
164
+ print(f"Skipping emotion split '{split}' (file not found)")
165
+ continue
166
+
167
+ assert source_path is not None
168
+ path = source_path
169
+
170
+ def iter_records() -> Iterator[Dict[str, object]]:
171
+ if path.suffix == ".jsonl":
172
+ for row in _read_jsonl(path):
173
+ raw_text = str(row.get("text", ""))
174
+ text = cleaner.transform([raw_text])[0]
175
+ labels = row.get("emotions") or row.get("labels") or []
176
+ if isinstance(labels, str):
177
+ labels = [label.strip() for label in labels.split(",") if label.strip()]
178
+ elif isinstance(labels, Sequence):
179
+ labels = [str(label) for label in labels]
180
+ else:
181
+ labels = [str(labels)] if labels else []
182
+ if not labels:
183
+ labels = ["neutral"]
184
+ yield {"text": text, "emotions": labels}
185
+ else:
186
+ delimiter = ";" if path.suffix == ".txt" else ","
187
+ with path.open("r", encoding="utf-8", newline="") as handle:
188
+ reader = csv.reader(handle, delimiter=delimiter)
189
+ for row in reader:
190
+ if not row:
191
+ continue
192
+ raw_text = str(row[0])
193
+ text = cleaner.transform([raw_text])[0]
194
+ raw_labels = row[1] if len(row) > 1 else ""
195
+ labels = [label.strip() for label in raw_labels.split(",") if label.strip()]
196
+ if not labels:
197
+ labels = ["neutral"]
198
+ yield {"text": text, "emotions": labels}
199
+
200
+ output_path = processed_dir / f"{split}.jsonl"
201
+ print(f"Writing emotion split '{split}' to {output_path}")
202
+ _write_jsonl(iter_records(), output_path)
203
+
204
+
205
+ def preprocess_topic(
206
+ raw_dir: Path,
207
+ processed_dir: Path,
208
+ cleaner: BasicTextCleaner,
209
+ val_ratio: float,
210
+ seed: int,
211
+ ) -> None:
212
+ if not raw_dir.exists():
213
+ print(f"Skipping topic preprocessing (missing directory: {raw_dir})")
214
+ return
215
+
216
+ def locate(*names: str) -> Path | None:
217
+ for name in names:
218
+ candidate = raw_dir / name
219
+ if candidate.exists():
220
+ return candidate
221
+ return None
222
+
223
+ train_path = locate("train.jsonl", "train.csv")
224
+ if train_path is None:
225
+ print(f"Skipping topic preprocessing (missing train split in {raw_dir})")
226
+ return
227
+
228
+ assert train_path is not None
229
+
230
+ def load_topic_rows(path: Path) -> list[Tuple[str, str]]:
231
+ rows: list[Tuple[str, str]] = []
232
+ if path.suffix == ".jsonl":
233
+ for record in _read_jsonl(path):
234
+ text = str(record.get("text") or record.get("content") or "")
235
+ topic = record.get("topic") or record.get("label")
236
+ cleaned_text = cleaner.transform([text])[0]
237
+ rows.append((cleaned_text, str(topic).strip()))
238
+ else:
239
+ with path.open("r", encoding="utf-8", newline="") as handle:
240
+ reader = csv.DictReader(handle)
241
+ for row in reader:
242
+ topic = row.get("Class Index") or row.get("topic") or row.get("label")
243
+ title = str(row.get("Title") or "")
244
+ description = str(row.get("Description") or row.get("text") or "")
245
+ text = " ".join(filter(None, (title, description)))
246
+ cleaned_text = cleaner.transform([text])[0]
247
+ rows.append((cleaned_text, str(topic).strip()))
248
+ return rows
249
+
250
+ train_rows = load_topic_rows(train_path)
251
+ if not train_rows:
252
+ print("No topic training rows found; skipping topic preprocessing.")
253
+ return
254
+
255
+ texts = [row[0] for row in train_rows]
256
+ topics = [row[1] for row in train_rows]
257
+
258
+ validation_path = locate("val.jsonl", "validation.jsonl", "val.csv", "validation.csv")
259
+ has_validation = validation_path is not None
260
+
261
+ if has_validation and validation_path:
262
+ val_rows = load_topic_rows(validation_path)
263
+ train_records = train_rows
264
+ else:
265
+ train_texts, val_texts, train_topics, val_topics = train_test_split(
266
+ texts,
267
+ topics,
268
+ test_size=val_ratio,
269
+ random_state=seed,
270
+ stratify=topics,
271
+ )
272
+ train_records = list(zip(train_texts, train_topics))
273
+ val_rows = list(zip(val_texts, val_topics))
274
+
275
+ def to_records(pairs: Sequence[Tuple[str, str]]) -> Iterator[Dict[str, object]]:
276
+ for text, topic in pairs:
277
+ yield {"text": text, "topic": topic}
278
+
279
+ print(f"Writing topic train split to {processed_dir / 'train.jsonl'}")
280
+ _write_jsonl(to_records(train_records), processed_dir / "train.jsonl")
281
+ print(f"Writing topic val split to {processed_dir / 'val.jsonl'}")
282
+ _write_jsonl(to_records(val_rows), processed_dir / "val.jsonl")
283
+
284
+ test_path = locate("test.jsonl", "test.csv")
285
+ if test_path is not None:
286
+ test_rows = load_topic_rows(test_path)
287
+ print(f"Writing topic test split to {processed_dir / 'test.jsonl'}")
288
+ _write_jsonl(to_records(test_rows), processed_dir / "test.jsonl")
289
+ else:
290
+ print(f"Skipping topic test split (missing test split in {raw_dir})")
291
+
292
+
293
+ def main() -> None:
294
+ args = parse_args()
295
+ config = load_yaml(args.config).data
296
+
297
+ raw_cfg = config.get("raw", {})
298
+ processed_cfg = config.get("processed", {})
299
+
300
+ books_raw = Path(raw_cfg.get("books", "data/raw/books"))
301
+ summarization_raw = Path(raw_cfg.get("summarization", "data/raw/summarization"))
302
+ emotion_raw = Path(raw_cfg.get("emotion", "data/raw/emotion"))
303
+ topic_raw = Path(raw_cfg.get("topic", "data/raw/topic"))
304
+
305
+ books_processed = Path(processed_cfg.get("books", "data/processed/books"))
306
+ summarization_processed = Path(processed_cfg.get("summarization", "data/processed/summarization"))
307
+ emotion_processed = Path(processed_cfg.get("emotion", "data/processed/emotion"))
308
+ topic_processed = Path(processed_cfg.get("topic", "data/processed/topic"))
309
+
310
+ cleaner = BasicTextCleaner()
311
+
312
+ preprocess_books(books_raw, books_processed, cleaner)
313
+ preprocess_summarization(summarization_raw, summarization_processed)
314
+ preprocess_emotion(emotion_raw, emotion_processed, cleaner)
315
+ preprocess_topic(topic_raw, topic_processed, cleaner, val_ratio=args.val_ratio, seed=args.seed)
316
+
317
+ print("Preprocessing complete.")
318
+
319
+
320
+ if __name__ == "__main__":
321
+ main()
scripts/test_gpu.py DELETED
@@ -1,27 +0,0 @@
1
- # test_gpu.py
2
- import torch
3
-
4
- print("=" * 50)
5
- print("GPU Information")
6
- print("=" * 50)
7
-
8
- if torch.cuda.is_available():
9
- gpu_name = torch.cuda.get_device_name(0)
10
- gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
11
-
12
- print(f"✅ GPU: {gpu_name}")
13
- print(f"✅ Memory: {gpu_memory:.2f} GB")
14
-
15
- # Test tensor creation
16
- x = torch.randn(1000, 1000, device='cuda')
17
- y = torch.randn(1000, 1000, device='cuda')
18
- z = x @ y
19
-
20
- print(f"✅ CUDA operations working!")
21
- print(f"✅ Current memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
22
- print(f"✅ Max memory allocated: {torch.cuda.max_memory_allocated(0) / 1e9:.2f} GB")
23
- else:
24
- print("❌ CUDA not available!")
25
- print("Using CPU - training will be slow!")
26
-
27
- print("=" * 50)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/train.py CHANGED
@@ -1,8 +1,219 @@
1
- # scripts/train.py
2
- from src.training.trainer import Trainer
3
- from src.utils.config import load_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  if __name__ == "__main__":
6
- config = load_config("configs/training/default.yaml")
7
- trainer = Trainer(config)
8
- trainer.train()
 
1
+ """End-to-end training entrypoint for the LexiMind multitask model."""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Dict, Sequence
9
+
10
+ import torch
11
+
12
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
13
+ if str(PROJECT_ROOT) not in sys.path:
14
+ sys.path.insert(0, str(PROJECT_ROOT))
15
+
16
+ from src.data.dataloader import (
17
+ build_emotion_dataloader,
18
+ build_summarization_dataloader,
19
+ build_topic_dataloader,
20
+ )
21
+ from src.data.dataset import (
22
+ EmotionDataset,
23
+ SummarizationDataset,
24
+ TopicDataset,
25
+ load_emotion_jsonl,
26
+ load_summarization_jsonl,
27
+ load_topic_jsonl,
28
+ )
29
+ from src.data.tokenization import Tokenizer, TokenizerConfig
30
+ from src.models.factory import build_multitask_model, load_model_config
31
+ from src.training.trainer import Trainer, TrainerConfig
32
+ from src.training.utils import set_seed
33
+ from src.utils.config import load_yaml
34
+ from src.utils.io import save_state
35
+ from src.utils.labels import LabelMetadata, save_label_metadata
36
+
37
+
38
+ SplitExamples = Dict[str, list]
39
+
40
+
41
+ SPLIT_ALIASES: Dict[str, Sequence[str]] = {
42
+ "train": ("train",),
43
+ "val": ("val", "validation"),
44
+ "test": ("test",),
45
+ }
46
+
47
+
48
+ def _read_examples(data_dir: Path, loader) -> SplitExamples:
49
+ splits: SplitExamples = {}
50
+ for canonical, aliases in SPLIT_ALIASES.items():
51
+ found = False
52
+ for alias in aliases:
53
+ for extension in ("jsonl", "json"):
54
+ candidate = data_dir / f"{alias}.{extension}"
55
+ if candidate.exists():
56
+ splits[canonical] = loader(str(candidate))
57
+ found = True
58
+ break
59
+ if found:
60
+ break
61
+ if not found:
62
+ raise FileNotFoundError(f"Missing {canonical} split under {data_dir}")
63
+ return splits
64
+
65
+
66
+ def parse_args() -> argparse.Namespace:
67
+ parser = argparse.ArgumentParser(description="Train the LexiMind multitask transformer")
68
+ parser.add_argument("--data-config", default="configs/data/datasets.yaml", help="Path to data configuration YAML.")
69
+ parser.add_argument("--training-config", default="configs/training/default.yaml", help="Path to training hyperparameter YAML.")
70
+ parser.add_argument("--model-config", default="configs/model/base.yaml", help="Path to model architecture YAML.")
71
+ parser.add_argument("--checkpoint-out", default="checkpoints/best.pt", help="Where to store the trained checkpoint.")
72
+ parser.add_argument("--labels-out", default="artifacts/labels.json", help="Where to persist label vocabularies.")
73
+ parser.add_argument("--history-out", default="outputs/training_history.json", help="Where to write training history.")
74
+ parser.add_argument("--device", default="cpu", help="Training device identifier (cpu or cuda).")
75
+ parser.add_argument("--seed", type=int, default=17, help="Random seed for reproducibility.")
76
+ return parser.parse_args()
77
+
78
+
79
+ def main() -> None:
80
+ args = parse_args()
81
+ set_seed(args.seed)
82
+
83
+ data_cfg = load_yaml(args.data_config).data
84
+ training_cfg = load_yaml(args.training_config).data
85
+ model_cfg = load_model_config(args.model_config)
86
+
87
+ summarization_dir = Path(data_cfg["processed"]["summarization"])
88
+ emotion_dir = Path(data_cfg["processed"]["emotion"])
89
+ topic_dir = Path(data_cfg["processed"]["topic"])
90
+
91
+ summarization_splits = _read_examples(summarization_dir, load_summarization_jsonl)
92
+ emotion_splits = _read_examples(emotion_dir, load_emotion_jsonl)
93
+ topic_splits = _read_examples(topic_dir, load_topic_jsonl)
94
+
95
+ tokenizer_section = data_cfg.get("tokenizer", {})
96
+ tokenizer_config = TokenizerConfig(
97
+ pretrained_model_name=tokenizer_section.get("pretrained_model_name", "facebook/bart-base"),
98
+ max_length=int(tokenizer_section.get("max_length", 512)),
99
+ lower=bool(tokenizer_section.get("lower", False)),
100
+ )
101
+ tokenizer = Tokenizer(tokenizer_config)
102
+
103
+ summarization_train = SummarizationDataset(summarization_splits["train"])
104
+ summarization_val = SummarizationDataset(summarization_splits["val"])
105
+
106
+ emotion_train = EmotionDataset(emotion_splits["train"])
107
+ emotion_val = EmotionDataset(emotion_splits["val"], binarizer=emotion_train.binarizer)
108
+
109
+ topic_train = TopicDataset(topic_splits["train"])
110
+ topic_val = TopicDataset(topic_splits["val"], encoder=topic_train.encoder)
111
+
112
+ dataloader_args = training_cfg.get("dataloader", {})
113
+ batch_size = int(dataloader_args.get("batch_size", 8))
114
+ shuffle = bool(dataloader_args.get("shuffle", True))
115
+ max_length = tokenizer.config.max_length
116
+
117
+ train_loaders = {
118
+ "summarization": build_summarization_dataloader(
119
+ summarization_train,
120
+ tokenizer,
121
+ batch_size=batch_size,
122
+ shuffle=shuffle,
123
+ max_source_length=max_length,
124
+ max_target_length=max_length,
125
+ ),
126
+ "emotion": build_emotion_dataloader(
127
+ emotion_train,
128
+ tokenizer,
129
+ batch_size=batch_size,
130
+ shuffle=shuffle,
131
+ max_length=max_length,
132
+ ),
133
+ "topic": build_topic_dataloader(
134
+ topic_train,
135
+ tokenizer,
136
+ batch_size=batch_size,
137
+ shuffle=shuffle,
138
+ max_length=max_length,
139
+ ),
140
+ }
141
+
142
+ val_loaders = {
143
+ "summarization": build_summarization_dataloader(
144
+ summarization_val,
145
+ tokenizer,
146
+ batch_size=batch_size,
147
+ shuffle=False,
148
+ max_source_length=max_length,
149
+ max_target_length=max_length,
150
+ ),
151
+ "emotion": build_emotion_dataloader(
152
+ emotion_val,
153
+ tokenizer,
154
+ batch_size=batch_size,
155
+ shuffle=False,
156
+ max_length=max_length,
157
+ ),
158
+ "topic": build_topic_dataloader(
159
+ topic_val,
160
+ tokenizer,
161
+ batch_size=batch_size,
162
+ shuffle=False,
163
+ max_length=max_length,
164
+ ),
165
+ }
166
+
167
+ device = torch.device(args.device)
168
+ model = build_multitask_model(
169
+ tokenizer,
170
+ num_emotions=len(emotion_train.emotion_classes),
171
+ num_topics=len(topic_train.topic_classes),
172
+ config=model_cfg,
173
+ ).to(device)
174
+
175
+ optimizer_cfg = training_cfg.get("optimizer", {})
176
+ lr = float(optimizer_cfg.get("lr", 3.0e-5))
177
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
178
+
179
+ trainer_cfg = training_cfg.get("trainer", {})
180
+ trainer = Trainer(
181
+ model=model,
182
+ optimizer=optimizer,
183
+ config=TrainerConfig(
184
+ max_epochs=int(trainer_cfg.get("max_epochs", 1)),
185
+ gradient_clip_norm=float(trainer_cfg.get("gradient_clip_norm", 1.0)),
186
+ logging_interval=int(trainer_cfg.get("logging_interval", 50)),
187
+ task_weights=trainer_cfg.get("task_weights"),
188
+ ),
189
+ device=device,
190
+ tokenizer=tokenizer,
191
+ )
192
+
193
+ history = trainer.fit(train_loaders, val_loaders)
194
+
195
+ checkpoint_path = Path(args.checkpoint_out)
196
+ checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
197
+ save_state(model, str(checkpoint_path))
198
+
199
+ labels_path = Path(args.labels_out)
200
+ save_label_metadata(
201
+ LabelMetadata(
202
+ emotion=emotion_train.emotion_classes,
203
+ topic=topic_train.topic_classes,
204
+ ),
205
+ labels_path,
206
+ )
207
+
208
+ history_path = Path(args.history_out)
209
+ history_path.parent.mkdir(parents=True, exist_ok=True)
210
+ with history_path.open("w", encoding="utf-8") as handle:
211
+ json.dump(history, handle, indent=2)
212
+
213
+ print(f"Training complete. Checkpoint saved to {checkpoint_path}")
214
+ print(f"Label metadata saved to {labels_path}")
215
+ print(f"History saved to {history_path}")
216
+
217
 
218
  if __name__ == "__main__":
219
+ main()
 
 
setup.py CHANGED
@@ -7,13 +7,23 @@ setup(
7
  package_dir={"": "src"},
8
  install_requires=[
9
  "torch>=2.0.0",
10
- "transformers>=4.30.0",
11
- # ... (or read from requirements.txt)
 
 
12
  ],
13
- entry_points={
14
- "console_scripts": [
15
- "leximind-train=scripts.train:main",
16
- "leximind-infer=scripts.inference:main",
 
 
 
 
 
 
 
 
17
  ],
18
  },
19
  )
 
7
  package_dir={"": "src"},
8
  install_requires=[
9
  "torch>=2.0.0",
10
+ "transformers>=4.40.0",
11
+ "scikit-learn>=1.4.0",
12
+ "numpy>=1.24.0",
13
+ "pandas>=2.0.0",
14
  ],
15
+ extras_require={
16
+ "web": [
17
+ "streamlit>=1.25.0",
18
+ "plotly>=5.18.0",
19
+ ],
20
+ "api": [
21
+ "fastapi>=0.110.0",
22
+ ],
23
+ "all": [
24
+ "streamlit>=1.25.0",
25
+ "plotly>=5.18.0",
26
+ "fastapi>=0.110.0",
27
  ],
28
  },
29
  )
src/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ """LexiMind core package."""
src/api/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ """API surface for LexiMind."""
src/api/app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application entrypoint."""
2
+ from fastapi import FastAPI
3
+
4
+ from .routes import router
5
+
6
+
7
+ def create_app() -> FastAPI:
8
+ app = FastAPI(title="LexiMind")
9
+ app.include_router(router)
10
+ return app
src/api/dependencies.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dependency providers for the FastAPI application."""
2
+ from __future__ import annotations
3
+
4
+ from functools import lru_cache
5
+ from pathlib import Path
6
+
7
+ from fastapi import HTTPException, status
8
+
9
+ from ..utils.logging import get_logger
10
+ logger = get_logger(__name__)
11
+
12
+ from ..inference.factory import create_inference_pipeline
13
+ from ..inference.pipeline import InferencePipeline
14
+
15
+
16
+ @lru_cache(maxsize=1)
17
+ def get_pipeline() -> InferencePipeline:
18
+ """Lazily construct and cache the inference pipeline for the API."""
19
+
20
+ checkpoint = Path("checkpoints/best.pt")
21
+ labels = Path("artifacts/labels.json")
22
+ model_config = Path("configs/model/base.yaml")
23
+
24
+ try:
25
+ pipeline, _ = create_inference_pipeline(
26
+ checkpoint_path=checkpoint,
27
+ labels_path=labels,
28
+ model_config_path=model_config,
29
+ )
30
+ except FileNotFoundError as exc:
31
+ logger.exception("Pipeline initialization failed: missing artifact")
32
+ raise HTTPException(
33
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
34
+ detail="Service temporarily unavailable",
35
+ ) from exc
36
+ except Exception as exc: # noqa: BLE001 - surface initialization issues to the caller
37
+ logger.exception("Pipeline initialization failed")
38
+ raise HTTPException(
39
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
40
+ detail="Service temporarily unavailable",
41
+ ) from exc
42
+ return pipeline
src/api/inference/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- """
2
- API inference module for LexiMind.
3
- """
4
-
5
- from .inference import load_models, summarize_text, classify_emotion, topic_for_text
6
-
7
- __all__ = ["load_models", "summarize_text", "classify_emotion", "topic_for_text"]
 
 
 
 
 
 
 
 
src/api/inference/inference.py DELETED
@@ -1,133 +0,0 @@
1
- """Minimal inference helpers that rely on the custom transformer stack."""
2
-
3
- from __future__ import annotations
4
-
5
- from pathlib import Path
6
- from typing import Any, Dict, List, Optional, Tuple
7
-
8
- import torch
9
-
10
- from ...data.preprocessing import TextPreprocessor, TransformerTokenizer
11
- from ...models.multitask import MultiTaskModel
12
-
13
-
14
- def _load_tokenizer(tokenizer_path: Path) -> TransformerTokenizer:
15
- if not tokenizer_path.exists():
16
- raise FileNotFoundError(f"tokenizer file '{tokenizer_path}' not found")
17
- return TransformerTokenizer.load(tokenizer_path)
18
-
19
-
20
- def load_models(config: Dict[str, Any]) -> Dict[str, Any]:
21
- """Load MultiTaskModel together with the tokenizer-driven preprocessor."""
22
-
23
- device = torch.device(config.get("device", "cpu"))
24
- tokenizer_path = config.get("tokenizer_path")
25
- if tokenizer_path is None:
26
- raise ValueError("'tokenizer_path' missing in config")
27
-
28
- tokenizer = _load_tokenizer(Path(tokenizer_path))
29
- preprocessor = TextPreprocessor(
30
- max_length=int(config.get("max_length", 512)),
31
- tokenizer=tokenizer,
32
- min_freq=int(config.get("min_freq", 1)),
33
- lowercase=bool(config.get("lowercase", True)),
34
- )
35
-
36
- encoder_kwargs = dict(config.get("encoder", {}))
37
- decoder_kwargs = dict(config.get("decoder", {}))
38
-
39
- encoder = preprocessor.build_encoder(**encoder_kwargs)
40
- decoder = preprocessor.build_decoder(**decoder_kwargs)
41
- model = MultiTaskModel(encoder=encoder, decoder=decoder)
42
-
43
- checkpoint_path = config.get("checkpoint_path")
44
- if checkpoint_path:
45
- state = torch.load(checkpoint_path, map_location=device)
46
- if isinstance(state, dict) and "state_dict" in state:
47
- state = state["state_dict"]
48
- model.load_state_dict(state, strict=False)
49
-
50
- model.to(device)
51
-
52
- return {
53
- "loaded": True,
54
- "device": device,
55
- "mt": model,
56
- "preprocessor": preprocessor,
57
- }
58
-
59
-
60
- def summarize_text(
61
- text: str,
62
- compression: float = 0.25,
63
- collect_attn: bool = False,
64
- models: Optional[Dict[str, Any]] = None,
65
- ) -> Tuple[str, Optional[Dict[str, torch.Tensor]]]:
66
- if models is None or not models.get("loaded"):
67
- raise RuntimeError("Models must be loaded via load_models before summarize_text is called")
68
-
69
- model: MultiTaskModel = models["mt"]
70
- preprocessor: TextPreprocessor = models["preprocessor"]
71
- device: torch.device = models["device"]
72
-
73
- batch = preprocessor.batch_encode([text])
74
- tokenizer = preprocessor.tokenizer
75
- encoder = model.encoder
76
- decoder = model.decoder
77
- if tokenizer is None or encoder is None or decoder is None:
78
- raise RuntimeError("Encoder, decoder, and tokenizer must be configured before summarization")
79
- input_ids = batch.input_ids.to(device)
80
- memory = encoder(input_ids)
81
- src_len = batch.lengths[0]
82
- max_tgt = max(4, int(src_len * compression))
83
- generated = decoder.greedy_decode(
84
- memory,
85
- max_len=min(preprocessor.max_length, max_tgt),
86
- start_token_id=tokenizer.bos_id,
87
- end_token_id=tokenizer.eos_id,
88
- )
89
- summary = tokenizer.decode(generated[0].tolist(), skip_special_tokens=True)
90
- return summary.strip(), None if not collect_attn else {}
91
-
92
-
93
- def classify_emotion(text: str, models: Optional[Dict[str, Any]] = None) -> Tuple[List[float], List[str]]:
94
- if models is None or not models.get("loaded"):
95
- raise RuntimeError("Models must be loaded via load_models before classify_emotion is called")
96
-
97
- model: MultiTaskModel = models["mt"]
98
- preprocessor: TextPreprocessor = models["preprocessor"]
99
- device: torch.device = models["device"]
100
-
101
- batch = preprocessor.batch_encode([text])
102
- input_ids = batch.input_ids.to(device)
103
- result = model.forward("emotion", {"input_ids": input_ids})
104
- logits = result[1] if isinstance(result, tuple) else result
105
- scores = torch.sigmoid(logits).squeeze(0).detach().cpu().tolist()
106
- labels = models.get("emotion_labels") or [
107
- "joy",
108
- "sadness",
109
- "anger",
110
- "fear",
111
- "surprise",
112
- "disgust",
113
- ]
114
- return scores, labels[: len(scores)]
115
-
116
-
117
- def topic_for_text(text: str, models: Optional[Dict[str, Any]] = None) -> Tuple[int, List[str]]:
118
- if models is None or not models.get("loaded"):
119
- raise RuntimeError("Models must be loaded via load_models before topic_for_text is called")
120
-
121
- model: MultiTaskModel = models["mt"]
122
- preprocessor: TextPreprocessor = models["preprocessor"]
123
- device: torch.device = models["device"]
124
-
125
- batch = preprocessor.batch_encode([text])
126
- input_ids = batch.input_ids.to(device)
127
- encoder = model.encoder
128
- if encoder is None:
129
- raise RuntimeError("Encoder must be configured before topic_for_text is called")
130
- memory = encoder(input_ids)
131
- embedding = memory.mean(dim=1).detach().cpu()
132
- _ = embedding # placeholder for downstream clustering hook
133
- return 0, ["topic_stub"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/api/routes.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """API routes."""
2
+ from typing import cast
3
+
4
+ from fastapi import APIRouter, Depends, HTTPException, status
5
+
6
+ from ..inference import EmotionPrediction, InferencePipeline, TopicPrediction
7
+ from .dependencies import get_pipeline
8
+ from .schemas import SummaryRequest, SummaryResponse
9
+
10
+ router = APIRouter()
11
+
12
+
13
+ @router.post("/summarize", response_model=SummaryResponse)
14
+ def summarize(payload: SummaryRequest, pipeline: InferencePipeline = Depends(get_pipeline)) -> SummaryResponse:
15
+ try:
16
+ outputs = pipeline.batch_predict([payload.text])
17
+ except Exception as exc: # noqa: BLE001 - surface inference error to client
18
+ raise HTTPException(
19
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
20
+ detail=str(exc),
21
+ ) from exc
22
+ summaries = cast(list[str], outputs["summaries"])
23
+ emotion_preds = cast(list[EmotionPrediction], outputs["emotion"])
24
+ topic_preds = cast(list[TopicPrediction], outputs["topic"])
25
+
26
+ emotion = emotion_preds[0]
27
+ topic = topic_preds[0]
28
+ return SummaryResponse(
29
+ summary=summaries[0],
30
+ emotion_labels=emotion.labels,
31
+ emotion_scores=emotion.scores,
32
+ topic=topic.label,
33
+ topic_confidence=topic.confidence,
34
+ )
src/api/schemas.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """API schemas."""
2
+ from pydantic import BaseModel
3
+
4
+
5
+ class SummaryRequest(BaseModel):
6
+ text: str
7
+
8
+
9
+ class SummaryResponse(BaseModel):
10
+ summary: str
11
+ emotion_labels: list[str]
12
+ emotion_scores: list[float]
13
+ topic: str
14
+ topic_confidence: float
src/data/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ """Data utilities for LexiMind."""
src/data/dataloader.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Task-aware DataLoader builders for the LexiMind multitask suite."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Iterable, List
5
+
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+
9
+ from .dataset import EmotionDataset, EmotionExample, SummarizationDataset, SummarizationExample, TopicDataset, TopicExample
10
+ from .tokenization import Tokenizer
11
+
12
+
13
+ class SummarizationCollator:
14
+ """Prepare encoder-decoder batches for abstractive summarization."""
15
+
16
+ def __init__(self, tokenizer: Tokenizer, *, max_source_length: int | None = None, max_target_length: int | None = None) -> None:
17
+ self.tokenizer = tokenizer
18
+ self.max_source_length = max_source_length
19
+ self.max_target_length = max_target_length
20
+
21
+ def __call__(self, batch: List[SummarizationExample]) -> dict[str, torch.Tensor]:
22
+ sources = [example.source for example in batch]
23
+ targets = [example.summary for example in batch]
24
+
25
+ source_enc = self.tokenizer.batch_encode(sources, max_length=self.max_source_length)
26
+ target_enc = self.tokenizer.batch_encode(targets, max_length=self.max_target_length)
27
+
28
+ labels = target_enc["input_ids"].clone()
29
+ decoder_input_ids = self.tokenizer.prepare_decoder_inputs(target_enc["input_ids"])
30
+ labels[target_enc["attention_mask"] == 0] = -100
31
+
32
+ return {
33
+ "src_ids": source_enc["input_ids"],
34
+ "src_mask": source_enc["attention_mask"],
35
+ "tgt_ids": decoder_input_ids,
36
+ "labels": labels,
37
+ }
38
+
39
+
40
+ class EmotionCollator:
41
+ """Prepare batches for multi-label emotion classification."""
42
+
43
+ def __init__(self, tokenizer: Tokenizer, dataset: EmotionDataset, *, max_length: int | None = None) -> None:
44
+ self.tokenizer = tokenizer
45
+ self.binarizer = dataset.binarizer
46
+ self.max_length = max_length
47
+
48
+ def __call__(self, batch: List[EmotionExample]) -> dict[str, torch.Tensor]:
49
+ texts = [example.text for example in batch]
50
+ encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
51
+ label_array = self.binarizer.transform([example.emotions for example in batch])
52
+ labels = torch.as_tensor(label_array, dtype=torch.float32)
53
+ return {
54
+ "input_ids": encoded["input_ids"],
55
+ "attention_mask": encoded["attention_mask"],
56
+ "labels": labels,
57
+ }
58
+
59
+
60
+ class TopicCollator:
61
+ """Prepare batches for topic classification using the projection head."""
62
+
63
+ def __init__(self, tokenizer: Tokenizer, dataset: TopicDataset, *, max_length: int | None = None) -> None:
64
+ self.tokenizer = tokenizer
65
+ self.encoder = dataset.encoder
66
+ self.max_length = max_length
67
+
68
+ def __call__(self, batch: List[TopicExample]) -> dict[str, torch.Tensor]:
69
+ texts = [example.text for example in batch]
70
+ encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
71
+ labels = torch.as_tensor(self.encoder.transform([example.topic for example in batch]), dtype=torch.long)
72
+ return {
73
+ "input_ids": encoded["input_ids"],
74
+ "attention_mask": encoded["attention_mask"],
75
+ "labels": labels,
76
+ }
77
+
78
+
79
+ def build_summarization_dataloader(
80
+ dataset: SummarizationDataset,
81
+ tokenizer: Tokenizer,
82
+ *,
83
+ batch_size: int,
84
+ shuffle: bool = True,
85
+ max_source_length: int | None = None,
86
+ max_target_length: int | None = None,
87
+ ) -> DataLoader:
88
+ collator = SummarizationCollator(
89
+ tokenizer,
90
+ max_source_length=max_source_length,
91
+ max_target_length=max_target_length,
92
+ )
93
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collator)
94
+
95
+
96
+ def build_emotion_dataloader(
97
+ dataset: EmotionDataset,
98
+ tokenizer: Tokenizer,
99
+ *,
100
+ batch_size: int,
101
+ shuffle: bool = True,
102
+ max_length: int | None = None,
103
+ ) -> DataLoader:
104
+ collator = EmotionCollator(tokenizer, dataset, max_length=max_length)
105
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collator)
106
+
107
+
108
+ def build_topic_dataloader(
109
+ dataset: TopicDataset,
110
+ tokenizer: Tokenizer,
111
+ *,
112
+ batch_size: int,
113
+ shuffle: bool = True,
114
+ max_length: int | None = None,
115
+ ) -> DataLoader:
116
+ collator = TopicCollator(tokenizer, dataset, max_length=max_length)
117
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collator)
src/data/dataset.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset definitions for the LexiMind multitask training pipeline."""
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Callable, Iterable, List, Sequence, TypeVar
8
+
9
+ from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
10
+ from torch.utils.data import Dataset
11
+
12
+
13
+ @dataclass(slots=True)
14
+ class SummarizationExample:
15
+ """Container for abstractive summarization samples."""
16
+
17
+ source: str
18
+ summary: str
19
+
20
+
21
+ @dataclass(slots=True)
22
+ class EmotionExample:
23
+ """Container for multi-label emotion classification samples."""
24
+
25
+ text: str
26
+ emotions: Sequence[str]
27
+
28
+
29
+ @dataclass(slots=True)
30
+ class TopicExample:
31
+ """Container for topic clustering / classification samples."""
32
+
33
+ text: str
34
+ topic: str
35
+
36
+
37
+ class SummarizationDataset(Dataset[SummarizationExample]):
38
+ """Dataset yielding encoder-decoder training pairs."""
39
+
40
+ def __init__(self, examples: Iterable[SummarizationExample]) -> None:
41
+ self._examples = list(examples)
42
+
43
+ def __len__(self) -> int:
44
+ return len(self._examples)
45
+
46
+ def __getitem__(self, index: int) -> SummarizationExample:
47
+ return self._examples[index]
48
+
49
+
50
+ class EmotionDataset(Dataset[EmotionExample]):
51
+ """Dataset that owns a scikit-learn MultiLabelBinarizer for emissions."""
52
+
53
+ def __init__(
54
+ self,
55
+ examples: Iterable[EmotionExample],
56
+ *,
57
+ binarizer: MultiLabelBinarizer | None = None,
58
+ ) -> None:
59
+ self._examples = list(examples)
60
+ all_labels = [example.emotions for example in self._examples]
61
+ if binarizer is None:
62
+ self._binarizer = MultiLabelBinarizer()
63
+ self._binarizer.fit(all_labels)
64
+ else:
65
+ self._binarizer = binarizer
66
+ if not hasattr(self._binarizer, "classes_"):
67
+ raise ValueError(
68
+ "Provided MultiLabelBinarizer must be pre-fitted with 'classes_' attribute."
69
+ )
70
+
71
+ def __len__(self) -> int:
72
+ return len(self._examples)
73
+
74
+ def __getitem__(self, index: int) -> EmotionExample:
75
+ return self._examples[index]
76
+
77
+ @property
78
+ def binarizer(self) -> MultiLabelBinarizer:
79
+ return self._binarizer
80
+
81
+ @property
82
+ def emotion_classes(self) -> List[str]:
83
+ return list(self._binarizer.classes_)
84
+
85
+
86
+ class TopicDataset(Dataset[TopicExample]):
87
+ """Dataset that owns a LabelEncoder for topic ids."""
88
+
89
+ def __init__(
90
+ self,
91
+ examples: Iterable[TopicExample],
92
+ *,
93
+ encoder: LabelEncoder | None = None,
94
+ ) -> None:
95
+ self._examples = list(examples)
96
+ topics = [example.topic for example in self._examples]
97
+ if encoder is None:
98
+ self._encoder = LabelEncoder().fit(topics)
99
+ else:
100
+ self._encoder = encoder
101
+ if not hasattr(self._encoder, "classes_"):
102
+ raise ValueError(
103
+ "Provided LabelEncoder must be pre-fitted with 'classes_' attribute."
104
+ )
105
+
106
+ def __len__(self) -> int:
107
+ return len(self._examples)
108
+
109
+ def __getitem__(self, index: int) -> TopicExample:
110
+ return self._examples[index]
111
+
112
+ @property
113
+ def encoder(self) -> LabelEncoder:
114
+ return self._encoder
115
+
116
+ @property
117
+ def topic_classes(self) -> List[str]:
118
+ return list(self._encoder.classes_)
119
+
120
+
121
+ T = TypeVar("T")
122
+
123
+
124
+ def _safe_json_load(handle, path: Path) -> object:
125
+ try:
126
+ return json.load(handle)
127
+ except json.JSONDecodeError as exc:
128
+ raise ValueError(f"Failed to parse JSON in '{path}': {exc}") from exc
129
+
130
+
131
+ def _safe_json_loads(data: str, path: Path, line_number: int) -> object:
132
+ try:
133
+ return json.loads(data)
134
+ except json.JSONDecodeError as exc:
135
+ raise ValueError(f"Failed to parse JSON in '{path}' at line {line_number}: {exc}") from exc
136
+
137
+
138
+ def _validate_keys(
139
+ payload: dict,
140
+ required_keys: Sequence[str],
141
+ position: int,
142
+ *,
143
+ path: Path,
144
+ is_array: bool = False,
145
+ ) -> None:
146
+ missing = [key for key in required_keys if key not in payload]
147
+ if missing:
148
+ keys = ", ".join(sorted(missing))
149
+ location = "index" if is_array else "line"
150
+ raise KeyError(f"Missing required keys ({keys}) at {location} {position} of '{path}'")
151
+
152
+
153
+ def _load_jsonl_generic(
154
+ path: str,
155
+ constructor: Callable[[dict], T],
156
+ required_keys: Sequence[str],
157
+ ) -> List[T]:
158
+ data_path = Path(path)
159
+ if not data_path.exists():
160
+ raise FileNotFoundError(f"Dataset file '{data_path}' does not exist")
161
+ if not data_path.is_file():
162
+ raise ValueError(f"Dataset path '{data_path}' is not a file")
163
+
164
+ items: List[T] = []
165
+ with data_path.open("r", encoding="utf-8") as handle:
166
+ first_non_ws = ""
167
+ while True:
168
+ pos = handle.tell()
169
+ char = handle.read(1)
170
+ if not char:
171
+ break
172
+ if not char.isspace():
173
+ first_non_ws = char
174
+ handle.seek(pos)
175
+ break
176
+ if not first_non_ws:
177
+ raise ValueError(f"Dataset file '{data_path}' is empty or contains only whitespace")
178
+
179
+ if first_non_ws == "[":
180
+ payloads = _safe_json_load(handle, data_path)
181
+ if not isinstance(payloads, list):
182
+ raise ValueError(f"Expected a JSON array in '{data_path}' but found {type(payloads).__name__}")
183
+ for idx, payload in enumerate(payloads):
184
+ if not isinstance(payload, dict):
185
+ raise ValueError(
186
+ f"Expected objects in array for '{data_path}', found {type(payload).__name__} at index {idx}"
187
+ )
188
+ _validate_keys(payload, required_keys, idx, path=data_path, is_array=True)
189
+ items.append(constructor(payload))
190
+ else:
191
+ handle.seek(0)
192
+ line_number = 0
193
+ for line in handle:
194
+ line_number += 1
195
+ if not line.strip():
196
+ continue
197
+ payload = _safe_json_loads(line, data_path, line_number)
198
+ if not isinstance(payload, dict):
199
+ raise ValueError(
200
+ f"Expected JSON object per line in '{data_path}', found {type(payload).__name__} at line {line_number}"
201
+ )
202
+ _validate_keys(payload, required_keys, line_number, path=data_path)
203
+ items.append(constructor(payload))
204
+
205
+ return items
206
+
207
+
208
+ def load_summarization_jsonl(path: str) -> List[SummarizationExample]:
209
+ return _load_jsonl_generic(
210
+ path,
211
+ lambda payload: SummarizationExample(source=payload["source"], summary=payload["summary"]),
212
+ required_keys=("source", "summary"),
213
+ )
214
+
215
+
216
+ def load_emotion_jsonl(path: str) -> List[EmotionExample]:
217
+ return _load_jsonl_generic(
218
+ path,
219
+ lambda payload: EmotionExample(text=payload["text"], emotions=payload.get("emotions", [])),
220
+ required_keys=("text",),
221
+ )
222
+
223
+
224
+ def load_topic_jsonl(path: str) -> List[TopicExample]:
225
+ return _load_jsonl_generic(
226
+ path,
227
+ lambda payload: TopicExample(text=payload["text"], topic=payload["topic"]),
228
+ required_keys=("text", "topic"),
229
+ )
src/data/download.py CHANGED
@@ -1,66 +1,45 @@
1
- """
2
- Download helpers for datasets.
3
 
4
- This version:
5
- - Adds robust error handling when Kaggle API is not configured.
6
- - Stores files under data/raw/ subfolders.
7
- - Keeps the Gutenberg direct download example.
 
8
 
9
- Make sure you have Kaggle credentials configured if you call Kaggle downloads.
10
- """
11
- import os
12
- import requests
13
 
14
- def download_gutenberg(out_dir="data/raw/books", gutenberg_id: int = 1342, filename: str = "pride_and_prejudice.txt"):
15
- """Download a Gutenberg text file by direct URL template (best-effort)."""
16
- url = f"https://www.gutenberg.org/files/{gutenberg_id}/{gutenberg_id}-0.txt"
17
- os.makedirs(out_dir, exist_ok=True)
18
- out_path = os.path.join(out_dir, filename)
19
- if os.path.exists(out_path):
20
- print("Already downloaded:", out_path)
21
- return out_path
22
- try:
23
- r = requests.get(url, timeout=30)
24
- r.raise_for_status()
25
- with open(out_path, "wb") as f:
26
- f.write(r.content)
27
- print("Downloaded:", out_path)
28
- return out_path
29
- except Exception as e:
30
- print("Failed to download Gutenberg file:", e)
31
- return None
32
 
33
- # Kaggle helpers: optional, wrapped to avoid hard failure when Kaggle isn't configured.
34
- def _safe_kaggle_download(dataset: str, path: str):
 
35
  try:
36
- import kaggle
37
- except Exception as e:
38
- print("Kaggle package not available or not configured. Please install 'kaggle' and configure API token. Error:", e)
39
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  try:
41
- os.makedirs(path, exist_ok=True)
42
- kaggle.api.authenticate()
43
- kaggle.api.dataset_download_files(dataset, path=path, unzip=True)
44
- print(f"Downloaded Kaggle dataset {dataset} to {path}")
45
- return True
46
- except Exception as e:
47
- print("Failed to download Kaggle dataset:", e)
48
- return False
49
-
50
- def download_emotion_dataset():
51
- target_dir = "data/raw/emotion"
52
- return _safe_kaggle_download('praveengovi/emotions-dataset-for-nlp', target_dir)
53
-
54
- def download_cnn_dailymail():
55
- target_dir = "data/raw/summarization"
56
- return _safe_kaggle_download('gowrishankarp/newspaper-text-summarization-cnn-dailymail', target_dir)
57
-
58
- def download_ag_news():
59
- target_dir = "data/raw/topic"
60
- return _safe_kaggle_download('amananandrai/ag-news-classification-dataset', target_dir)
61
-
62
- if __name__ == "__main__":
63
- download_gutenberg()
64
- download_emotion_dataset()
65
- download_cnn_dailymail()
66
- download_ag_news()
 
1
+ """Dataset download helpers."""
 
2
 
3
+ import socket
4
+ from pathlib import Path
5
+ from subprocess import CalledProcessError, run
6
+ from urllib.error import URLError
7
+ from urllib.request import urlopen
8
 
 
 
 
 
9
 
10
+ DOWNLOAD_TIMEOUT = 60
11
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ def kaggle_download(dataset: str, output_dir: str) -> None:
14
+ target = Path(output_dir)
15
+ target.mkdir(parents=True, exist_ok=True)
16
  try:
17
+ run([
18
+ "kaggle",
19
+ "datasets",
20
+ "download",
21
+ "-d",
22
+ dataset,
23
+ "-p",
24
+ str(target),
25
+ "--unzip",
26
+ ], check=True)
27
+ except CalledProcessError as error:
28
+ raise RuntimeError(
29
+ "Kaggle download failed. Verify that the Kaggle CLI is authenticated,"
30
+ " you have accepted the dataset terms on kaggle.com, and your kaggle.json"
31
+ " credentials are located in %USERPROFILE%/.kaggle."
32
+ ) from error
33
+
34
+
35
+ def gutenberg_download(url: str, output_path: str) -> None:
36
+ target = Path(output_path)
37
+ target.parent.mkdir(parents=True, exist_ok=True)
38
  try:
39
+ with urlopen(url, timeout=DOWNLOAD_TIMEOUT) as response, target.open("wb") as handle:
40
+ chunk = response.read(8192)
41
+ while chunk:
42
+ handle.write(chunk)
43
+ chunk = response.read(8192)
44
+ except (URLError, socket.timeout, OSError) as error:
45
+ raise RuntimeError(f"Failed to download '{url}' to '{target}': {error}") from error
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/preprocessing.py CHANGED
@@ -1,260 +1,130 @@
1
- """Lightweight preprocessing utilities built around the in-repo transformer."""
2
-
3
  from __future__ import annotations
4
 
5
- from collections import Counter
6
- from dataclasses import dataclass
7
- import json
8
- from pathlib import Path
9
  import re
10
- from typing import Dict, Iterable, List, Optional, Sequence, Tuple
 
11
 
12
  import torch
 
 
13
 
14
- from ..models.decoder import TransformerDecoder
15
- from ..models.encoder import TransformerEncoder
16
-
17
- SPECIAL_TOKENS: Tuple[str, str, str, str] = ("<pad>", "<bos>", "<eos>", "<unk>")
18
-
19
-
20
- def _normalize(text: str, lowercase: bool) -> str:
21
- text = text.strip()
22
- text = re.sub(r"\s+", " ", text)
23
- if lowercase:
24
- text = text.lower()
25
- return text
26
-
27
 
28
- def _basic_tokenize(text: str) -> List[str]:
29
- return re.findall(r"\b\w+\b|[.,;:?!]", text)
30
 
 
 
31
 
32
- class TransformerTokenizer:
33
- """Minimal tokenizer that keeps vocabulary aligned with the custom transformer."""
34
-
35
- def __init__(
36
- self,
37
- stoi: Dict[str, int],
38
- itos: List[str],
39
- specials: Sequence[str] = SPECIAL_TOKENS,
40
- lowercase: bool = True,
41
- ) -> None:
42
- self.stoi = stoi
43
- self.itos = itos
44
- self.specials = tuple(specials)
45
  self.lowercase = lowercase
46
- self.pad_id = self._lookup(self.specials[0])
47
- self.bos_id = self._lookup(self.specials[1])
48
- self.eos_id = self._lookup(self.specials[2])
49
- self.unk_id = self._lookup(self.specials[3])
50
-
51
- @classmethod
52
- def build(
53
- cls,
54
- texts: Iterable[str],
55
- min_freq: int = 1,
56
- lowercase: bool = True,
57
- specials: Sequence[str] = SPECIAL_TOKENS,
58
- ) -> "TransformerTokenizer":
59
- counter: Counter[str] = Counter()
60
- for text in texts:
61
- normalized = _normalize(text, lowercase)
62
- counter.update(_basic_tokenize(normalized))
63
-
64
- ordered_specials = list(dict.fromkeys(specials))
65
- itos: List[str] = ordered_specials.copy()
66
- for token, freq in counter.most_common():
67
- if freq < min_freq:
68
- continue
69
- if token in itos:
70
- continue
71
- itos.append(token)
72
 
73
- stoi = {token: idx for idx, token in enumerate(itos)}
74
- return cls(stoi=stoi, itos=itos, specials=ordered_specials, lowercase=lowercase)
75
 
76
- @property
77
- def vocab_size(self) -> int:
78
- return len(self.itos)
79
 
80
- def tokenize(self, text: str) -> List[str]:
81
- normalized = _normalize(text, self.lowercase)
82
- return _basic_tokenize(normalized)
 
 
83
 
84
- def encode(
85
- self,
86
- text: str,
87
- add_special_tokens: bool = True,
88
- max_length: Optional[int] = None,
89
- ) -> List[int]:
90
- tokens = self.tokenize(text)
91
- pieces = [self.stoi.get(tok, self.unk_id) for tok in tokens]
92
- if add_special_tokens:
93
- pieces = [self.bos_id] + pieces + [self.eos_id]
94
-
95
- if max_length is not None and len(pieces) > max_length:
96
- if add_special_tokens and max_length >= 2:
97
- inner_max = max_length - 2
98
- trimmed = pieces[1:-1][:inner_max]
99
- pieces = [self.bos_id] + trimmed + [self.eos_id]
100
- else:
101
- pieces = pieces[:max_length]
102
- return pieces
103
 
104
- def decode(self, ids: Sequence[int], skip_special_tokens: bool = True) -> str:
105
- tokens: List[str] = []
106
- for idx in ids:
107
- if idx < 0 or idx >= len(self.itos):
108
- continue
109
- token = self.itos[idx]
110
- if skip_special_tokens and token in self.specials:
111
- continue
112
- tokens.append(token)
113
- return " ".join(tokens).strip()
114
-
115
- def pad_batch(
116
- self,
117
- sequences: Sequence[Sequence[int]],
118
- pad_to_length: Optional[int] = None,
119
- ) -> Tuple[torch.Tensor, torch.Tensor]:
120
- if not sequences:
121
- raise ValueError("pad_batch requires at least one sequence")
122
- if pad_to_length is None:
123
- pad_to_length = max(len(seq) for seq in sequences)
124
- padded: List[List[int]] = []
125
- mask: List[List[int]] = []
126
- for seq in sequences:
127
- trimmed = list(seq[:pad_to_length])
128
- pad_len = pad_to_length - len(trimmed)
129
- padded.append(trimmed + [self.pad_id] * pad_len)
130
- mask.append([1] * len(trimmed) + [0] * pad_len)
131
- return torch.tensor(padded, dtype=torch.long), torch.tensor(mask, dtype=torch.bool)
132
-
133
- def save(self, path: Path) -> None:
134
- payload = {
135
- "itos": self.itos,
136
- "specials": list(self.specials),
137
- "lowercase": self.lowercase,
138
- }
139
- path.parent.mkdir(parents=True, exist_ok=True)
140
- path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
141
-
142
- @classmethod
143
- def load(cls, path: Path) -> "TransformerTokenizer":
144
- data = json.loads(path.read_text(encoding="utf-8"))
145
- itos = list(data["itos"])
146
- stoi = {token: idx for idx, token in enumerate(itos)}
147
- specials = data.get("specials", list(SPECIAL_TOKENS))
148
- lowercase = bool(data.get("lowercase", True))
149
- return cls(stoi=stoi, itos=itos, specials=specials, lowercase=lowercase)
150
-
151
- def _lookup(self, token: str) -> int:
152
- if token not in self.stoi:
153
- raise ValueError(f"token '{token}' missing from vocabulary")
154
- return self.stoi[token]
155
-
156
-
157
- @dataclass
158
  class Batch:
 
 
159
  input_ids: torch.Tensor
160
  attention_mask: torch.Tensor
161
  lengths: List[int]
162
 
163
 
164
  class TextPreprocessor:
165
- """Prepares text so it can flow directly into the custom transformer stack."""
 
 
 
 
 
166
 
167
  def __init__(
168
  self,
169
- max_length: int = 512,
170
- tokenizer: Optional[TransformerTokenizer] = None,
171
  *,
172
- min_freq: int = 1,
 
 
173
  lowercase: bool = True,
 
 
174
  ) -> None:
175
- self.max_length = max_length
176
- self.min_freq = min_freq
177
  self.lowercase = lowercase
178
- self.tokenizer = tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  def clean_text(self, text: str) -> str:
181
- return _normalize(text, self.lowercase)
182
-
183
- def fit_tokenizer(self, texts: Iterable[str]) -> TransformerTokenizer:
184
- cleaned = [self.clean_text(text) for text in texts]
185
- self.tokenizer = TransformerTokenizer.build(
186
- cleaned,
187
- min_freq=self.min_freq,
188
- lowercase=False,
189
- )
190
- return self.tokenizer
191
-
192
- def encode(self, text: str, *, add_special_tokens: bool = True) -> List[int]:
193
- if self.tokenizer is None:
194
- raise RuntimeError("Tokenizer not fitted")
195
- cleaned = self.clean_text(text)
196
- return self.tokenizer.encode(cleaned, add_special_tokens=add_special_tokens, max_length=self.max_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  def batch_encode(self, texts: Sequence[str]) -> Batch:
199
- if self.tokenizer is None:
200
- raise RuntimeError("Tokenizer not fitted")
201
- sequences = [self.encode(text) for text in texts]
202
- lengths = [len(seq) for seq in sequences]
203
- input_ids, attention_mask = self.tokenizer.pad_batch(sequences, pad_to_length=self.max_length)
204
  return Batch(input_ids=input_ids, attention_mask=attention_mask, lengths=lengths)
205
 
206
- def build_encoder(self, **encoder_kwargs) -> TransformerEncoder:
207
- if self.tokenizer is None:
208
- raise RuntimeError("Tokenizer not fitted")
209
- return TransformerEncoder(
210
- vocab_size=self.tokenizer.vocab_size,
211
- max_len=self.max_length,
212
- pad_token_id=self.tokenizer.pad_id,
213
- **encoder_kwargs,
214
- )
215
-
216
- def build_decoder(self, **decoder_kwargs) -> TransformerDecoder:
217
- if self.tokenizer is None:
218
- raise RuntimeError("Tokenizer not fitted")
219
- return TransformerDecoder(
220
- vocab_size=self.tokenizer.vocab_size,
221
- max_len=self.max_length,
222
- pad_token_id=self.tokenizer.pad_id,
223
- **decoder_kwargs,
224
- )
225
-
226
- def save_tokenizer(self, path: Path) -> None:
227
- if self.tokenizer is None:
228
- raise RuntimeError("Tokenizer not fitted")
229
- self.tokenizer.save(path)
230
-
231
- def load_tokenizer(self, path: Path) -> TransformerTokenizer:
232
- self.tokenizer = TransformerTokenizer.load(path)
233
- return self.tokenizer
234
-
235
- def chunk_text(self, text: str, *, chunk_size: int = 1000, overlap: int = 100) -> List[str]:
236
- if chunk_size <= overlap:
237
- raise ValueError("chunk_size must be larger than overlap")
238
- words = self.clean_text(text).split()
239
- chunks: List[str] = []
240
- start = 0
241
- while start < len(words):
242
- end = min(start + chunk_size, len(words))
243
- chunks.append(" ".join(words[start:end]))
244
- start += chunk_size - overlap
245
- return chunks
246
-
247
- def save_book_chunks(
248
- self,
249
- input_path: Path,
250
- out_dir: Path,
251
- *,
252
- chunk_size: int = 1000,
253
- overlap: int = 100,
254
- ) -> Path:
255
- out_dir.mkdir(parents=True, exist_ok=True)
256
- raw_text = input_path.read_text(encoding="utf-8", errors="ignore")
257
- chunks = self.chunk_text(raw_text, chunk_size=chunk_size, overlap=overlap)
258
- out_file = out_dir / f"{input_path.stem}.json"
259
- out_file.write_text(json.dumps(chunks, ensure_ascii=False, indent=2), encoding="utf-8")
260
- return out_file
 
1
+ """Text preprocessing utilities built around Hugging Face tokenizers."""
 
2
  from __future__ import annotations
3
 
 
 
 
 
4
  import re
5
+ from dataclasses import dataclass, replace
6
+ from typing import Iterable, List, Sequence
7
 
8
  import torch
9
+ from sklearn.base import BaseEstimator, TransformerMixin
10
+ from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
11
 
12
+ from .tokenization import Tokenizer, TokenizerConfig
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
14
 
15
+ class BasicTextCleaner(BaseEstimator, TransformerMixin):
16
+ """Minimal text cleaner following scikit-learn conventions."""
17
 
18
+ def __init__(self, lowercase: bool = True, strip: bool = True) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
19
  self.lowercase = lowercase
20
+ self.strip = strip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def fit(self, texts: Iterable[str], y: Iterable[str] | None = None):
23
+ return self
24
 
25
+ def transform(self, texts: Iterable[str]) -> List[str]:
26
+ return [self._clean_text(text) for text in texts]
 
27
 
28
+ def _clean_text(self, text: str) -> str:
29
+ item = text.strip() if self.strip else text
30
+ if self.lowercase:
31
+ item = item.lower()
32
+ return " ".join(item.split())
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ @dataclass(slots=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  class Batch:
37
+ """Bundle of tensors returned by the text preprocessor."""
38
+
39
  input_ids: torch.Tensor
40
  attention_mask: torch.Tensor
41
  lengths: List[int]
42
 
43
 
44
  class TextPreprocessor:
45
+ """Coordinate lightweight text cleaning and tokenization.
46
+
47
+ When supplying an already-initialized tokenizer instance, its configuration is left
48
+ untouched. If a differing ``max_length`` is requested, a ``ValueError`` is raised to
49
+ avoid mutating shared tokenizer state.
50
+ """
51
 
52
  def __init__(
53
  self,
54
+ tokenizer: Tokenizer | None = None,
 
55
  *,
56
+ tokenizer_config: TokenizerConfig | None = None,
57
+ tokenizer_name: str = "facebook/bart-base",
58
+ max_length: int | None = None,
59
  lowercase: bool = True,
60
+ remove_stopwords: bool = False,
61
+ sklearn_transformer: TransformerMixin | None = None,
62
  ) -> None:
63
+ self.cleaner = BasicTextCleaner(lowercase=lowercase, strip=True)
 
64
  self.lowercase = lowercase
65
+ if remove_stopwords:
66
+ raise ValueError(
67
+ "Stop-word removal is not supported because it conflicts with subword tokenizers; "
68
+ "clean the text externally before initializing TextPreprocessor."
69
+ )
70
+ self._stop_words = None
71
+ self._sklearn_transformer = sklearn_transformer
72
+
73
+ if tokenizer is None:
74
+ cfg = tokenizer_config or TokenizerConfig(pretrained_model_name=tokenizer_name)
75
+ if max_length is not None:
76
+ cfg = replace(cfg, max_length=max_length)
77
+ self.tokenizer = Tokenizer(cfg)
78
+ else:
79
+ self.tokenizer = tokenizer
80
+ if max_length is not None and max_length != tokenizer.config.max_length:
81
+ raise ValueError(
82
+ "Provided tokenizer config.max_length does not match requested max_length; "
83
+ "initialise the tokenizer with desired settings before passing it in."
84
+ )
85
+
86
+ self.max_length = max_length or self.tokenizer.config.max_length
87
 
88
  def clean_text(self, text: str) -> str:
89
+ item = self.cleaner.transform([text])[0]
90
+ return self._normalize_tokens(item)
91
+
92
+ def _normalize_tokens(self, text: str) -> str:
93
+ """Apply token-level normalization and optional stop-word filtering."""
94
+ # Note: Pre-tokenization word-splitting is incompatible with subword tokenizers.
95
+ # Stop-word filtering should be done post-tokenization or not at all for transformers.
96
+ return text
97
+
98
+ def _apply_sklearn_transform(self, texts: List[str]) -> List[str]:
99
+ if self._sklearn_transformer is None:
100
+ return texts
101
+
102
+ transform = getattr(self._sklearn_transformer, "transform", None)
103
+ if transform is None:
104
+ raise AttributeError("Provided sklearn transformer must implement a 'transform' method")
105
+ transformed = transform(texts)
106
+ if isinstance(transformed, list):
107
+ return transformed # assume downstream type is already list[str]
108
+ if hasattr(transformed, "tolist"):
109
+ transformed = transformed.tolist()
110
+
111
+ result = list(transformed)
112
+ if not all(isinstance(item, str) for item in result):
113
+ result = [str(item) for item in result]
114
+ return result
115
+
116
+ def _prepare_texts(self, texts: Sequence[str]) -> List[str]:
117
+ cleaned = self.cleaner.transform(texts)
118
+ normalized = [self._normalize_tokens(text) for text in cleaned]
119
+ return self._apply_sklearn_transform(normalized)
120
 
121
  def batch_encode(self, texts: Sequence[str]) -> Batch:
122
+ cleaned = self._prepare_texts(texts)
123
+ encoded = self.tokenizer.batch_encode(cleaned, max_length=self.max_length)
124
+ input_ids: torch.Tensor = encoded["input_ids"]
125
+ attention_mask: torch.Tensor = encoded["attention_mask"].to(dtype=torch.bool)
126
+ lengths = attention_mask.sum(dim=1).tolist()
127
  return Batch(input_ids=input_ids, attention_mask=attention_mask, lengths=lengths)
128
 
129
+ def __call__(self, texts: Sequence[str]) -> Batch:
130
+ return self.batch_encode(texts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/tokenization.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tokenizer wrapper around HuggingFace models used across LexiMind."""
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Iterable, List, Sequence, cast
6
+
7
+ import torch
8
+ from transformers import AutoTokenizer, PreTrainedTokenizerBase
9
+
10
+
11
+ @dataclass(slots=True)
12
+ class TokenizerConfig:
13
+ pretrained_model_name: str = "facebook/bart-base"
14
+ max_length: int = 512
15
+ padding: str = "longest"
16
+ truncation: bool = True
17
+ lower: bool = False
18
+
19
+
20
+ class Tokenizer:
21
+ """Lightweight façade over a HuggingFace tokenizer."""
22
+
23
+ def __init__(self, config: TokenizerConfig | None = None) -> None:
24
+ cfg = config or TokenizerConfig()
25
+ self.config = cfg
26
+ self._tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(cfg.pretrained_model_name)
27
+ self._pad_token_id = self._resolve_id(self._tokenizer.pad_token_id)
28
+ self._bos_token_id = self._resolve_id(
29
+ self._tokenizer.bos_token_id if self._tokenizer.bos_token_id is not None else self._tokenizer.cls_token_id
30
+ )
31
+ self._eos_token_id = self._resolve_id(
32
+ self._tokenizer.eos_token_id if self._tokenizer.eos_token_id is not None else self._tokenizer.sep_token_id
33
+ )
34
+
35
+ @property
36
+ def tokenizer(self) -> PreTrainedTokenizerBase:
37
+ return self._tokenizer
38
+
39
+ @property
40
+ def pad_token_id(self) -> int:
41
+ return self._pad_token_id
42
+
43
+ @property
44
+ def bos_token_id(self) -> int:
45
+ return self._bos_token_id
46
+
47
+ @property
48
+ def eos_token_id(self) -> int:
49
+ return self._eos_token_id
50
+
51
+ @property
52
+ def vocab_size(self) -> int:
53
+ vocab = getattr(self._tokenizer, "vocab_size", None)
54
+ if vocab is None:
55
+ raise RuntimeError("Tokenizer must expose vocab_size")
56
+ return int(vocab)
57
+
58
+ @staticmethod
59
+ def _resolve_id(value) -> int:
60
+ if value is None:
61
+ raise ValueError("Tokenizer is missing required special token ids")
62
+ if isinstance(value, (list, tuple)):
63
+ value = value[0]
64
+ return int(value)
65
+
66
+ def encode(self, text: str) -> List[int]:
67
+ content = text.lower() if self.config.lower else text
68
+ return self._tokenizer.encode(
69
+ content,
70
+ max_length=self.config.max_length,
71
+ truncation=self.config.truncation,
72
+ padding=self.config.padding,
73
+ )
74
+
75
+ def encode_batch(self, texts: Sequence[str]) -> List[List[int]]:
76
+ normalized = (text.lower() if self.config.lower else text for text in texts)
77
+ encoded = self._tokenizer.batch_encode_plus(
78
+ list(normalized),
79
+ max_length=self.config.max_length,
80
+ padding=self.config.padding,
81
+ truncation=self.config.truncation,
82
+ return_attention_mask=False,
83
+ return_tensors=None,
84
+ )
85
+ return cast(List[List[int]], encoded["input_ids"])
86
+
87
+ def batch_encode(self, texts: Sequence[str], *, max_length: int | None = None) -> dict[str, torch.Tensor]:
88
+ normalized = [text.lower() if self.config.lower else text for text in texts]
89
+ encoded = self._tokenizer(
90
+ normalized,
91
+ padding=self.config.padding,
92
+ truncation=self.config.truncation,
93
+ max_length=max_length or self.config.max_length,
94
+ return_tensors="pt",
95
+ )
96
+ input_ids = cast(torch.Tensor, encoded["input_ids"])
97
+ attention_mask = cast(torch.Tensor, encoded["attention_mask"])
98
+ if input_ids.dtype != torch.long:
99
+ input_ids = input_ids.to(dtype=torch.long)
100
+ if attention_mask.dtype != torch.bool:
101
+ attention_mask = attention_mask.to(dtype=torch.bool)
102
+ return {
103
+ "input_ids": input_ids,
104
+ "attention_mask": attention_mask,
105
+ }
106
+
107
+ def decode(self, token_ids: Iterable[int]) -> str:
108
+ return self._tokenizer.decode(list(token_ids), skip_special_tokens=True)
109
+
110
+ def decode_batch(self, sequences: Sequence[Sequence[int]]) -> List[str]:
111
+ prepared = [list(seq) for seq in sequences]
112
+ return self._tokenizer.batch_decode(prepared, skip_special_tokens=True)
113
+
114
+ def prepare_decoder_inputs(self, labels: torch.Tensor) -> torch.Tensor:
115
+ """Shift decoder labels to create input ids prefixed by BOS."""
116
+
117
+ bos = self.bos_token_id
118
+ pad = self.pad_token_id
119
+ decoder_inputs = torch.full_like(labels, pad)
120
+ decoder_inputs[:, 0] = bos
121
+ decoder_inputs[:, 1:] = labels[:, :-1]
122
+ return decoder_inputs
src/inference/__init__.py CHANGED
@@ -1,7 +1,12 @@
1
- """
2
- Inference utilities for LexiMind.
3
- """
4
 
5
- from .baseline_summarizer import Summarizer, TransformerSummarizer
 
6
 
7
- __all__ = ["Summarizer", "TransformerSummarizer"]
 
 
 
 
 
 
 
1
+ """Inference tools for LexiMind."""
 
 
2
 
3
+ from .factory import create_inference_pipeline
4
+ from .pipeline import EmotionPrediction, InferenceConfig, InferencePipeline, TopicPrediction
5
 
6
+ __all__ = [
7
+ "InferencePipeline",
8
+ "InferenceConfig",
9
+ "EmotionPrediction",
10
+ "TopicPrediction",
11
+ "create_inference_pipeline",
12
+ ]
src/inference/baseline_summarizer.py DELETED
@@ -1,41 +0,0 @@
1
- """Thin wrapper around the custom transformer summarizer."""
2
-
3
- from __future__ import annotations
4
- from typing import Any, Dict, Optional, Tuple
5
- import torch
6
- from ..api.inference import load_models
7
-
8
-
9
- class TransformerSummarizer:
10
- def __init__(self, config: Optional[Dict[str, Any]] = None) -> None:
11
- models = load_models(config or {})
12
- if not models.get("loaded"):
13
- raise RuntimeError("load_models returned an unloaded model; check configuration")
14
- self.model = models["mt"]
15
- self.preprocessor = models["preprocessor"]
16
- self.device = models["device"]
17
-
18
- def summarize(
19
- self,
20
- text: str,
21
- compression: float = 0.25,
22
- collect_attn: bool = False,
23
- ) -> Tuple[str, Optional[Dict[str, torch.Tensor]]]:
24
- batch = self.preprocessor.batch_encode([text])
25
- tokenizer = self.preprocessor.tokenizer
26
- encoder = self.model.encoder
27
- decoder = self.model.decoder
28
- if tokenizer is None or encoder is None or decoder is None:
29
- raise RuntimeError("Model components are missing; ensure encoder, decoder, and tokenizer are set")
30
- input_ids = batch.input_ids.to(self.device)
31
- memory = encoder(input_ids)
32
- src_len = batch.lengths[0]
33
- target_len = max(4, int(src_len * compression))
34
- generated = decoder.greedy_decode(
35
- memory,
36
- max_len=min(self.preprocessor.max_length, target_len),
37
- start_token_id=tokenizer.bos_id,
38
- end_token_id=tokenizer.eos_id,
39
- )
40
- summary = tokenizer.decode(generated[0].tolist(), skip_special_tokens=True)
41
- return summary.strip(), None if not collect_attn else {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/inference/factory.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers to assemble an inference pipeline from saved artifacts."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+ from typing import Tuple
6
+
7
+ import torch
8
+
9
+ from ..data.tokenization import Tokenizer, TokenizerConfig
10
+ from ..models.factory import ModelConfig, build_multitask_model, load_model_config
11
+ from ..utils.io import load_state
12
+ from ..utils.labels import LabelMetadata, load_label_metadata
13
+ from .pipeline import InferenceConfig, InferencePipeline
14
+
15
+
16
+ def create_inference_pipeline(
17
+ checkpoint_path: str | Path,
18
+ labels_path: str | Path,
19
+ *,
20
+ tokenizer_config: TokenizerConfig | None = None,
21
+ tokenizer_dir: str | Path | None = None,
22
+ model_config_path: str | Path | None = None,
23
+ device: str | torch.device = "cpu",
24
+ summary_max_length: int | None = None,
25
+ ) -> Tuple[InferencePipeline, LabelMetadata]:
26
+ """Build an :class:`InferencePipeline` from saved model and label metadata."""
27
+
28
+ checkpoint = Path(checkpoint_path)
29
+ if not checkpoint.exists():
30
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
31
+
32
+ labels = load_label_metadata(labels_path)
33
+
34
+ resolved_tokenizer_config = tokenizer_config
35
+ if resolved_tokenizer_config is None:
36
+ default_dir = Path(__file__).resolve().parent.parent.parent / "artifacts" / "hf_tokenizer"
37
+ chosen_dir = Path(tokenizer_dir) if tokenizer_dir is not None else default_dir
38
+ local_tokenizer_dir = chosen_dir
39
+ if local_tokenizer_dir.exists():
40
+ resolved_tokenizer_config = TokenizerConfig(pretrained_model_name=str(local_tokenizer_dir))
41
+ else:
42
+ raise ValueError(
43
+ "No tokenizer configuration provided and default tokenizer directory "
44
+ f"'{local_tokenizer_dir}' not found. Please provide tokenizer_config parameter or set tokenizer_dir."
45
+ )
46
+
47
+ tokenizer = Tokenizer(resolved_tokenizer_config)
48
+ model_config = load_model_config(model_config_path)
49
+ model = build_multitask_model(
50
+ tokenizer,
51
+ num_emotions=labels.emotion_size,
52
+ num_topics=labels.topic_size,
53
+ config=model_config,
54
+ )
55
+ load_state(model, str(checkpoint))
56
+
57
+ if isinstance(device, torch.device):
58
+ device_str = str(device)
59
+ else:
60
+ device_str = device
61
+
62
+ if summary_max_length is not None:
63
+ pipeline_config = InferenceConfig(summary_max_length=summary_max_length, device=device_str)
64
+ else:
65
+ pipeline_config = InferenceConfig(device=device_str)
66
+
67
+ pipeline = InferencePipeline(
68
+ model=model,
69
+ tokenizer=tokenizer,
70
+ config=pipeline_config,
71
+ emotion_labels=labels.emotion,
72
+ topic_labels=labels.topic,
73
+ device=device,
74
+ )
75
+ return pipeline, labels
src/inference/generation.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generation helpers."""
2
+
3
+ import torch
4
+
5
+
6
+ def greedy_decode(model: torch.nn.Module, input_ids: torch.Tensor, max_length: int) -> torch.Tensor:
7
+ """Run greedy decoding with ``model.generate`` and return generated token ids."""
8
+
9
+ return model.generate(
10
+ input_ids,
11
+ max_length=max_length,
12
+ do_sample=False,
13
+ num_beams=1,
14
+ )
src/inference/pipeline.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference helpers for multitask LexiMind models."""
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass, fields, replace
5
+ from typing import Iterable, List, Sequence
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from ..data.preprocessing import Batch, TextPreprocessor
11
+ from ..data.tokenization import Tokenizer
12
+
13
+
14
+ @dataclass(slots=True)
15
+ class InferenceConfig:
16
+ """Configuration knobs for the inference pipeline."""
17
+
18
+ summary_max_length: int = 128
19
+ emotion_threshold: float = 0.5
20
+ device: str | None = None
21
+
22
+
23
+ @dataclass(slots=True)
24
+ class EmotionPrediction:
25
+ labels: List[str]
26
+ scores: List[float]
27
+
28
+
29
+ @dataclass(slots=True)
30
+ class TopicPrediction:
31
+ label: str
32
+ confidence: float
33
+
34
+
35
+ class InferencePipeline:
36
+ """Run summarization, emotion, and topic heads through a unified interface."""
37
+
38
+ def __init__(
39
+ self,
40
+ model: torch.nn.Module,
41
+ tokenizer: Tokenizer,
42
+ *,
43
+ preprocessor: TextPreprocessor | None = None,
44
+ emotion_labels: Sequence[str] | None = None,
45
+ topic_labels: Sequence[str] | None = None,
46
+ config: InferenceConfig | None = None,
47
+ device: torch.device | str | None = None,
48
+ ) -> None:
49
+ self.model = model
50
+ self.tokenizer = tokenizer
51
+ self.config = config or InferenceConfig()
52
+ chosen_device = device or self.config.device
53
+ if chosen_device is None:
54
+ first_param = next(model.parameters(), None)
55
+ chosen_device = first_param.device if first_param is not None else "cpu"
56
+ self.device = torch.device(chosen_device)
57
+ self.model.to(self.device)
58
+ self.model.eval()
59
+
60
+ self.preprocessor = preprocessor or TextPreprocessor(tokenizer=tokenizer)
61
+ self.emotion_labels = list(emotion_labels) if emotion_labels is not None else None
62
+ self.topic_labels = list(topic_labels) if topic_labels is not None else None
63
+
64
+ def summarize(self, texts: Sequence[str], *, max_length: int | None = None) -> List[str]:
65
+ if not texts:
66
+ return []
67
+ batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
68
+ src_ids = batch.input_ids
69
+ src_mask = batch.attention_mask
70
+ max_len = max_length or self.config.summary_max_length
71
+
72
+ if not hasattr(self.model, "encoder") or not hasattr(self.model, "decoder"):
73
+ raise RuntimeError("Model must expose encoder and decoder attributes for summarization.")
74
+
75
+ with torch.inference_mode():
76
+ encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
77
+ memory = self.model.encoder(src_ids, mask=encoder_mask)
78
+ generated = self.model.decoder.greedy_decode(
79
+ memory=memory,
80
+ max_len=max_len,
81
+ start_token_id=self.tokenizer.bos_token_id,
82
+ end_token_id=self.tokenizer.eos_token_id,
83
+ device=self.device,
84
+ )
85
+
86
+ return self.tokenizer.decode_batch(generated.tolist())
87
+
88
+ def predict_emotions(
89
+ self,
90
+ texts: Sequence[str],
91
+ *,
92
+ threshold: float | None = None,
93
+ ) -> List[EmotionPrediction]:
94
+ if not texts:
95
+ return []
96
+ if self.emotion_labels is None or not self.emotion_labels:
97
+ raise RuntimeError("emotion_labels must be provided to decode emotion predictions")
98
+
99
+ batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
100
+ model_inputs = self._batch_to_model_inputs(batch)
101
+ decision_threshold = threshold or self.config.emotion_threshold
102
+
103
+ with torch.inference_mode():
104
+ logits = self.model.forward("emotion", model_inputs)
105
+ probs = torch.sigmoid(logits)
106
+
107
+ predictions: List[EmotionPrediction] = []
108
+ for row in probs.cpu():
109
+ pairs = [
110
+ (label, score)
111
+ for label, score in zip(self.emotion_labels, row.tolist())
112
+ if score >= decision_threshold
113
+ ]
114
+ labels = [label for label, _ in pairs]
115
+ scores = [score for _, score in pairs]
116
+ predictions.append(EmotionPrediction(labels=labels, scores=scores))
117
+ return predictions
118
+
119
+ def predict_topics(self, texts: Sequence[str]) -> List[TopicPrediction]:
120
+ if not texts:
121
+ return []
122
+ if self.topic_labels is None or not self.topic_labels:
123
+ raise RuntimeError("topic_labels must be provided to decode topic predictions")
124
+
125
+ batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
126
+ model_inputs = self._batch_to_model_inputs(batch)
127
+
128
+ with torch.inference_mode():
129
+ logits = self.model.forward("topic", model_inputs)
130
+ probs = F.softmax(logits, dim=-1)
131
+
132
+ results: List[TopicPrediction] = []
133
+ for row in probs.cpu():
134
+ scores = row.tolist()
135
+ best_index = int(row.argmax().item())
136
+ results.append(TopicPrediction(label=self.topic_labels[best_index], confidence=scores[best_index]))
137
+ return results
138
+
139
+ def batch_predict(self, texts: Iterable[str]) -> dict[str, object]:
140
+ text_list = list(texts)
141
+ if self.emotion_labels is None or not self.emotion_labels:
142
+ raise RuntimeError("emotion_labels must be provided for batch predictions")
143
+ if self.topic_labels is None or not self.topic_labels:
144
+ raise RuntimeError("topic_labels must be provided for batch predictions")
145
+ return {
146
+ "summaries": self.summarize(text_list),
147
+ "emotion": self.predict_emotions(text_list),
148
+ "topic": self.predict_topics(text_list),
149
+ }
150
+
151
+ def _batch_to_device(self, batch: Batch) -> Batch:
152
+ tensor_updates: dict[str, torch.Tensor] = {}
153
+ for item in fields(batch):
154
+ value = getattr(batch, item.name)
155
+ if torch.is_tensor(value):
156
+ tensor_updates[item.name] = value.to(self.device)
157
+ if not tensor_updates:
158
+ return batch
159
+ return replace(batch, **tensor_updates)
160
+
161
+ @staticmethod
162
+ def _batch_to_model_inputs(batch: Batch) -> dict[str, torch.Tensor]:
163
+ inputs: dict[str, torch.Tensor] = {"input_ids": batch.input_ids}
164
+ if batch.attention_mask is not None:
165
+ inputs["attention_mask"] = batch.attention_mask
166
+ return inputs
src/inference/postprocessing.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Output cleaning helpers."""
2
+ from typing import List
3
+
4
+
5
+ def strip_whitespace(texts: List[str]) -> List[str]:
6
+ return [text.strip() for text in texts]
src/models/factory.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Factory helpers to assemble multitask models for inference/training."""
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ from ..data.tokenization import Tokenizer
9
+ from ..utils.config import load_yaml
10
+ from .decoder import TransformerDecoder
11
+ from .encoder import TransformerEncoder
12
+ from .heads import ClassificationHead, LMHead
13
+ from .multitask import MultiTaskModel
14
+
15
+
16
+ @dataclass(slots=True)
17
+ class ModelConfig:
18
+ """Configuration describing the transformer architecture."""
19
+
20
+ d_model: int = 512
21
+ num_encoder_layers: int = 6
22
+ num_decoder_layers: int = 6
23
+ num_attention_heads: int = 8
24
+ ffn_dim: int = 2048
25
+ dropout: float = 0.1
26
+
27
+ def __post_init__(self):
28
+ if self.d_model % self.num_attention_heads != 0:
29
+ raise ValueError(
30
+ f"d_model ({self.d_model}) must be divisible by num_attention_heads ({self.num_attention_heads})"
31
+ )
32
+ if not 0 <= self.dropout <= 1:
33
+ raise ValueError(f"dropout must be in [0, 1], got {self.dropout}")
34
+ if self.d_model <= 0 or self.num_encoder_layers <= 0 or self.num_decoder_layers <= 0:
35
+ raise ValueError("Model dimensions must be positive")
36
+ if self.num_attention_heads <= 0 or self.ffn_dim <= 0:
37
+ raise ValueError("Model dimensions must be positive")
38
+
39
+
40
+ def load_model_config(path: Optional[str | Path]) -> ModelConfig:
41
+ """Load a model configuration from YAML with sane defaults."""
42
+
43
+ if path is None:
44
+ return ModelConfig()
45
+
46
+ data = load_yaml(str(path)).data
47
+ return ModelConfig(
48
+ d_model=int(data.get("d_model", 512)),
49
+ num_encoder_layers=int(data.get("num_encoder_layers", 6)),
50
+ num_decoder_layers=int(data.get("num_decoder_layers", 6)),
51
+ num_attention_heads=int(data.get("num_attention_heads", 8)),
52
+ ffn_dim=int(data.get("ffn_dim", 2048)),
53
+ dropout=float(data.get("dropout", 0.1)),
54
+ )
55
+
56
+
57
+ def build_multitask_model(
58
+ tokenizer: Tokenizer,
59
+ *,
60
+ num_emotions: int,
61
+ num_topics: int,
62
+ config: ModelConfig | None = None,
63
+ ) -> MultiTaskModel:
64
+ """Construct the multitask transformer with heads for the three tasks."""
65
+
66
+ cfg = config or ModelConfig()
67
+ if not isinstance(num_emotions, int) or num_emotions <= 0:
68
+ raise ValueError("num_emotions must be a positive integer")
69
+ if not isinstance(num_topics, int) or num_topics <= 0:
70
+ raise ValueError("num_topics must be a positive integer")
71
+ encoder = TransformerEncoder(
72
+ vocab_size=tokenizer.vocab_size,
73
+ d_model=cfg.d_model,
74
+ num_layers=cfg.num_encoder_layers,
75
+ num_heads=cfg.num_attention_heads,
76
+ d_ff=cfg.ffn_dim,
77
+ dropout=cfg.dropout,
78
+ max_len=tokenizer.config.max_length,
79
+ pad_token_id=tokenizer.pad_token_id,
80
+ )
81
+ decoder = TransformerDecoder(
82
+ vocab_size=tokenizer.vocab_size,
83
+ d_model=cfg.d_model,
84
+ num_layers=cfg.num_decoder_layers,
85
+ num_heads=cfg.num_attention_heads,
86
+ d_ff=cfg.ffn_dim,
87
+ dropout=cfg.dropout,
88
+ max_len=tokenizer.config.max_length,
89
+ pad_token_id=tokenizer.pad_token_id,
90
+ )
91
+
92
+ model = MultiTaskModel(encoder=encoder, decoder=decoder, decoder_outputs_logits=True)
93
+ model.add_head(
94
+ "summarization",
95
+ LMHead(d_model=cfg.d_model, vocab_size=tokenizer.vocab_size, tie_embedding=decoder.embedding),
96
+ )
97
+ model.add_head(
98
+ "emotion",
99
+ ClassificationHead(d_model=cfg.d_model, num_labels=num_emotions, pooler="mean", dropout=cfg.dropout),
100
+ )
101
+ model.add_head(
102
+ "topic",
103
+ ClassificationHead(d_model=cfg.d_model, num_labels=num_topics, pooler="mean", dropout=cfg.dropout),
104
+ )
105
+ return model
src/models/multitask.py CHANGED
@@ -39,17 +39,28 @@ class MultiTaskModel(nn.Module):
39
  mt = MultiTaskModel(encoder=enc, decoder=dec)
40
  mt.add_head("summarize", LMHead(...))
41
  logits = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids})
 
 
 
 
 
 
42
  """
43
 
44
  def __init__(
45
  self,
46
  encoder: Optional[TransformerEncoder] = None,
47
  decoder: Optional[TransformerDecoder] = None,
 
 
48
  ):
49
  super().__init__()
50
  self.encoder = encoder
51
  self.decoder = decoder
52
  self.heads: Dict[str, nn.Module] = {}
 
 
 
53
 
54
  def add_head(self, name: str, module: nn.Module) -> None:
55
  """Register a head under a task name."""
@@ -99,9 +110,15 @@ class MultiTaskModel(nn.Module):
99
  raise RuntimeError("Encoder is required for encoder-side heads")
100
  # accept either input_ids or embeddings
101
  if "input_ids" in inputs:
102
- enc_out = self.encoder(inputs["input_ids"])
 
 
 
103
  elif "embeddings" in inputs:
104
- enc_out = self.encoder(inputs["embeddings"])
 
 
 
105
  else:
106
  raise ValueError("inputs must contain 'input_ids' or 'embeddings' for encoder tasks")
107
  logits = head(enc_out)
@@ -120,10 +137,20 @@ class MultiTaskModel(nn.Module):
120
  raise RuntimeError("Both encoder and decoder are required for LM-style heads")
121
 
122
  # Build encoder memory
 
 
 
 
 
 
 
 
 
 
123
  if "src_ids" in inputs:
124
- memory = self.encoder(inputs["src_ids"])
125
  elif "src_embeddings" in inputs:
126
- memory = self.encoder(inputs["src_embeddings"])
127
  else:
128
  raise ValueError("inputs must contain 'src_ids' or 'src_embeddings' for seq2seq tasks")
129
 
@@ -137,12 +164,13 @@ class MultiTaskModel(nn.Module):
137
  # Here we don't attempt to generate when labels not provided.
138
  raise ValueError("Seq2seq tasks require 'tgt_ids' or 'tgt_embeddings' for training forward")
139
 
140
- # Run decoder. Decoder returns logits shaped (B, T, vocab) in this codebase.
141
  decoder_out = self.decoder(decoder_inputs, memory)
142
 
143
- # If decoder already returned logits matching the head vocab size, use them directly.
144
- # Otherwise, assume decoder returned hidden states and let the head project them.
145
- if isinstance(decoder_out, torch.Tensor) and decoder_out.shape[-1] == head.vocab_size:
 
 
146
  logits = decoder_out
147
  else:
148
  logits = head(decoder_out)
@@ -195,4 +223,15 @@ class MultiTaskModel(nn.Module):
195
  return F.cross_entropy(logits, labels.long())
196
 
197
  # If we can't determine, raise
198
- raise RuntimeError("Cannot compute loss for unknown head type")
 
 
 
 
 
 
 
 
 
 
 
 
39
  mt = MultiTaskModel(encoder=enc, decoder=dec)
40
  mt.add_head("summarize", LMHead(...))
41
  logits = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids})
42
+
43
+ Args:
44
+ encoder: optional encoder backbone.
45
+ decoder: optional decoder backbone.
46
+ decoder_outputs_logits: set True when ``decoder.forward`` already returns vocabulary logits;
47
+ set False if the decoder produces hidden states that must be projected by the LM head.
48
  """
49
 
50
  def __init__(
51
  self,
52
  encoder: Optional[TransformerEncoder] = None,
53
  decoder: Optional[TransformerDecoder] = None,
54
+ *,
55
+ decoder_outputs_logits: bool = True,
56
  ):
57
  super().__init__()
58
  self.encoder = encoder
59
  self.decoder = decoder
60
  self.heads: Dict[str, nn.Module] = {}
61
+ # When True, decoder.forward(...) is expected to return logits already projected to the vocabulary space.
62
+ # When False, decoder outputs hidden states that must be passed through the registered LM head.
63
+ self.decoder_outputs_logits = decoder_outputs_logits
64
 
65
  def add_head(self, name: str, module: nn.Module) -> None:
66
  """Register a head under a task name."""
 
110
  raise RuntimeError("Encoder is required for encoder-side heads")
111
  # accept either input_ids or embeddings
112
  if "input_ids" in inputs:
113
+ encoder_mask = None
114
+ if "attention_mask" in inputs:
115
+ encoder_mask = self._expand_attention_mask(inputs["attention_mask"], inputs["input_ids"].device)
116
+ enc_out = self.encoder(inputs["input_ids"], mask=encoder_mask)
117
  elif "embeddings" in inputs:
118
+ encoder_mask = inputs.get("attention_mask")
119
+ if encoder_mask is not None:
120
+ encoder_mask = self._expand_attention_mask(encoder_mask, inputs["embeddings"].device)
121
+ enc_out = self.encoder(inputs["embeddings"], mask=encoder_mask)
122
  else:
123
  raise ValueError("inputs must contain 'input_ids' or 'embeddings' for encoder tasks")
124
  logits = head(enc_out)
 
137
  raise RuntimeError("Both encoder and decoder are required for LM-style heads")
138
 
139
  # Build encoder memory
140
+ src_mask = inputs.get("src_mask")
141
+ if src_mask is None:
142
+ src_mask = inputs.get("attention_mask")
143
+ encoder_mask = None
144
+ reference_tensor = inputs.get("src_ids")
145
+ if reference_tensor is None:
146
+ reference_tensor = inputs.get("src_embeddings")
147
+ if src_mask is not None and reference_tensor is not None:
148
+ encoder_mask = self._expand_attention_mask(src_mask, reference_tensor.device)
149
+
150
  if "src_ids" in inputs:
151
+ memory = self.encoder(inputs["src_ids"], mask=encoder_mask)
152
  elif "src_embeddings" in inputs:
153
+ memory = self.encoder(inputs["src_embeddings"], mask=encoder_mask)
154
  else:
155
  raise ValueError("inputs must contain 'src_ids' or 'src_embeddings' for seq2seq tasks")
156
 
 
164
  # Here we don't attempt to generate when labels not provided.
165
  raise ValueError("Seq2seq tasks require 'tgt_ids' or 'tgt_embeddings' for training forward")
166
 
 
167
  decoder_out = self.decoder(decoder_inputs, memory)
168
 
169
+ if self.decoder_outputs_logits:
170
+ if not isinstance(decoder_out, torch.Tensor):
171
+ raise TypeError(
172
+ "Decoder is configured to return logits, but forward returned a non-tensor value."
173
+ )
174
  logits = decoder_out
175
  else:
176
  logits = head(decoder_out)
 
223
  return F.cross_entropy(logits, labels.long())
224
 
225
  # If we can't determine, raise
226
+ raise RuntimeError("Cannot compute loss for unknown head type")
227
+
228
+ @staticmethod
229
+ def _expand_attention_mask(mask: torch.Tensor, device: torch.device) -> torch.Tensor:
230
+ if mask is None:
231
+ return None # type: ignore[return-value]
232
+ bool_mask = mask.to(device=device, dtype=torch.bool)
233
+ if bool_mask.dim() == 2:
234
+ return bool_mask.unsqueeze(1) & bool_mask.unsqueeze(2)
235
+ if bool_mask.dim() in (3, 4):
236
+ return bool_mask
237
+ raise ValueError("Attention mask must be 2D, 3D, or 4D tensor")
src/training/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ """Training utilities for LexiMind."""
src/training/callbacks.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Callback hooks for training."""
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Dict, Optional
5
+
6
+ import torch
7
+
8
+
9
+ def save_checkpoint(
10
+ model: torch.nn.Module,
11
+ optimizer: torch.optim.Optimizer,
12
+ epoch: int,
13
+ output_path: str,
14
+ *,
15
+ metrics: Optional[Dict[str, Any]] = None,
16
+ ) -> None:
17
+ """Persist model and optimizer state for resuming training."""
18
+
19
+ checkpoint = {
20
+ "model_state_dict": model.state_dict(),
21
+ "optimizer_state_dict": optimizer.state_dict(),
22
+ "epoch": int(epoch),
23
+ }
24
+ if metrics:
25
+ checkpoint["metrics"] = metrics
26
+
27
+ target = Path(output_path)
28
+ target.parent.mkdir(parents=True, exist_ok=True)
29
+ temp_path = target.parent / f"{target.name}.tmp"
30
+ try:
31
+ torch.save(checkpoint, temp_path)
32
+ temp_path.replace(target)
33
+ except Exception:
34
+ raise
35
+ finally:
36
+ if temp_path.exists():
37
+ temp_path.unlink(missing_ok=True)
src/training/losses.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loss helpers."""
2
+ import torch
3
+
4
+
5
+ def multitask_loss(losses: dict[str, torch.Tensor]) -> torch.Tensor:
6
+ iterator = iter(losses.values())
7
+ try:
8
+ total = next(iterator).clone()
9
+ except StopIteration:
10
+ raise ValueError("losses is empty")
11
+ for value in iterator:
12
+ total = total + value
13
+ return total / len(losses)
src/training/metrics.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Metric helpers used during training and evaluation."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Sequence
5
+
6
+ import torch
7
+
8
+
9
+ def accuracy(predictions: Sequence[int], targets: Sequence[int]) -> float:
10
+ matches = sum(int(pred == target) for pred, target in zip(predictions, targets))
11
+ return matches / max(1, len(predictions))
12
+
13
+
14
+ def multilabel_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
15
+ preds = predictions.float()
16
+ gold = targets.float()
17
+ true_positive = (preds * gold).sum(dim=1)
18
+ precision = true_positive / (preds.sum(dim=1).clamp(min=1.0))
19
+ recall = true_positive / (gold.sum(dim=1).clamp(min=1.0))
20
+ f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
21
+ return float(f1.mean().item())
22
+
23
+
24
+ def rouge_like(predictions: Sequence[str], references: Sequence[str]) -> float:
25
+ if not predictions or not references:
26
+ return 0.0
27
+ scores = []
28
+ for pred, ref in zip(predictions, references):
29
+ pred_tokens = pred.split()
30
+ ref_tokens = ref.split()
31
+ if not ref_tokens:
32
+ scores.append(0.0)
33
+ continue
34
+ overlap = len(set(pred_tokens) & set(ref_tokens))
35
+ scores.append(overlap / len(ref_tokens))
36
+ return sum(scores) / len(scores)