Spaces:
Running
Running
Commit
·
1fbc47b
1
Parent(s):
f9edbb4
chore: snapshot current refinements
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +45 -153
- configs/data/datasets.yaml +26 -0
- configs/model/base.yaml +6 -50
- configs/model/large.yaml +6 -0
- configs/model/small.yaml +6 -23
- configs/training/default.yaml +12 -0
- configs/training/full.yaml +12 -0
- configs/training/quick_test.yaml +9 -0
- docker/Dockerfile +0 -0
- docker/docker-compose.yml +0 -0
- docs/api.md +79 -0
- docs/architecture.md +57 -0
- docs/training.md +59 -0
- pyproject.toml +6 -11
- requirements.txt +6 -16
- scripts/download_data.py +182 -0
- scripts/download_data.sh +5 -0
- scripts/evaluate.py +134 -0
- scripts/export_model.py +69 -0
- scripts/inference.py +112 -0
- scripts/preprocess_data.py +321 -0
- scripts/test_gpu.py +0 -27
- scripts/train.py +217 -6
- setup.py +16 -6
- src/__init__.py +1 -0
- src/api/__init__.py +1 -0
- src/api/app.py +10 -0
- src/api/dependencies.py +42 -0
- src/api/inference/__init__.py +0 -7
- src/api/inference/inference.py +0 -133
- src/api/routes.py +34 -0
- src/api/schemas.py +14 -0
- src/data/__init__.py +1 -0
- src/data/dataloader.py +117 -0
- src/data/dataset.py +229 -0
- src/data/download.py +39 -60
- src/data/preprocessing.py +95 -225
- src/data/tokenization.py +122 -0
- src/inference/__init__.py +10 -5
- src/inference/baseline_summarizer.py +0 -41
- src/inference/factory.py +75 -0
- src/inference/generation.py +14 -0
- src/inference/pipeline.py +166 -0
- src/inference/postprocessing.py +6 -0
- src/models/factory.py +105 -0
- src/models/multitask.py +48 -9
- src/training/__init__.py +1 -0
- src/training/callbacks.py +37 -0
- src/training/losses.py +13 -0
- src/training/metrics.py +36 -0
README.md
CHANGED
|
@@ -1,175 +1,67 @@
|
|
| 1 |
-
# LexiMind
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
##
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
- Custom encoder-decoder Transformer architecture with shared representations
|
| 14 |
-
- Multi-task loss function with learnable task weighting
|
| 15 |
-
- Attention weight visualization for model interpretability
|
| 16 |
-
- Interactive web interface for real-time inference
|
| 17 |
-
- Trained on diverse corpora: news articles (CNN/DailyMail, BBC) and literary texts (Project Gutenberg)
|
| 18 |
-
|
| 19 |
-
## 🏗️ Architecture
|
| 20 |
-
|
| 21 |
-
```
|
| 22 |
-
Input Text
|
| 23 |
-
↓
|
| 24 |
-
┌─────────────────────┐
|
| 25 |
-
│ Shared Encoder │ ← TransformerEncoder (6 layers)
|
| 26 |
-
│ (Multi-head Attn) │
|
| 27 |
-
└─────────────────────┘
|
| 28 |
-
↓ ↓ ↓
|
| 29 |
-
│ │ └──────────────┐
|
| 30 |
-
│ │ │
|
| 31 |
-
│ └─────────┐ │
|
| 32 |
-
│ │ │
|
| 33 |
-
↓ ↓ ↓
|
| 34 |
-
┌─────────┐ ┌────────┐ ┌─────────┐
|
| 35 |
-
│ Decoder │ │Classify│ │ Project │
|
| 36 |
-
│ Head │ │ Head │ │ Head │
|
| 37 |
-
└─────────┘ └────────┘ └─────────┘
|
| 38 |
-
↓ ↓ ↓
|
| 39 |
-
Summary Emotions Embeddings
|
| 40 |
-
(for clustering)
|
| 41 |
-
```
|
| 42 |
-
|
| 43 |
-
## 📊 Datasets
|
| 44 |
-
|
| 45 |
-
- **CNN/DailyMail**: 300k+ news articles with human-written summaries
|
| 46 |
-
- **BBC News**: 2,225 articles across 5 categories
|
| 47 |
-
- **Project Gutenberg**: Classic literature for long-form text analysis
|
| 48 |
-
|
| 49 |
-
## 🚀 Quick Start
|
| 50 |
-
|
| 51 |
-
### Installation
|
| 52 |
```bash
|
| 53 |
git clone https://github.com/OliverPerrin/LexiMind.git
|
| 54 |
cd LexiMind
|
| 55 |
pip install -r requirements.txt
|
| 56 |
-
```
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
|
| 63 |
-
|
| 64 |
-
```bash
|
| 65 |
-
python src/train.py --config configs/default.yaml
|
| 66 |
-
```
|
| 67 |
-
|
| 68 |
-
### Launch Interface
|
| 69 |
-
```bash
|
| 70 |
-
python src/app.py
|
| 71 |
```
|
| 72 |
|
| 73 |
-
|
|
|
|
| 74 |
|
|
|
|
| 75 |
```
|
| 76 |
-
|
| 77 |
-
├──
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
│ │ └── clustering.py # Projection head for embeddings
|
| 83 |
-
│ ├── data/
|
| 84 |
-
│ │ ├── download_datasets.py # Data acquisition
|
| 85 |
-
│ │ ├── preprocessing.py # Text cleaning & tokenization
|
| 86 |
-
│ │ └── dataset.py # PyTorch Dataset classes
|
| 87 |
-
│ ├── training/
|
| 88 |
-
│ │ ├── train.py # Training loop
|
| 89 |
-
│ │ ├── losses.py # Multi-task loss functions
|
| 90 |
-
│ │ └── metrics.py # ROUGE, F1, silhouette scores
|
| 91 |
-
│ ├── inference/
|
| 92 |
-
│ │ └── pipeline.py # End-to-end inference
|
| 93 |
-
│ ├── visualization/
|
| 94 |
-
│ │ └── attention.py # Attention heatmap generation
|
| 95 |
-
│ └── app.py # Gradio/FastAPI interface
|
| 96 |
-
├── configs/
|
| 97 |
-
│ └── default.yaml # Model & training hyperparameters
|
| 98 |
-
├── tests/
|
| 99 |
-
│ └── test_*.py # Unit tests
|
| 100 |
-
├── notebooks/
|
| 101 |
-
│ └── exploratory.ipynb # Data exploration & analysis
|
| 102 |
-
├── requirements.txt
|
| 103 |
-
└── README.md
|
| 104 |
```
|
| 105 |
|
| 106 |
-
|
|
|
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
| Emotion Classification | Macro F1 | TBD |
|
| 112 |
-
| Topic Clustering | Silhouette Score | TBD |
|
| 113 |
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
- **Decoder**: 6-layer autoregressive Transformer
|
| 119 |
-
- **Vocab Size**: 32,000 (SentencePiece tokenizer)
|
| 120 |
-
- **Parameters**: ~60M total
|
| 121 |
-
|
| 122 |
-
### Training
|
| 123 |
-
- **Optimizer**: AdamW (lr=1e-4, weight_decay=0.01)
|
| 124 |
-
- **Scheduler**: Linear warmup (5000 steps) + cosine decay
|
| 125 |
-
- **Loss**: Weighted sum of cross-entropy (summarization), BCE (emotions), triplet loss (clustering)
|
| 126 |
-
- **Hardware**: Trained on single NVIDIA RTX 3090 (24GB VRAM)
|
| 127 |
-
- **Time**: ~48 hours for 10 epochs
|
| 128 |
-
|
| 129 |
-
### Multi-Task Learning Strategy
|
| 130 |
-
Uses uncertainty weighting ([Kendall et al., 2018](https://arxiv.org/abs/1705.07115)) to automatically balance task losses:
|
| 131 |
-
|
| 132 |
-
```
|
| 133 |
-
L_total = Σ (1/2σ²_i * L_i + log(σ_i))
|
| 134 |
```
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
## 🎨 Interface Preview
|
| 139 |
-
|
| 140 |
-
The web interface provides:
|
| 141 |
-
- Text input with real-time token count
|
| 142 |
-
- Compression level slider (20%-80%)
|
| 143 |
-
- Side-by-side original/summary comparison
|
| 144 |
-
- Emotion probability bars with color coding
|
| 145 |
-
- Interactive attention heatmap (click tokens to highlight attention)
|
| 146 |
-
- Downloadable results (JSON/CSV)
|
| 147 |
-
|
| 148 |
-
## 📈 Future Enhancements
|
| 149 |
-
|
| 150 |
-
- [ ] Add multilingual support (mBART)
|
| 151 |
-
- [ ] Implement beam search for better summaries
|
| 152 |
-
- [ ] Fine-tune on domain-specific corpora (medical, legal)
|
| 153 |
-
- [ ] Add semantic search across document embeddings
|
| 154 |
-
- [ ] Deploy as REST API with Docker
|
| 155 |
-
- [ ] Implement model distillation for mobile deployment
|
| 156 |
-
|
| 157 |
-
## 📚 References
|
| 158 |
-
|
| 159 |
-
- Vaswani et al. (2017) - [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
|
| 160 |
-
- Lewis et al. (2019) - [BART: Denoising Sequence-to-Sequence Pre-training](https://arxiv.org/abs/1910.13461)
|
| 161 |
-
- Caruana (1997) - [Multitask Learning](https://link.springer.com/article/10.1023/A:1007379606734)
|
| 162 |
-
- Demszky et al. (2020) - [GoEmotions Dataset](https://arxiv.org/abs/2005.00547)
|
| 163 |
-
|
| 164 |
-
## 📄 License
|
| 165 |
-
|
| 166 |
-
GNU General Public License v3.0
|
| 167 |
-
|
| 168 |
-
## 👤 Author
|
| 169 |
-
|
| 170 |
-
**Oliver Perrin**
|
| 171 |
-
- Portfolio: [oliverperrin.com](https://oliverperrin.com)
|
| 172 |
-
- LinkedIn: [linkedin.com/in/oliverperrin](https://linkedin.com/in/oliverperrin)
|
| 173 |
-
- Email: [email protected]
|
| 174 |
|
| 175 |
-
|
|
|
|
|
|
| 1 |
+
# LexiMind (Inference Edition)
|
| 2 |
|
| 3 |
+
LexiMind now ships as a focused inference sandbox for the custom multitask Transformer found in
|
| 4 |
+
`src/models`. Training, dataset downloaders, and legacy scripts have been removed so it is easy to
|
| 5 |
+
load a checkpoint, run the Streamlit demo, and experiment with summarization, emotion
|
| 6 |
+
classification, and topic cues on your own text.
|
| 7 |
|
| 8 |
+
## What Stays
|
| 9 |
+
- Transformer encoder/decoder and task heads under `src/models`
|
| 10 |
+
- Unit tests for the model stack (`tests/test_models`)
|
| 11 |
+
- Streamlit UI (`src/ui/streamlit_app.py`) wired to the inference helpers in `src/api/inference`
|
| 12 |
|
| 13 |
+
## What Changed
|
| 14 |
+
- Hugging Face tokenizers provide all tokenization (see `TextPreprocessor`)
|
| 15 |
+
- Training, dataset downloaders, and CLI scripts have been removed
|
| 16 |
+
- Scikit-learn powers light text normalization (stop-word removal optional)
|
| 17 |
+
- Requirements trimmed to inference-only dependencies
|
| 18 |
|
| 19 |
+
## Quick Start
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
```bash
|
| 21 |
git clone https://github.com/OliverPerrin/LexiMind.git
|
| 22 |
cd LexiMind
|
| 23 |
pip install -r requirements.txt
|
|
|
|
| 24 |
|
| 25 |
+
# Optional extras via setup.py packaging metadata
|
| 26 |
+
pip install .[web] # installs streamlit + plotly
|
| 27 |
+
pip install .[api] # installs fastapi
|
| 28 |
+
pip install .[all] # installs both groups
|
| 29 |
|
| 30 |
+
streamlit run src/ui/streamlit_app.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
```
|
| 32 |
|
| 33 |
+
Configure the Streamlit app via the sidebar to point at your tokenizer directory and model
|
| 34 |
+
checkpoint (defaults assume `artifacts/hf_tokenizer` and `checkpoints/best.pt`).
|
| 35 |
|
| 36 |
+
## Minimal Project Map
|
| 37 |
```
|
| 38 |
+
src/
|
| 39 |
+
├── api/ # load_models + helpers
|
| 40 |
+
├── data/ # TextPreprocessor using Hugging Face + sklearn
|
| 41 |
+
├── inference/ # thin summarizer facade
|
| 42 |
+
├── models/ # core Transformer architecture (untouched)
|
| 43 |
+
└── ui/ # Streamlit interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
```
|
| 45 |
|
| 46 |
+
Everything outside `src/` now holds optional assets such as checkpoints, tokenizer exports, and
|
| 47 |
+
documentation stubs.
|
| 48 |
|
| 49 |
+
## Loading a Checkpoint Programmatically
|
| 50 |
+
```python
|
| 51 |
+
from src.api.inference import load_models, summarize_text
|
|
|
|
|
|
|
| 52 |
|
| 53 |
+
models = load_models({
|
| 54 |
+
"checkpoint_path": "checkpoints/best.pt",
|
| 55 |
+
"tokenizer_path": "artifacts/hf_tokenizer",
|
| 56 |
+
"hf_tokenizer_name": "facebook/bart-base",
|
| 57 |
+
})
|
| 58 |
|
| 59 |
+
summary, _ = summarize_text("Paste any article here.", models=models)
|
| 60 |
+
print(summary)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
```
|
| 62 |
|
| 63 |
+
## License
|
| 64 |
+
GPL-3.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
## Author
|
| 67 |
+
Oliver Perrin · [email protected]
|
configs/data/datasets.yaml
CHANGED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
raw:
|
| 2 |
+
summarization: data/raw/summarization/cnn_dailymail
|
| 3 |
+
emotion: data/raw/emotion
|
| 4 |
+
topic: data/raw/topic
|
| 5 |
+
books: data/raw/books
|
| 6 |
+
processed:
|
| 7 |
+
summarization: data/processed/summarization
|
| 8 |
+
emotion: data/processed/emotion
|
| 9 |
+
topic: data/processed/topic
|
| 10 |
+
books: data/processed/books
|
| 11 |
+
tokenizer:
|
| 12 |
+
pretrained_model_name: facebook/bart-base
|
| 13 |
+
max_length: 512
|
| 14 |
+
lower: false
|
| 15 |
+
downloads:
|
| 16 |
+
summarization:
|
| 17 |
+
dataset: gowrishankarp/newspaper-text-summarization-cnn-dailymail
|
| 18 |
+
output: data/raw/summarization/cnn_dailymail
|
| 19 |
+
books:
|
| 20 |
+
- name: pride_and_prejudice
|
| 21 |
+
url: https://www.gutenberg.org/cache/epub/1342/pg1342.txt
|
| 22 |
+
output: data/raw/books/pride_and_prejudice.txt
|
| 23 |
+
emotion:
|
| 24 |
+
dataset: dair-ai/emotion
|
| 25 |
+
topic:
|
| 26 |
+
dataset: ag_news
|
configs/model/base.yaml
CHANGED
|
@@ -1,50 +1,6 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
d_ff: 2048
|
| 8 |
-
dropout: 0.1
|
| 9 |
-
max_seq_length: 512
|
| 10 |
-
|
| 11 |
-
tasks:
|
| 12 |
-
summarization:
|
| 13 |
-
enabled: true
|
| 14 |
-
decoder_layers: 6
|
| 15 |
-
|
| 16 |
-
emotion:
|
| 17 |
-
enabled: true
|
| 18 |
-
num_classes: 27
|
| 19 |
-
pool_strategy: "mean" # Options: mean, max, cls, attention
|
| 20 |
-
|
| 21 |
-
clustering:
|
| 22 |
-
enabled: true
|
| 23 |
-
embedding_dim: 128
|
| 24 |
-
normalize: true
|
| 25 |
-
|
| 26 |
-
training:
|
| 27 |
-
batch_size: 16
|
| 28 |
-
gradient_accumulation_steps: 2 # Effective batch = 32
|
| 29 |
-
learning_rate: 1e-4
|
| 30 |
-
weight_decay: 0.01
|
| 31 |
-
num_epochs: 10
|
| 32 |
-
warmup_steps: 1000
|
| 33 |
-
max_grad_norm: 1.0
|
| 34 |
-
|
| 35 |
-
scheduler:
|
| 36 |
-
type: "cosine" # Options: linear, cosine, polynomial
|
| 37 |
-
|
| 38 |
-
mixed_precision: true # Use AMP for faster training
|
| 39 |
-
|
| 40 |
-
data:
|
| 41 |
-
max_length: 512
|
| 42 |
-
summary_max_length: 128
|
| 43 |
-
train_split: 0.8
|
| 44 |
-
val_split: 0.1
|
| 45 |
-
test_split: 0.1
|
| 46 |
-
|
| 47 |
-
preprocessing:
|
| 48 |
-
lowercase: true
|
| 49 |
-
remove_stopwords: false
|
| 50 |
-
min_token_length: 3
|
|
|
|
| 1 |
+
d_model: 512
|
| 2 |
+
num_encoder_layers: 6
|
| 3 |
+
num_decoder_layers: 6
|
| 4 |
+
num_attention_heads: 8
|
| 5 |
+
ffn_dim: 2048
|
| 6 |
+
dropout: 0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/model/large.yaml
CHANGED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
d_model: 768
|
| 2 |
+
num_encoder_layers: 12
|
| 3 |
+
num_decoder_layers: 12
|
| 4 |
+
num_attention_heads: 12
|
| 5 |
+
ffn_dim: 3072
|
| 6 |
+
dropout: 0.1
|
configs/model/small.yaml
CHANGED
|
@@ -1,23 +1,6 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
training:
|
| 9 |
-
batch_size: 32 # ~4GB VRAM
|
| 10 |
-
gradient_accumulation_steps: 1
|
| 11 |
-
mixed_precision: true # Essential!
|
| 12 |
-
|
| 13 |
-
# configs/model/base.yaml (production)
|
| 14 |
-
model:
|
| 15 |
-
d_model: 512
|
| 16 |
-
num_encoder_layers: 6
|
| 17 |
-
num_decoder_layers: 6
|
| 18 |
-
num_heads: 8
|
| 19 |
-
|
| 20 |
-
training:
|
| 21 |
-
batch_size: 8 # ~8GB VRAM
|
| 22 |
-
gradient_accumulation_steps: 4 # Effective batch = 32
|
| 23 |
-
mixed_precision: true
|
|
|
|
| 1 |
+
d_model: 256
|
| 2 |
+
num_encoder_layers: 4
|
| 3 |
+
num_decoder_layers: 4
|
| 4 |
+
num_attention_heads: 4
|
| 5 |
+
ffn_dim: 1024
|
| 6 |
+
dropout: 0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/training/default.yaml
CHANGED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataloader:
|
| 2 |
+
batch_size: 8
|
| 3 |
+
shuffle: true
|
| 4 |
+
optimizer:
|
| 5 |
+
name: adamw
|
| 6 |
+
lr: 3.0e-5
|
| 7 |
+
scheduler:
|
| 8 |
+
name: cosine
|
| 9 |
+
warmup_steps: 500
|
| 10 |
+
trainer:
|
| 11 |
+
max_epochs: 5
|
| 12 |
+
gradient_clip_norm: 1.0
|
configs/training/full.yaml
CHANGED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataloader:
|
| 2 |
+
batch_size: 16
|
| 3 |
+
shuffle: true
|
| 4 |
+
optimizer:
|
| 5 |
+
name: adamw
|
| 6 |
+
lr: 2.0e-5
|
| 7 |
+
scheduler:
|
| 8 |
+
name: cosine
|
| 9 |
+
warmup_steps: 1000
|
| 10 |
+
trainer:
|
| 11 |
+
max_epochs: 15
|
| 12 |
+
gradient_clip_norm: 1.0
|
configs/training/quick_test.yaml
CHANGED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataloader:
|
| 2 |
+
batch_size: 2
|
| 3 |
+
shuffle: false
|
| 4 |
+
optimizer:
|
| 5 |
+
name: adamw
|
| 6 |
+
lr: 1.0e-4
|
| 7 |
+
trainer:
|
| 8 |
+
max_epochs: 1
|
| 9 |
+
gradient_clip_norm: 0.5
|
docker/Dockerfile
DELETED
|
File without changes
|
docker/docker-compose.yml
DELETED
|
File without changes
|
docs/api.md
CHANGED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# API & CLI Documentation
|
| 2 |
+
|
| 3 |
+
## FastAPI Service
|
| 4 |
+
The FastAPI application is defined in `src/api/app.py` and wires routes from
|
| 5 |
+
`src/api/routes.py`. All dependencies resolve through `src/api/dependencies.py`, which lazily constructs the shared inference pipeline.
|
| 6 |
+
|
| 7 |
+
### POST `/summarize`
|
| 8 |
+
- **Request Body** (`SummaryRequest`):
|
| 9 |
+
```json
|
| 10 |
+
{
|
| 11 |
+
"text": "Your input document"
|
| 12 |
+
}
|
| 13 |
+
```
|
| 14 |
+
- **Response** (`SummaryResponse`):
|
| 15 |
+
```json
|
| 16 |
+
{
|
| 17 |
+
"summary": "Generated abstractive summary",
|
| 18 |
+
"emotion_labels": ["joy", "surprise"],
|
| 19 |
+
"emotion_scores": [0.91, 0.63],
|
| 20 |
+
"topic": "news",
|
| 21 |
+
"topic_confidence": 0.82
|
| 22 |
+
}
|
| 23 |
+
```
|
| 24 |
+
- **Behaviour:**
|
| 25 |
+
1. Text is preprocessed through `TextPreprocessor` (with optional sklearn transformer if configured).
|
| 26 |
+
2. The multitask model generates a summary via greedy decoding.
|
| 27 |
+
3. Emotion and topic heads produce logits which are converted to probabilities and mapped to
|
| 28 |
+
human-readable labels using `artifacts/labels.json`.
|
| 29 |
+
4. Results are returned as structured JSON suitable for a future Gradio interface.
|
| 30 |
+
|
| 31 |
+
### Error Handling
|
| 32 |
+
- If the checkpoint or label metadata is missing, the dependency raises an HTTP 503 error with
|
| 33 |
+
an explanatory message.
|
| 34 |
+
- Validation errors (missing `text`) are handled automatically by FastAPI/Pydantic.
|
| 35 |
+
|
| 36 |
+
## Command-Line Interface
|
| 37 |
+
`scripts/inference.py` provides a CLI that mirrors the API behaviour.
|
| 38 |
+
|
| 39 |
+
### Usage
|
| 40 |
+
```bash
|
| 41 |
+
python scripts/inference.py "Document to analyse" \
|
| 42 |
+
--checkpoint checkpoints/best.pt \
|
| 43 |
+
--labels artifacts/labels.json \
|
| 44 |
+
--tokenizer artifacts/hf_tokenizer \
|
| 45 |
+
--model-config configs/model/base.yaml \
|
| 46 |
+
--device cpu
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
Options:
|
| 50 |
+
- `text` – zero or more positional arguments. If omitted, use `--file` to point to a newline
|
| 51 |
+
delimited text file.
|
| 52 |
+
- `--file` – optional path containing one text per line.
|
| 53 |
+
- `--checkpoint` – path to the trained model weights.
|
| 54 |
+
- `--labels` – JSON containing emotion/topic vocabularies (defaults to `artifacts/labels.json`).
|
| 55 |
+
- `--tokenizer` – optional tokenizer directory; defaults to the exported artifact if present.
|
| 56 |
+
- `--model-config` – YAML describing the architecture.
|
| 57 |
+
- `--device` – `cpu` or `cuda`. Passing `cuda` attempts to run inference on GPU.
|
| 58 |
+
- `--summary-max-length` – overrides the default maximum generation length.
|
| 59 |
+
|
| 60 |
+
### Output
|
| 61 |
+
The CLI prints a JSON array where each entry contains the original text, summary, emotion labels
|
| 62 |
+
with scores, and topic prediction. This format is identical to the REST response, facilitating
|
| 63 |
+
integration tests and future Gradio UI rendering.
|
| 64 |
+
|
| 65 |
+
## Future Gradio UI
|
| 66 |
+
- The planned UI will call the same inference pipeline and display results interactively.
|
| 67 |
+
- Given the response schema, the UI can show:
|
| 68 |
+
- Generated summary text.
|
| 69 |
+
- Emotion chips with probability bars.
|
| 70 |
+
- Topic confidence gauges.
|
| 71 |
+
- Placeholder panel for attention heatmaps and explanations.
|
| 72 |
+
- Once implemented, documentation updates will add a `docs/ui.md` section and screenshots under
|
| 73 |
+
`docs/images/`.
|
| 74 |
+
|
| 75 |
+
## Testing
|
| 76 |
+
- `tests/test_api/test_routes.py` stubs the pipeline to ensure response fields and dependency
|
| 77 |
+
overrides behave as expected.
|
| 78 |
+
- `tests/test_inference/test_pipeline.py` validates pipeline methods end-to-end with dummy models,
|
| 79 |
+
guaranteeing API and CLI consumers receive consistent payload shapes.
|
docs/architecture.md
CHANGED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LexiMind Architecture
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
LexiMind couples a from-scratch Transformer implementation with a modern data and inference stack. The project consists of three major layers:
|
| 5 |
+
|
| 6 |
+
1. **Data & Preprocessing** – lightweight text cleaning built on top of scikit-learn
|
| 7 |
+
primitives and a Hugging Face tokenizer wrapper with deterministic batching helpers.
|
| 8 |
+
2. **Model Composition** – the bespoke encoder/decoder stack with task heads assembled via
|
| 9 |
+
`MultiTaskModel`, plus `models.factory.build_multitask_model` to rebuild the network from
|
| 10 |
+
configuration files.
|
| 11 |
+
3. **Inference & Serving** – a multi-task pipeline capable of summarization, emotion, and topic classification; surfaced through a CLI and FastAPI service with plans for a Gradio UI.
|
| 12 |
+
|
| 13 |
+
## Custom Transformer Stack
|
| 14 |
+
- `src/models/encoder.py` and `src/models/decoder.py` implement Pre-LayerNorm Transformer
|
| 15 |
+
blocks with explicit positional encoding, masking logic, and incremental decoding support.
|
| 16 |
+
- `src/models/heads.py` provides modular output heads. Summarization uses an `LMHead` tied to
|
| 17 |
+
the decoder embedding weights; emotion and topic tasks use `ClassificationHead` instances.
|
| 18 |
+
- `src/models/multitask.py` routes inputs to the correct head, computes task-specific losses,
|
| 19 |
+
and exposes a single forward API used by the trainer and inference pipeline.
|
| 20 |
+
- `src/models/factory.py` rebuilds the encoder, decoder, and heads directly from YAML config
|
| 21 |
+
and tokenizer metadata so inference rebuilds the exact architecture used in training.
|
| 22 |
+
|
| 23 |
+
## Data, Tokenization, and Preprocessing
|
| 24 |
+
- `src/data/tokenization.py` wraps `AutoTokenizer` to provide tensor-aware batching and helper
|
| 25 |
+
utilities for decoder input shifting, BOS/EOS resolution, and vocab size retrieval.
|
| 26 |
+
- `src/data/preprocessing.py` introduces `TextPreprocessor`, layering a `BasicTextCleaner` with
|
| 27 |
+
optional scikit-learn transformers (via `sklearn_transformer`) before tokenization. This keeps
|
| 28 |
+
the default cleaning minimal while allowing future reuse of `sklearn.preprocessing` utilities
|
| 29 |
+
without changing calling code.
|
| 30 |
+
- `src/data/dataset.py` and `src/data/dataloader.py` define strongly typed dataset containers and
|
| 31 |
+
collators that encode inputs with the shared tokenizer and set up task-specific labels (multi-label
|
| 32 |
+
emotions, categorical topics, seq2seq summaries).
|
| 33 |
+
|
| 34 |
+
## Training Pipeline
|
| 35 |
+
- `src/training/trainer.py` coordinates multi-task optimization with per-task loss functions, gradient clipping, and shared tokenizer decoding for metric computation.
|
| 36 |
+
- Metrics in `src/training/metrics.py` include accuracy, multi-label F1, and a ROUGE-like overlap score for summarization. These metrics mirror the trainer outputs logged per task.
|
| 37 |
+
- Label vocabularies are serialized to `artifacts/labels.json` after training so inference can decode class indices consistently.
|
| 38 |
+
|
| 39 |
+
## Inference & Serving
|
| 40 |
+
- `src/inference/pipeline.py` exposes summarization, emotion, and topic predictions with shared pre-processing, generation, and thresholding logic. It expects label vocabularies from the serialized metadata file.
|
| 41 |
+
- `src/inference/factory.py` rebuilds the full pipeline by loading the tokenizer (preferring the exported tokenizer artifact), reconstructing the model via the factory helpers, restoring checkpoints, and injecting label metadata.
|
| 42 |
+
- The CLI (`scripts/inference.py`) drives the pipeline from the command line. The FastAPI app (`src/api/routes.py`) exposes the `/summarize` endpoint that returns summaries, emotion labels + scores, and topic predictions. Test coverage in `tests/test_inference` and `tests/test_api` validates both layers with lightweight stubs.
|
| 43 |
+
|
| 44 |
+
## Gradio UI Roadmap
|
| 45 |
+
- The inference pipeline returns structured outputs that are already suitable for a web UI.
|
| 46 |
+
- Planned steps for a Gradio demo:
|
| 47 |
+
1. Wrap `InferencePipeline.batch_predict` inside Gradio callbacks for text input.
|
| 48 |
+
2. Display summaries alongside emotion tag chips and topic confidence bars.
|
| 49 |
+
3. Surface token-level attention visualizations by extending the pipeline to emit decoder attention maps (hooks already exist in the decoder).
|
| 50 |
+
- Documentation and code paths were structured to keep the Gradio integration isolated in a future `src/ui/gradio_app.py` module without altering core logic.
|
| 51 |
+
|
| 52 |
+
## Key Decisions
|
| 53 |
+
- **Custom Transformer Preservation** – all modeling remains on the bespoke encoder/decoder, satisfying the constraint to avoid Hugging Face model classes while still leveraging their tokenizer implementation.
|
| 54 |
+
- **Tokenizer Artifact Preference** – inference automatically favors the exported tokenizer in `artifacts/hf_tokenizer`, guaranteeing consistent vocabularies between training and serving.
|
| 55 |
+
- **Sklearn-friendly Preprocessing** – the text preprocessor now accepts an optional
|
| 56 |
+
`TransformerMixin` so additional normalization (lemmatization, custom token filters, etc.) can be injected using familiar scikit-learn tooling without rewriting the batching code.
|
| 57 |
+
- **Documentation Alignment** – the `docs/` folder mirrors the structure requested, capturing design reasoning and paving the way for future diagrams in `docs/images`.
|
docs/training.md
CHANGED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training Procedure
|
| 2 |
+
|
| 3 |
+
## Data Sources
|
| 4 |
+
- **Summarization** – expects JSONL files with `source` and `summary` fields under
|
| 5 |
+
`data/processed/summarization`.
|
| 6 |
+
- **Emotion Classification** – multi-label samples loaded from JSONL files with
|
| 7 |
+
`text` and `emotions` arrays. The dataset owns a `MultiLabelBinarizer` for consistent encoding.
|
| 8 |
+
- **Topic Classification** – single-label categorical samples with `text` and `topic` fields, encoded via `LabelEncoder`.
|
| 9 |
+
|
| 10 |
+
Paths and tokenizer defaults are configured in `configs/data/datasets.yaml`. The tokenizer section chooses the Hugging Face backbone (`facebook/bart-base` by default) and maximum length. Gutenberg book downloads are controlled via the `downloads.books` list (each entry includes `name`, `url`, and `output`).
|
| 11 |
+
|
| 12 |
+
## Dataloaders & Collators
|
| 13 |
+
- `SummarizationCollator` encodes encoder/decoder inputs, prepares decoder input IDs via `Tokenizer.prepare_decoder_inputs`, and masks padding tokens with `-100` for loss computation.
|
| 14 |
+
- `EmotionCollator` applies the dataset's `MultiLabelBinarizer`, returning dense float tensors suitable for `BCEWithLogitsLoss`.
|
| 15 |
+
- `TopicCollator` emits integer class IDs via the dataset's `LabelEncoder` for `CrossEntropyLoss`.
|
| 16 |
+
|
| 17 |
+
These collators keep all tokenization centralized, reducing duplication and making it easy to swap in additional sklearn transformations through `TextPreprocessor` should we wish to extend cleaning or normalization.
|
| 18 |
+
|
| 19 |
+
## Model Assembly
|
| 20 |
+
- `src/models/factory.build_multitask_model` rebuilds the encoder, decoder, and heads from the tokenizer metadata and YAML config. This factory is used both during training and inference to eliminate drift between environments.
|
| 21 |
+
- The model wraps:
|
| 22 |
+
- Transformer encoder/decoder stacks with shared positional encodings.
|
| 23 |
+
- LM head tied to decoder embeddings for summarization.
|
| 24 |
+
- Mean-pooled classification heads for emotion and topic tasks.
|
| 25 |
+
|
| 26 |
+
## Optimisation Loop
|
| 27 |
+
- `src/training/trainer.Trainer` orchestrates multi-task training.
|
| 28 |
+
- Cross-entropy is used for summarization (seq2seq logits vs. shifted labels).
|
| 29 |
+
- `BCEWithLogitsLoss` handles multi-label emotions.
|
| 30 |
+
- `CrossEntropyLoss` handles topic classification.
|
| 31 |
+
- Gradient clipping ensures stability, and per-task weights can be configured via
|
| 32 |
+
`TrainerConfig.task_weights` to balance gradients if needed.
|
| 33 |
+
- Metrics tracked per task:
|
| 34 |
+
- **Summarization** – ROUGE-like overlap metric (`training.metrics.rouge_like`).
|
| 35 |
+
- **Emotion** – micro F1 score for multi-label predictions.
|
| 36 |
+
- **Topic** – categorical accuracy.
|
| 37 |
+
|
| 38 |
+
## Checkpoints & Artifacts
|
| 39 |
+
- `src/utils/io.save_state` stores model weights; checkpoints live under `checkpoints/`.
|
| 40 |
+
- `artifacts/labels.json` captures the ordered emotion/topic vocabularies immediately after
|
| 41 |
+
training. This file is required for inference so class indices map back to human-readable labels.
|
| 42 |
+
- The tokenizer is exported to `artifacts/hf_tokenizer/` for reproducible vocabularies.
|
| 43 |
+
|
| 44 |
+
## Running Training
|
| 45 |
+
1. Ensure processed datasets are available (see `data/processed/` structure).
|
| 46 |
+
2. Choose a configuration (e.g., `configs/training/default.yaml`) for hyperparameters and data splits.
|
| 47 |
+
3. Instantiate the tokenizer via `TokenizerConfig` and build datasets/dataloaders.
|
| 48 |
+
4. Use `build_multitask_model` to construct the model, create an optimizer, and run
|
| 49 |
+
`Trainer.fit(train_loaders, val_loaders)`.
|
| 50 |
+
5. Save checkpoints and update `artifacts/labels.json` with the dataset label order.
|
| 51 |
+
|
| 52 |
+
> **Note:** A full CLI for training is forthcoming. The scripts in `scripts/` currently act as
|
| 53 |
+
> scaffolding; once the Gradio UI is introduced we will extend these utilities to launch
|
| 54 |
+
> training jobs with configuration files directly.
|
| 55 |
+
|
| 56 |
+
## Future Enhancements
|
| 57 |
+
- Integrate curriculum scheduling or task-balanced sampling once empirical results dictate.
|
| 58 |
+
- Capture attention maps during training to support visualization in the planned Gradio UI.
|
| 59 |
+
- Leverage the optional `sklearn_transformer` hook in `TextPreprocessor` for lemmatization or domain-specific normalization when datasets require it.
|
pyproject.toml
CHANGED
|
@@ -13,19 +13,14 @@ license = {text = "GPL-3.0"}
|
|
| 13 |
|
| 14 |
dependencies = [
|
| 15 |
"torch>=2.0.0",
|
| 16 |
-
"
|
| 17 |
-
"datasets>=2.14.0",
|
| 18 |
-
"tokenizers>=0.13.0",
|
| 19 |
"numpy>=1.24.0",
|
| 20 |
"pandas>=2.0.0",
|
| 21 |
-
"
|
| 22 |
-
"
|
| 23 |
-
"
|
| 24 |
-
"
|
| 25 |
-
"
|
| 26 |
-
"omegaconf>=2.3.0",
|
| 27 |
-
"tensorboard>=2.13.0",
|
| 28 |
-
"gradio>=3.35.0",
|
| 29 |
]
|
| 30 |
|
| 31 |
[project.optional-dependencies]
|
|
|
|
| 13 |
|
| 14 |
dependencies = [
|
| 15 |
"torch>=2.0.0",
|
| 16 |
+
"scikit-learn>=1.4.0",
|
|
|
|
|
|
|
| 17 |
"numpy>=1.24.0",
|
| 18 |
"pandas>=2.0.0",
|
| 19 |
+
"streamlit>=1.25.0",
|
| 20 |
+
"plotly>=5.18.0",
|
| 21 |
+
"transformers>=4.40.0",
|
| 22 |
+
"fastapi>=0.110.0",
|
| 23 |
+
"datasets>=4.4.0",
|
|
|
|
|
|
|
|
|
|
| 24 |
]
|
| 25 |
|
| 26 |
[project.optional-dependencies]
|
requirements.txt
CHANGED
|
@@ -1,22 +1,12 @@
|
|
| 1 |
# requirements.txt
|
| 2 |
torch>=2.0.0
|
| 3 |
-
transformers>=4.
|
| 4 |
-
|
| 5 |
-
tokenizers>=0.13.0
|
| 6 |
numpy>=1.24.0
|
| 7 |
pandas>=2.0.0
|
| 8 |
-
scikit-learn>=1.3.0
|
| 9 |
-
matplotlib>=3.7.0
|
| 10 |
-
seaborn>=0.12.0
|
| 11 |
-
nltk>=3.8.0
|
| 12 |
-
tqdm>=4.65.0
|
| 13 |
-
pyyaml>=6.0
|
| 14 |
-
omegaconf>=2.3.0
|
| 15 |
-
tensorboard>=2.13.0
|
| 16 |
-
gradio>=3.35.0
|
| 17 |
-
requests>=2.31.0
|
| 18 |
-
kaggle>=1.5.12
|
| 19 |
streamlit>=1.25.0
|
| 20 |
plotly>=5.18.0
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
| 1 |
# requirements.txt
|
| 2 |
torch>=2.0.0
|
| 3 |
+
transformers>=4.40.0
|
| 4 |
+
scikit-learn>=1.4.0
|
|
|
|
| 5 |
numpy>=1.24.0
|
| 6 |
pandas>=2.0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
streamlit>=1.25.0
|
| 8 |
plotly>=5.18.0
|
| 9 |
+
fastapi>=0.110.0
|
| 10 |
+
datasets>=4.4.0
|
| 11 |
+
pytest
|
| 12 |
+
matplotlib
|
scripts/download_data.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Download datasets used by LexiMind."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Iterable, Iterator, cast
|
| 9 |
+
|
| 10 |
+
from datasets import ClassLabel, Dataset, DatasetDict, load_dataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 14 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 15 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 16 |
+
|
| 17 |
+
from src.data.download import gutenberg_download, kaggle_download
|
| 18 |
+
from src.utils.config import load_yaml
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
DEFAULT_SUMMARIZATION_DATASET = "gowrishankarp/newspaper-text-summarization-cnn-dailymail"
|
| 22 |
+
DEFAULT_EMOTION_DATASET = "dair-ai/emotion"
|
| 23 |
+
DEFAULT_TOPIC_DATASET = "ag_news"
|
| 24 |
+
DEFAULT_BOOK_URL = "https://www.gutenberg.org/cache/epub/1342/pg1342.txt"
|
| 25 |
+
DEFAULT_BOOK_OUTPUT = "data/raw/books/pride_and_prejudice.txt"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def parse_args() -> argparse.Namespace:
|
| 29 |
+
parser = argparse.ArgumentParser(description="Download datasets required for LexiMind training")
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--config",
|
| 32 |
+
default="configs/data/datasets.yaml",
|
| 33 |
+
help="Path to the dataset configuration YAML.",
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument("--skip-kaggle", action="store_true", help="Skip downloading the Kaggle summarization dataset.")
|
| 36 |
+
parser.add_argument("--skip-book", action="store_true", help="Skip downloading Gutenberg book texts.")
|
| 37 |
+
return parser.parse_args()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _safe_load_config(path: str | None) -> dict:
|
| 41 |
+
if not path:
|
| 42 |
+
return {}
|
| 43 |
+
config_path = Path(path)
|
| 44 |
+
if not config_path.exists():
|
| 45 |
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 46 |
+
return load_yaml(str(config_path)).data
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _write_jsonl(records: Iterable[dict[str, object]], destination: Path) -> None:
|
| 50 |
+
destination.parent.mkdir(parents=True, exist_ok=True)
|
| 51 |
+
with destination.open("w", encoding="utf-8") as handle:
|
| 52 |
+
for record in records:
|
| 53 |
+
handle.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _emotion_records(dataset_split: Dataset, label_names: list[str] | None) -> Iterator[dict[str, object]]:
|
| 57 |
+
for item in dataset_split:
|
| 58 |
+
data = dict(item)
|
| 59 |
+
text = data.get("text", "")
|
| 60 |
+
label_value = data.get("label")
|
| 61 |
+
def resolve_label(index: object) -> str:
|
| 62 |
+
if isinstance(index, int) and label_names and 0 <= index < len(label_names):
|
| 63 |
+
return label_names[index]
|
| 64 |
+
return str(index)
|
| 65 |
+
|
| 66 |
+
if isinstance(label_value, list):
|
| 67 |
+
labels = [resolve_label(idx) for idx in label_value]
|
| 68 |
+
else:
|
| 69 |
+
labels = [resolve_label(label_value)]
|
| 70 |
+
yield {"text": text, "emotions": labels}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _topic_records(dataset_split: Dataset, label_names: list[str] | None) -> Iterator[dict[str, object]]:
|
| 74 |
+
for item in dataset_split:
|
| 75 |
+
data = dict(item)
|
| 76 |
+
text = data.get("text") or data.get("content") or ""
|
| 77 |
+
label_value = data.get("label")
|
| 78 |
+
def resolve_topic(raw: object) -> str:
|
| 79 |
+
if label_names:
|
| 80 |
+
idx: int | None = None
|
| 81 |
+
if isinstance(raw, int):
|
| 82 |
+
idx = raw
|
| 83 |
+
elif isinstance(raw, str):
|
| 84 |
+
try:
|
| 85 |
+
idx = int(raw)
|
| 86 |
+
except ValueError:
|
| 87 |
+
idx = None
|
| 88 |
+
if idx is not None and 0 <= idx < len(label_names):
|
| 89 |
+
return label_names[idx]
|
| 90 |
+
return str(raw) if raw is not None else ""
|
| 91 |
+
|
| 92 |
+
if isinstance(label_value, list):
|
| 93 |
+
topic = resolve_topic(label_value[0]) if label_value else ""
|
| 94 |
+
else:
|
| 95 |
+
topic = resolve_topic(label_value)
|
| 96 |
+
yield {"text": text, "topic": topic}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def main() -> None:
|
| 100 |
+
args = parse_args()
|
| 101 |
+
config = _safe_load_config(args.config)
|
| 102 |
+
|
| 103 |
+
raw_paths = config.get("raw", {}) if isinstance(config, dict) else {}
|
| 104 |
+
downloads_cfg = config.get("downloads", {}) if isinstance(config, dict) else {}
|
| 105 |
+
|
| 106 |
+
summarization_cfg = downloads_cfg.get("summarization", {}) if isinstance(downloads_cfg, dict) else {}
|
| 107 |
+
summarization_dataset = summarization_cfg.get("dataset", DEFAULT_SUMMARIZATION_DATASET)
|
| 108 |
+
summarization_output = summarization_cfg.get("output", raw_paths.get("summarization", "data/raw/summarization"))
|
| 109 |
+
|
| 110 |
+
if not args.skip_kaggle and summarization_dataset:
|
| 111 |
+
print(f"Downloading summarization dataset '{summarization_dataset}' -> {summarization_output}")
|
| 112 |
+
kaggle_download(summarization_dataset, summarization_output)
|
| 113 |
+
else:
|
| 114 |
+
print("Skipping Kaggle summarization download.")
|
| 115 |
+
|
| 116 |
+
books_root = Path(raw_paths.get("books", "data/raw/books"))
|
| 117 |
+
books_root.mkdir(parents=True, exist_ok=True)
|
| 118 |
+
|
| 119 |
+
books_entries: list[dict[str, object]] = []
|
| 120 |
+
if isinstance(downloads_cfg, dict):
|
| 121 |
+
raw_entries = downloads_cfg.get("books")
|
| 122 |
+
if isinstance(raw_entries, list):
|
| 123 |
+
books_entries = [entry for entry in raw_entries if isinstance(entry, dict)]
|
| 124 |
+
|
| 125 |
+
if not args.skip_book:
|
| 126 |
+
if not books_entries:
|
| 127 |
+
books_entries = [
|
| 128 |
+
{
|
| 129 |
+
"name": "pride_and_prejudice",
|
| 130 |
+
"url": DEFAULT_BOOK_URL,
|
| 131 |
+
"output": DEFAULT_BOOK_OUTPUT,
|
| 132 |
+
}
|
| 133 |
+
]
|
| 134 |
+
for entry in books_entries:
|
| 135 |
+
name = str(entry.get("name") or "gutenberg_text")
|
| 136 |
+
url = str(entry.get("url") or DEFAULT_BOOK_URL)
|
| 137 |
+
output_value = entry.get("output")
|
| 138 |
+
destination = Path(output_value) if isinstance(output_value, str) and output_value else books_root / f"{name}.txt"
|
| 139 |
+
destination.parent.mkdir(parents=True, exist_ok=True)
|
| 140 |
+
print(f"Downloading Gutenberg text '{name}' from {url} -> {destination}")
|
| 141 |
+
gutenberg_download(url, str(destination))
|
| 142 |
+
else:
|
| 143 |
+
print("Skipping Gutenberg downloads.")
|
| 144 |
+
emotion_cfg = downloads_cfg.get("emotion", {}) if isinstance(downloads_cfg, dict) else {}
|
| 145 |
+
emotion_name = emotion_cfg.get("dataset", DEFAULT_EMOTION_DATASET)
|
| 146 |
+
emotion_dir = Path(raw_paths.get("emotion", "data/raw/emotion"))
|
| 147 |
+
emotion_dir.mkdir(parents=True, exist_ok=True)
|
| 148 |
+
print(f"Downloading emotion dataset '{emotion_name}' -> {emotion_dir}")
|
| 149 |
+
emotion_dataset = cast(DatasetDict, load_dataset(emotion_name))
|
| 150 |
+
first_emotion_key = next(iter(emotion_dataset.keys()), None) if emotion_dataset else None
|
| 151 |
+
emotion_label_feature = (
|
| 152 |
+
emotion_dataset[first_emotion_key].features.get("label")
|
| 153 |
+
if first_emotion_key is not None
|
| 154 |
+
else None
|
| 155 |
+
)
|
| 156 |
+
emotion_label_names = emotion_label_feature.names if isinstance(emotion_label_feature, ClassLabel) else None
|
| 157 |
+
for split_name, split in emotion_dataset.items():
|
| 158 |
+
output_path = emotion_dir / f"{str(split_name)}.jsonl"
|
| 159 |
+
_write_jsonl(_emotion_records(split, emotion_label_names), output_path)
|
| 160 |
+
|
| 161 |
+
topic_cfg = downloads_cfg.get("topic", {}) if isinstance(downloads_cfg, dict) else {}
|
| 162 |
+
topic_name = topic_cfg.get("dataset", DEFAULT_TOPIC_DATASET)
|
| 163 |
+
topic_dir = Path(raw_paths.get("topic", "data/raw/topic"))
|
| 164 |
+
topic_dir.mkdir(parents=True, exist_ok=True)
|
| 165 |
+
print(f"Downloading topic dataset '{topic_name}' -> {topic_dir}")
|
| 166 |
+
topic_dataset = cast(DatasetDict, load_dataset(topic_name))
|
| 167 |
+
first_topic_key = next(iter(topic_dataset.keys()), None) if topic_dataset else None
|
| 168 |
+
topic_label_feature = (
|
| 169 |
+
topic_dataset[first_topic_key].features.get("label")
|
| 170 |
+
if first_topic_key is not None
|
| 171 |
+
else None
|
| 172 |
+
)
|
| 173 |
+
topic_label_names = topic_label_feature.names if isinstance(topic_label_feature, ClassLabel) else None
|
| 174 |
+
for split_name, split in topic_dataset.items():
|
| 175 |
+
output_path = topic_dir / f"{str(split_name)}.jsonl"
|
| 176 |
+
_write_jsonl(_topic_records(split, topic_label_names), output_path)
|
| 177 |
+
|
| 178 |
+
print("Download routine finished.")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
main()
|
scripts/download_data.sh
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
python3 "${SCRIPT_DIR}/download_data.py"
|
scripts/evaluate.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluate the multitask model on processed validation/test splits."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
| 12 |
+
|
| 13 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 14 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 15 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 16 |
+
|
| 17 |
+
from src.data.dataset import (
|
| 18 |
+
load_emotion_jsonl,
|
| 19 |
+
load_summarization_jsonl,
|
| 20 |
+
load_topic_jsonl,
|
| 21 |
+
)
|
| 22 |
+
from src.inference.factory import create_inference_pipeline
|
| 23 |
+
from src.training.metrics import accuracy, multilabel_f1, rouge_like
|
| 24 |
+
from src.utils.config import load_yaml
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
SPLIT_ALIASES = {
|
| 28 |
+
"train": ("train",),
|
| 29 |
+
"val": ("val", "validation"),
|
| 30 |
+
"test": ("test",),
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _read_split(root: Path, split: str, loader) -> list:
|
| 35 |
+
aliases = SPLIT_ALIASES.get(split, (split,))
|
| 36 |
+
for alias in aliases:
|
| 37 |
+
for ext in ("jsonl", "json"):
|
| 38 |
+
candidate = root / f"{alias}.{ext}"
|
| 39 |
+
if candidate.exists():
|
| 40 |
+
return loader(str(candidate))
|
| 41 |
+
raise FileNotFoundError(f"Missing {split} split under {root}")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def parse_args() -> argparse.Namespace:
|
| 45 |
+
parser = argparse.ArgumentParser(description="Evaluate the LexiMind multitask model")
|
| 46 |
+
parser.add_argument("--split", default="val", choices=["train", "val", "test"], help="Dataset split to evaluate.")
|
| 47 |
+
parser.add_argument("--checkpoint", default="checkpoints/best.pt", help="Path to the trained checkpoint.")
|
| 48 |
+
parser.add_argument("--labels", default="artifacts/labels.json", help="Label metadata JSON.")
|
| 49 |
+
parser.add_argument("--data-config", default="configs/data/datasets.yaml", help="Data configuration YAML.")
|
| 50 |
+
parser.add_argument("--model-config", default="configs/model/base.yaml", help="Model architecture YAML.")
|
| 51 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device for evaluation.")
|
| 52 |
+
parser.add_argument("--batch-size", type=int, default=16, help="Batch size for generation/classification during evaluation.")
|
| 53 |
+
return parser.parse_args()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def chunks(items: List, size: int):
|
| 57 |
+
for start in range(0, len(items), size):
|
| 58 |
+
yield items[start : start + size]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main() -> None:
|
| 62 |
+
args = parse_args()
|
| 63 |
+
data_cfg = load_yaml(args.data_config).data
|
| 64 |
+
|
| 65 |
+
pipeline, metadata = create_inference_pipeline(
|
| 66 |
+
checkpoint_path=args.checkpoint,
|
| 67 |
+
labels_path=args.labels,
|
| 68 |
+
tokenizer_config=None,
|
| 69 |
+
model_config_path=args.model_config,
|
| 70 |
+
device=args.device,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
summarization_dir = Path(data_cfg["processed"]["summarization"])
|
| 74 |
+
emotion_dir = Path(data_cfg["processed"]["emotion"])
|
| 75 |
+
topic_dir = Path(data_cfg["processed"]["topic"])
|
| 76 |
+
|
| 77 |
+
summary_examples = _read_split(summarization_dir, args.split, load_summarization_jsonl)
|
| 78 |
+
emotion_examples = _read_split(emotion_dir, args.split, load_emotion_jsonl)
|
| 79 |
+
topic_examples = _read_split(topic_dir, args.split, load_topic_jsonl)
|
| 80 |
+
|
| 81 |
+
emotion_binarizer = MultiLabelBinarizer(classes=metadata.emotion)
|
| 82 |
+
# Ensure scikit-learn initializes the attributes using metadata ordering.
|
| 83 |
+
emotion_binarizer.fit([[label] for label in metadata.emotion])
|
| 84 |
+
|
| 85 |
+
# Summarization
|
| 86 |
+
summaries_pred = []
|
| 87 |
+
summaries_ref = []
|
| 88 |
+
for batch in chunks(summary_examples, args.batch_size):
|
| 89 |
+
inputs = [example.source for example in batch]
|
| 90 |
+
summaries_pred.extend(pipeline.summarize(inputs))
|
| 91 |
+
summaries_ref.extend([example.summary for example in batch])
|
| 92 |
+
rouge_score = rouge_like(summaries_pred, summaries_ref)
|
| 93 |
+
|
| 94 |
+
# Emotion
|
| 95 |
+
emotion_preds_tensor = []
|
| 96 |
+
emotion_target_tensor = []
|
| 97 |
+
label_to_index = {label: idx for idx, label in enumerate(metadata.emotion)}
|
| 98 |
+
for batch in chunks(emotion_examples, args.batch_size):
|
| 99 |
+
inputs = [example.text for example in batch]
|
| 100 |
+
predictions = pipeline.predict_emotions(inputs)
|
| 101 |
+
target_matrix = emotion_binarizer.transform([list(example.emotions) for example in batch])
|
| 102 |
+
for pred, target_row in zip(predictions, target_matrix):
|
| 103 |
+
vector = torch.zeros(len(metadata.emotion), dtype=torch.float32)
|
| 104 |
+
for label in pred.labels:
|
| 105 |
+
idx = label_to_index.get(label)
|
| 106 |
+
if idx is not None:
|
| 107 |
+
vector[idx] = 1.0
|
| 108 |
+
emotion_preds_tensor.append(vector)
|
| 109 |
+
emotion_target_tensor.append(torch.tensor(target_row, dtype=torch.float32))
|
| 110 |
+
emotion_f1 = multilabel_f1(torch.stack(emotion_preds_tensor), torch.stack(emotion_target_tensor))
|
| 111 |
+
|
| 112 |
+
# Topic
|
| 113 |
+
topic_preds = []
|
| 114 |
+
topic_targets = []
|
| 115 |
+
for batch in chunks(topic_examples, args.batch_size):
|
| 116 |
+
inputs = [example.text for example in batch]
|
| 117 |
+
predictions = pipeline.predict_topics(inputs)
|
| 118 |
+
topic_preds.extend([pred.label for pred in predictions])
|
| 119 |
+
topic_targets.extend([example.topic for example in batch])
|
| 120 |
+
topic_accuracy = accuracy(topic_preds, topic_targets)
|
| 121 |
+
|
| 122 |
+
print(json.dumps(
|
| 123 |
+
{
|
| 124 |
+
"split": args.split,
|
| 125 |
+
"rouge_like": rouge_score,
|
| 126 |
+
"emotion_f1": emotion_f1,
|
| 127 |
+
"topic_accuracy": topic_accuracy,
|
| 128 |
+
},
|
| 129 |
+
indent=2,
|
| 130 |
+
))
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
main()
|
scripts/export_model.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Rebuild and export the trained multitask model for downstream use."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from src.data.tokenization import Tokenizer, TokenizerConfig
|
| 10 |
+
from src.models.factory import build_multitask_model, load_model_config
|
| 11 |
+
from src.utils.config import load_yaml
|
| 12 |
+
from src.utils.labels import load_label_metadata
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def parse_args() -> argparse.Namespace:
|
| 16 |
+
parser = argparse.ArgumentParser(description="Export LexiMind model weights")
|
| 17 |
+
parser.add_argument("--checkpoint", default="checkpoints/best.pt", help="Path to the trained checkpoint.")
|
| 18 |
+
parser.add_argument("--output", default="outputs/model.pt", help="Output path for the exported state dict.")
|
| 19 |
+
parser.add_argument("--labels", default="artifacts/labels.json", help="Label metadata JSON produced after training.")
|
| 20 |
+
parser.add_argument("--model-config", default="configs/model/base.yaml", help="Model architecture configuration.")
|
| 21 |
+
parser.add_argument("--data-config", default="configs/data/datasets.yaml", help="Data configuration (for tokenizer settings).")
|
| 22 |
+
return parser.parse_args()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def main() -> None:
|
| 26 |
+
"""Export multitask model weights from a training checkpoint to a standalone state dict."""
|
| 27 |
+
args = parse_args()
|
| 28 |
+
|
| 29 |
+
checkpoint = Path(args.checkpoint)
|
| 30 |
+
if not checkpoint.exists():
|
| 31 |
+
raise FileNotFoundError(checkpoint)
|
| 32 |
+
|
| 33 |
+
labels = load_label_metadata(args.labels)
|
| 34 |
+
data_cfg = load_yaml(args.data_config).data
|
| 35 |
+
tokenizer_section = data_cfg.get("tokenizer", {})
|
| 36 |
+
tokenizer_config = TokenizerConfig(
|
| 37 |
+
pretrained_model_name=tokenizer_section.get("pretrained_model_name", "facebook/bart-base"),
|
| 38 |
+
max_length=int(tokenizer_section.get("max_length", 512)),
|
| 39 |
+
lower=bool(tokenizer_section.get("lower", False)),
|
| 40 |
+
)
|
| 41 |
+
tokenizer = Tokenizer(tokenizer_config)
|
| 42 |
+
|
| 43 |
+
model = build_multitask_model(
|
| 44 |
+
tokenizer,
|
| 45 |
+
num_emotions=labels.emotion_size,
|
| 46 |
+
num_topics=labels.topic_size,
|
| 47 |
+
config=load_model_config(args.model_config),
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
raw_state = torch.load(checkpoint, map_location="cpu")
|
| 51 |
+
if isinstance(raw_state, dict):
|
| 52 |
+
if "model_state_dict" in raw_state and isinstance(raw_state["model_state_dict"], dict):
|
| 53 |
+
state_dict = raw_state["model_state_dict"]
|
| 54 |
+
elif "state_dict" in raw_state and isinstance(raw_state["state_dict"], dict):
|
| 55 |
+
state_dict = raw_state["state_dict"]
|
| 56 |
+
else:
|
| 57 |
+
state_dict = raw_state
|
| 58 |
+
else:
|
| 59 |
+
raise TypeError(f"Unsupported checkpoint format: expected dict, got {type(raw_state)!r}")
|
| 60 |
+
model.load_state_dict(state_dict)
|
| 61 |
+
|
| 62 |
+
output_path = Path(args.output)
|
| 63 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 64 |
+
torch.save(model.state_dict(), output_path)
|
| 65 |
+
print(f"Model exported to {output_path}")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
main()
|
scripts/inference.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run inference with the multitask model."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import List, cast
|
| 8 |
+
|
| 9 |
+
from src.data.tokenization import TokenizerConfig
|
| 10 |
+
from src.inference import EmotionPrediction, TopicPrediction, create_inference_pipeline
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _load_texts(positional: List[str], file_path: Path | None) -> List[str]:
|
| 14 |
+
texts = [text for text in positional if text]
|
| 15 |
+
if file_path is not None:
|
| 16 |
+
if not file_path.exists():
|
| 17 |
+
raise FileNotFoundError(file_path)
|
| 18 |
+
with file_path.open("r", encoding="utf-8") as handle:
|
| 19 |
+
texts.extend([line.strip() for line in handle if line.strip()])
|
| 20 |
+
if not texts:
|
| 21 |
+
raise ValueError("No input texts provided. Pass text arguments or use --file.")
|
| 22 |
+
return texts
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parse_args() -> argparse.Namespace:
|
| 26 |
+
parser = argparse.ArgumentParser(description="Run LexiMind multitask inference.")
|
| 27 |
+
parser.add_argument("text", nargs="*", help="Input text(s) to analyse.")
|
| 28 |
+
parser.add_argument("--file", type=Path, help="Path to a file containing one text per line.")
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--checkpoint",
|
| 31 |
+
type=Path,
|
| 32 |
+
default=Path("checkpoints/best.pt"),
|
| 33 |
+
help="Path to the model checkpoint produced during training.",
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--labels",
|
| 37 |
+
type=Path,
|
| 38 |
+
default=Path("artifacts/labels.json"),
|
| 39 |
+
help="JSON file containing emotion/topic label vocabularies.",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--tokenizer",
|
| 43 |
+
type=Path,
|
| 44 |
+
default=None,
|
| 45 |
+
help="Optional path to a tokenizer directory exported during training.",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--model-config",
|
| 49 |
+
type=Path,
|
| 50 |
+
default=Path("configs/model/base.yaml"),
|
| 51 |
+
help="Model architecture config used to rebuild the transformer stack.",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument("--device", default="cpu", help="Device to run inference on (cpu or cuda).")
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--summary-max-length",
|
| 56 |
+
type=int,
|
| 57 |
+
default=None,
|
| 58 |
+
help="Optional maximum length for generated summaries.",
|
| 59 |
+
)
|
| 60 |
+
return parser.parse_args()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def main() -> None:
|
| 64 |
+
args = parse_args()
|
| 65 |
+
texts = _load_texts(args.text, args.file)
|
| 66 |
+
|
| 67 |
+
tokenizer_config = None
|
| 68 |
+
if args.tokenizer is not None:
|
| 69 |
+
tokenizer_config = TokenizerConfig(pretrained_model_name=str(args.tokenizer))
|
| 70 |
+
else:
|
| 71 |
+
local_dir = Path("artifacts/hf_tokenizer")
|
| 72 |
+
if local_dir.exists():
|
| 73 |
+
tokenizer_config = TokenizerConfig(pretrained_model_name=str(local_dir))
|
| 74 |
+
|
| 75 |
+
pipeline, _ = create_inference_pipeline(
|
| 76 |
+
checkpoint_path=args.checkpoint,
|
| 77 |
+
labels_path=args.labels,
|
| 78 |
+
tokenizer_config=tokenizer_config,
|
| 79 |
+
model_config_path=args.model_config,
|
| 80 |
+
device=args.device,
|
| 81 |
+
summary_max_length=args.summary_max_length,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
results = pipeline.batch_predict(texts)
|
| 85 |
+
summaries = cast(List[str], results["summaries"])
|
| 86 |
+
emotion_preds = cast(List[EmotionPrediction], results["emotion"])
|
| 87 |
+
topic_preds = cast(List[TopicPrediction], results["topic"])
|
| 88 |
+
|
| 89 |
+
packaged = []
|
| 90 |
+
for idx, text in enumerate(texts):
|
| 91 |
+
emotion = emotion_preds[idx]
|
| 92 |
+
topic = topic_preds[idx]
|
| 93 |
+
packaged.append(
|
| 94 |
+
{
|
| 95 |
+
"text": text,
|
| 96 |
+
"summary": summaries[idx],
|
| 97 |
+
"emotion": {
|
| 98 |
+
"labels": emotion.labels,
|
| 99 |
+
"scores": emotion.scores,
|
| 100 |
+
},
|
| 101 |
+
"topic": {
|
| 102 |
+
"label": topic.label,
|
| 103 |
+
"confidence": topic.confidence,
|
| 104 |
+
},
|
| 105 |
+
}
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
print(json.dumps(packaged, indent=2, ensure_ascii=False))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
main()
|
scripts/preprocess_data.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Preprocess raw datasets into JSONL splits for LexiMind training."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import csv
|
| 6 |
+
import json
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, Iterable, Iterator, Sequence, Tuple
|
| 10 |
+
|
| 11 |
+
from sklearn.model_selection import train_test_split
|
| 12 |
+
|
| 13 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 14 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 15 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 16 |
+
|
| 17 |
+
from src.data.preprocessing import BasicTextCleaner
|
| 18 |
+
from src.utils.config import load_yaml
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parse_args() -> argparse.Namespace:
|
| 22 |
+
parser = argparse.ArgumentParser(description="Preprocess datasets configured for LexiMind")
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
"--config",
|
| 25 |
+
default="configs/data/datasets.yaml",
|
| 26 |
+
help="Path to data configuration YAML.",
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument("--val-ratio", type=float, default=0.1, help="Validation split size for topic dataset when no validation split is present.")
|
| 29 |
+
parser.add_argument("--seed", type=int, default=17, help="Random seed for deterministic splitting.")
|
| 30 |
+
return parser.parse_args()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _resolve_csv(base: Path, filename: str) -> Path | None:
|
| 34 |
+
primary = base / filename
|
| 35 |
+
if primary.exists():
|
| 36 |
+
return primary
|
| 37 |
+
nested = base / "cnn_dailymail" / filename
|
| 38 |
+
if nested.exists():
|
| 39 |
+
return nested
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _write_jsonl(records: Iterable[Dict[str, object]], destination: Path) -> None:
|
| 44 |
+
destination.parent.mkdir(parents=True, exist_ok=True)
|
| 45 |
+
with destination.open("w", encoding="utf-8") as handle:
|
| 46 |
+
for record in records:
|
| 47 |
+
handle.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _read_jsonl(path: Path) -> Iterator[Dict[str, object]]:
|
| 51 |
+
with path.open("r", encoding="utf-8") as handle:
|
| 52 |
+
for line in handle:
|
| 53 |
+
row = line.strip()
|
| 54 |
+
if not row:
|
| 55 |
+
continue
|
| 56 |
+
yield json.loads(row)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def preprocess_books(
|
| 60 |
+
raw_dir: Path,
|
| 61 |
+
processed_dir: Path,
|
| 62 |
+
cleaner: BasicTextCleaner,
|
| 63 |
+
*,
|
| 64 |
+
min_tokens: int = 30,
|
| 65 |
+
) -> None:
|
| 66 |
+
if not raw_dir.exists():
|
| 67 |
+
print(f"Skipping book preprocessing (missing directory: {raw_dir})")
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
processed_dir.mkdir(parents=True, exist_ok=True)
|
| 71 |
+
index: list[Dict[str, object]] = []
|
| 72 |
+
|
| 73 |
+
for book_path in sorted(raw_dir.glob("*.txt")):
|
| 74 |
+
text = book_path.read_text(encoding="utf-8").lstrip("\ufeff")
|
| 75 |
+
normalized = text.replace("\r\n", "\n")
|
| 76 |
+
paragraphs = [paragraph.strip() for paragraph in normalized.split("\n\n") if paragraph.strip()]
|
| 77 |
+
|
| 78 |
+
records: list[Dict[str, object]] = []
|
| 79 |
+
for paragraph_id, paragraph in enumerate(paragraphs):
|
| 80 |
+
cleaned = cleaner.transform([paragraph])[0]
|
| 81 |
+
tokens = cleaned.split()
|
| 82 |
+
if len(tokens) < min_tokens:
|
| 83 |
+
continue
|
| 84 |
+
record = {
|
| 85 |
+
"book": book_path.stem,
|
| 86 |
+
"title": book_path.stem.replace("_", " ").title(),
|
| 87 |
+
"paragraph_id": paragraph_id,
|
| 88 |
+
"text": paragraph,
|
| 89 |
+
"clean_text": cleaned,
|
| 90 |
+
"token_count": len(tokens),
|
| 91 |
+
"char_count": len(paragraph),
|
| 92 |
+
}
|
| 93 |
+
records.append(record)
|
| 94 |
+
|
| 95 |
+
if not records:
|
| 96 |
+
print(f"No suitably sized paragraphs found in {book_path}; skipping.")
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
output_path = processed_dir / f"{book_path.stem}.jsonl"
|
| 100 |
+
print(f"Writing book segments for '{book_path.stem}' to {output_path}")
|
| 101 |
+
_write_jsonl(records, output_path)
|
| 102 |
+
index.append(
|
| 103 |
+
{
|
| 104 |
+
"book": book_path.stem,
|
| 105 |
+
"title": records[0]["title"],
|
| 106 |
+
"paragraphs": len(records),
|
| 107 |
+
"source": str(book_path),
|
| 108 |
+
"output": str(output_path),
|
| 109 |
+
}
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if index:
|
| 113 |
+
index_path = processed_dir / "index.json"
|
| 114 |
+
with index_path.open("w", encoding="utf-8") as handle:
|
| 115 |
+
json.dump(index, handle, ensure_ascii=False, indent=2)
|
| 116 |
+
print(f"Book index written to {index_path}")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def preprocess_summarization(raw_dir: Path, processed_dir: Path) -> None:
|
| 120 |
+
if not raw_dir.exists():
|
| 121 |
+
print(f"Skipping summarization preprocessing (missing directory: {raw_dir})")
|
| 122 |
+
return
|
| 123 |
+
|
| 124 |
+
for split in ("train", "validation", "test"):
|
| 125 |
+
source_path = _resolve_csv(raw_dir, f"{split}.csv")
|
| 126 |
+
if source_path is None:
|
| 127 |
+
print(f"Skipping summarization split '{split}' (file not found)")
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
output_path = processed_dir / f"{split}.jsonl"
|
| 131 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 132 |
+
print(f"Writing summarization split '{split}' to {output_path}")
|
| 133 |
+
with source_path.open("r", encoding="utf-8", newline="") as source_handle, output_path.open("w", encoding="utf-8") as sink:
|
| 134 |
+
reader = csv.DictReader(source_handle)
|
| 135 |
+
for row in reader:
|
| 136 |
+
article = row.get("article") or row.get("Article") or ""
|
| 137 |
+
highlights = row.get("highlights") or row.get("summary") or ""
|
| 138 |
+
payload = {"source": article.strip(), "summary": highlights.strip()}
|
| 139 |
+
sink.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def preprocess_emotion(raw_dir: Path, processed_dir: Path, cleaner: BasicTextCleaner) -> None:
|
| 143 |
+
if not raw_dir.exists():
|
| 144 |
+
print(f"Skipping emotion preprocessing (missing directory: {raw_dir})")
|
| 145 |
+
return
|
| 146 |
+
|
| 147 |
+
split_aliases: Dict[str, Sequence[str]] = {
|
| 148 |
+
"train": ("train",),
|
| 149 |
+
"val": ("val", "validation"),
|
| 150 |
+
"test": ("test",),
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
for split, aliases in split_aliases.items():
|
| 154 |
+
source_path: Path | None = None
|
| 155 |
+
for alias in aliases:
|
| 156 |
+
for extension in ("jsonl", "txt", "csv"):
|
| 157 |
+
candidate = raw_dir / f"{alias}.{extension}"
|
| 158 |
+
if candidate.exists():
|
| 159 |
+
source_path = candidate
|
| 160 |
+
break
|
| 161 |
+
if source_path is not None:
|
| 162 |
+
break
|
| 163 |
+
if source_path is None:
|
| 164 |
+
print(f"Skipping emotion split '{split}' (file not found)")
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
assert source_path is not None
|
| 168 |
+
path = source_path
|
| 169 |
+
|
| 170 |
+
def iter_records() -> Iterator[Dict[str, object]]:
|
| 171 |
+
if path.suffix == ".jsonl":
|
| 172 |
+
for row in _read_jsonl(path):
|
| 173 |
+
raw_text = str(row.get("text", ""))
|
| 174 |
+
text = cleaner.transform([raw_text])[0]
|
| 175 |
+
labels = row.get("emotions") or row.get("labels") or []
|
| 176 |
+
if isinstance(labels, str):
|
| 177 |
+
labels = [label.strip() for label in labels.split(",") if label.strip()]
|
| 178 |
+
elif isinstance(labels, Sequence):
|
| 179 |
+
labels = [str(label) for label in labels]
|
| 180 |
+
else:
|
| 181 |
+
labels = [str(labels)] if labels else []
|
| 182 |
+
if not labels:
|
| 183 |
+
labels = ["neutral"]
|
| 184 |
+
yield {"text": text, "emotions": labels}
|
| 185 |
+
else:
|
| 186 |
+
delimiter = ";" if path.suffix == ".txt" else ","
|
| 187 |
+
with path.open("r", encoding="utf-8", newline="") as handle:
|
| 188 |
+
reader = csv.reader(handle, delimiter=delimiter)
|
| 189 |
+
for row in reader:
|
| 190 |
+
if not row:
|
| 191 |
+
continue
|
| 192 |
+
raw_text = str(row[0])
|
| 193 |
+
text = cleaner.transform([raw_text])[0]
|
| 194 |
+
raw_labels = row[1] if len(row) > 1 else ""
|
| 195 |
+
labels = [label.strip() for label in raw_labels.split(",") if label.strip()]
|
| 196 |
+
if not labels:
|
| 197 |
+
labels = ["neutral"]
|
| 198 |
+
yield {"text": text, "emotions": labels}
|
| 199 |
+
|
| 200 |
+
output_path = processed_dir / f"{split}.jsonl"
|
| 201 |
+
print(f"Writing emotion split '{split}' to {output_path}")
|
| 202 |
+
_write_jsonl(iter_records(), output_path)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def preprocess_topic(
|
| 206 |
+
raw_dir: Path,
|
| 207 |
+
processed_dir: Path,
|
| 208 |
+
cleaner: BasicTextCleaner,
|
| 209 |
+
val_ratio: float,
|
| 210 |
+
seed: int,
|
| 211 |
+
) -> None:
|
| 212 |
+
if not raw_dir.exists():
|
| 213 |
+
print(f"Skipping topic preprocessing (missing directory: {raw_dir})")
|
| 214 |
+
return
|
| 215 |
+
|
| 216 |
+
def locate(*names: str) -> Path | None:
|
| 217 |
+
for name in names:
|
| 218 |
+
candidate = raw_dir / name
|
| 219 |
+
if candidate.exists():
|
| 220 |
+
return candidate
|
| 221 |
+
return None
|
| 222 |
+
|
| 223 |
+
train_path = locate("train.jsonl", "train.csv")
|
| 224 |
+
if train_path is None:
|
| 225 |
+
print(f"Skipping topic preprocessing (missing train split in {raw_dir})")
|
| 226 |
+
return
|
| 227 |
+
|
| 228 |
+
assert train_path is not None
|
| 229 |
+
|
| 230 |
+
def load_topic_rows(path: Path) -> list[Tuple[str, str]]:
|
| 231 |
+
rows: list[Tuple[str, str]] = []
|
| 232 |
+
if path.suffix == ".jsonl":
|
| 233 |
+
for record in _read_jsonl(path):
|
| 234 |
+
text = str(record.get("text") or record.get("content") or "")
|
| 235 |
+
topic = record.get("topic") or record.get("label")
|
| 236 |
+
cleaned_text = cleaner.transform([text])[0]
|
| 237 |
+
rows.append((cleaned_text, str(topic).strip()))
|
| 238 |
+
else:
|
| 239 |
+
with path.open("r", encoding="utf-8", newline="") as handle:
|
| 240 |
+
reader = csv.DictReader(handle)
|
| 241 |
+
for row in reader:
|
| 242 |
+
topic = row.get("Class Index") or row.get("topic") or row.get("label")
|
| 243 |
+
title = str(row.get("Title") or "")
|
| 244 |
+
description = str(row.get("Description") or row.get("text") or "")
|
| 245 |
+
text = " ".join(filter(None, (title, description)))
|
| 246 |
+
cleaned_text = cleaner.transform([text])[0]
|
| 247 |
+
rows.append((cleaned_text, str(topic).strip()))
|
| 248 |
+
return rows
|
| 249 |
+
|
| 250 |
+
train_rows = load_topic_rows(train_path)
|
| 251 |
+
if not train_rows:
|
| 252 |
+
print("No topic training rows found; skipping topic preprocessing.")
|
| 253 |
+
return
|
| 254 |
+
|
| 255 |
+
texts = [row[0] for row in train_rows]
|
| 256 |
+
topics = [row[1] for row in train_rows]
|
| 257 |
+
|
| 258 |
+
validation_path = locate("val.jsonl", "validation.jsonl", "val.csv", "validation.csv")
|
| 259 |
+
has_validation = validation_path is not None
|
| 260 |
+
|
| 261 |
+
if has_validation and validation_path:
|
| 262 |
+
val_rows = load_topic_rows(validation_path)
|
| 263 |
+
train_records = train_rows
|
| 264 |
+
else:
|
| 265 |
+
train_texts, val_texts, train_topics, val_topics = train_test_split(
|
| 266 |
+
texts,
|
| 267 |
+
topics,
|
| 268 |
+
test_size=val_ratio,
|
| 269 |
+
random_state=seed,
|
| 270 |
+
stratify=topics,
|
| 271 |
+
)
|
| 272 |
+
train_records = list(zip(train_texts, train_topics))
|
| 273 |
+
val_rows = list(zip(val_texts, val_topics))
|
| 274 |
+
|
| 275 |
+
def to_records(pairs: Sequence[Tuple[str, str]]) -> Iterator[Dict[str, object]]:
|
| 276 |
+
for text, topic in pairs:
|
| 277 |
+
yield {"text": text, "topic": topic}
|
| 278 |
+
|
| 279 |
+
print(f"Writing topic train split to {processed_dir / 'train.jsonl'}")
|
| 280 |
+
_write_jsonl(to_records(train_records), processed_dir / "train.jsonl")
|
| 281 |
+
print(f"Writing topic val split to {processed_dir / 'val.jsonl'}")
|
| 282 |
+
_write_jsonl(to_records(val_rows), processed_dir / "val.jsonl")
|
| 283 |
+
|
| 284 |
+
test_path = locate("test.jsonl", "test.csv")
|
| 285 |
+
if test_path is not None:
|
| 286 |
+
test_rows = load_topic_rows(test_path)
|
| 287 |
+
print(f"Writing topic test split to {processed_dir / 'test.jsonl'}")
|
| 288 |
+
_write_jsonl(to_records(test_rows), processed_dir / "test.jsonl")
|
| 289 |
+
else:
|
| 290 |
+
print(f"Skipping topic test split (missing test split in {raw_dir})")
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def main() -> None:
|
| 294 |
+
args = parse_args()
|
| 295 |
+
config = load_yaml(args.config).data
|
| 296 |
+
|
| 297 |
+
raw_cfg = config.get("raw", {})
|
| 298 |
+
processed_cfg = config.get("processed", {})
|
| 299 |
+
|
| 300 |
+
books_raw = Path(raw_cfg.get("books", "data/raw/books"))
|
| 301 |
+
summarization_raw = Path(raw_cfg.get("summarization", "data/raw/summarization"))
|
| 302 |
+
emotion_raw = Path(raw_cfg.get("emotion", "data/raw/emotion"))
|
| 303 |
+
topic_raw = Path(raw_cfg.get("topic", "data/raw/topic"))
|
| 304 |
+
|
| 305 |
+
books_processed = Path(processed_cfg.get("books", "data/processed/books"))
|
| 306 |
+
summarization_processed = Path(processed_cfg.get("summarization", "data/processed/summarization"))
|
| 307 |
+
emotion_processed = Path(processed_cfg.get("emotion", "data/processed/emotion"))
|
| 308 |
+
topic_processed = Path(processed_cfg.get("topic", "data/processed/topic"))
|
| 309 |
+
|
| 310 |
+
cleaner = BasicTextCleaner()
|
| 311 |
+
|
| 312 |
+
preprocess_books(books_raw, books_processed, cleaner)
|
| 313 |
+
preprocess_summarization(summarization_raw, summarization_processed)
|
| 314 |
+
preprocess_emotion(emotion_raw, emotion_processed, cleaner)
|
| 315 |
+
preprocess_topic(topic_raw, topic_processed, cleaner, val_ratio=args.val_ratio, seed=args.seed)
|
| 316 |
+
|
| 317 |
+
print("Preprocessing complete.")
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
main()
|
scripts/test_gpu.py
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
# test_gpu.py
|
| 2 |
-
import torch
|
| 3 |
-
|
| 4 |
-
print("=" * 50)
|
| 5 |
-
print("GPU Information")
|
| 6 |
-
print("=" * 50)
|
| 7 |
-
|
| 8 |
-
if torch.cuda.is_available():
|
| 9 |
-
gpu_name = torch.cuda.get_device_name(0)
|
| 10 |
-
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 11 |
-
|
| 12 |
-
print(f"✅ GPU: {gpu_name}")
|
| 13 |
-
print(f"✅ Memory: {gpu_memory:.2f} GB")
|
| 14 |
-
|
| 15 |
-
# Test tensor creation
|
| 16 |
-
x = torch.randn(1000, 1000, device='cuda')
|
| 17 |
-
y = torch.randn(1000, 1000, device='cuda')
|
| 18 |
-
z = x @ y
|
| 19 |
-
|
| 20 |
-
print(f"✅ CUDA operations working!")
|
| 21 |
-
print(f"✅ Current memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
|
| 22 |
-
print(f"✅ Max memory allocated: {torch.cuda.max_memory_allocated(0) / 1e9:.2f} GB")
|
| 23 |
-
else:
|
| 24 |
-
print("❌ CUDA not available!")
|
| 25 |
-
print("Using CPU - training will be slow!")
|
| 26 |
-
|
| 27 |
-
print("=" * 50)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/train.py
CHANGED
|
@@ -1,8 +1,219 @@
|
|
| 1 |
-
|
| 2 |
-
from
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
if __name__ == "__main__":
|
| 6 |
-
|
| 7 |
-
trainer = Trainer(config)
|
| 8 |
-
trainer.train()
|
|
|
|
| 1 |
+
"""End-to-end training entrypoint for the LexiMind multitask model."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, Sequence
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 13 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 14 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 15 |
+
|
| 16 |
+
from src.data.dataloader import (
|
| 17 |
+
build_emotion_dataloader,
|
| 18 |
+
build_summarization_dataloader,
|
| 19 |
+
build_topic_dataloader,
|
| 20 |
+
)
|
| 21 |
+
from src.data.dataset import (
|
| 22 |
+
EmotionDataset,
|
| 23 |
+
SummarizationDataset,
|
| 24 |
+
TopicDataset,
|
| 25 |
+
load_emotion_jsonl,
|
| 26 |
+
load_summarization_jsonl,
|
| 27 |
+
load_topic_jsonl,
|
| 28 |
+
)
|
| 29 |
+
from src.data.tokenization import Tokenizer, TokenizerConfig
|
| 30 |
+
from src.models.factory import build_multitask_model, load_model_config
|
| 31 |
+
from src.training.trainer import Trainer, TrainerConfig
|
| 32 |
+
from src.training.utils import set_seed
|
| 33 |
+
from src.utils.config import load_yaml
|
| 34 |
+
from src.utils.io import save_state
|
| 35 |
+
from src.utils.labels import LabelMetadata, save_label_metadata
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
SplitExamples = Dict[str, list]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
SPLIT_ALIASES: Dict[str, Sequence[str]] = {
|
| 42 |
+
"train": ("train",),
|
| 43 |
+
"val": ("val", "validation"),
|
| 44 |
+
"test": ("test",),
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _read_examples(data_dir: Path, loader) -> SplitExamples:
|
| 49 |
+
splits: SplitExamples = {}
|
| 50 |
+
for canonical, aliases in SPLIT_ALIASES.items():
|
| 51 |
+
found = False
|
| 52 |
+
for alias in aliases:
|
| 53 |
+
for extension in ("jsonl", "json"):
|
| 54 |
+
candidate = data_dir / f"{alias}.{extension}"
|
| 55 |
+
if candidate.exists():
|
| 56 |
+
splits[canonical] = loader(str(candidate))
|
| 57 |
+
found = True
|
| 58 |
+
break
|
| 59 |
+
if found:
|
| 60 |
+
break
|
| 61 |
+
if not found:
|
| 62 |
+
raise FileNotFoundError(f"Missing {canonical} split under {data_dir}")
|
| 63 |
+
return splits
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def parse_args() -> argparse.Namespace:
|
| 67 |
+
parser = argparse.ArgumentParser(description="Train the LexiMind multitask transformer")
|
| 68 |
+
parser.add_argument("--data-config", default="configs/data/datasets.yaml", help="Path to data configuration YAML.")
|
| 69 |
+
parser.add_argument("--training-config", default="configs/training/default.yaml", help="Path to training hyperparameter YAML.")
|
| 70 |
+
parser.add_argument("--model-config", default="configs/model/base.yaml", help="Path to model architecture YAML.")
|
| 71 |
+
parser.add_argument("--checkpoint-out", default="checkpoints/best.pt", help="Where to store the trained checkpoint.")
|
| 72 |
+
parser.add_argument("--labels-out", default="artifacts/labels.json", help="Where to persist label vocabularies.")
|
| 73 |
+
parser.add_argument("--history-out", default="outputs/training_history.json", help="Where to write training history.")
|
| 74 |
+
parser.add_argument("--device", default="cpu", help="Training device identifier (cpu or cuda).")
|
| 75 |
+
parser.add_argument("--seed", type=int, default=17, help="Random seed for reproducibility.")
|
| 76 |
+
return parser.parse_args()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main() -> None:
|
| 80 |
+
args = parse_args()
|
| 81 |
+
set_seed(args.seed)
|
| 82 |
+
|
| 83 |
+
data_cfg = load_yaml(args.data_config).data
|
| 84 |
+
training_cfg = load_yaml(args.training_config).data
|
| 85 |
+
model_cfg = load_model_config(args.model_config)
|
| 86 |
+
|
| 87 |
+
summarization_dir = Path(data_cfg["processed"]["summarization"])
|
| 88 |
+
emotion_dir = Path(data_cfg["processed"]["emotion"])
|
| 89 |
+
topic_dir = Path(data_cfg["processed"]["topic"])
|
| 90 |
+
|
| 91 |
+
summarization_splits = _read_examples(summarization_dir, load_summarization_jsonl)
|
| 92 |
+
emotion_splits = _read_examples(emotion_dir, load_emotion_jsonl)
|
| 93 |
+
topic_splits = _read_examples(topic_dir, load_topic_jsonl)
|
| 94 |
+
|
| 95 |
+
tokenizer_section = data_cfg.get("tokenizer", {})
|
| 96 |
+
tokenizer_config = TokenizerConfig(
|
| 97 |
+
pretrained_model_name=tokenizer_section.get("pretrained_model_name", "facebook/bart-base"),
|
| 98 |
+
max_length=int(tokenizer_section.get("max_length", 512)),
|
| 99 |
+
lower=bool(tokenizer_section.get("lower", False)),
|
| 100 |
+
)
|
| 101 |
+
tokenizer = Tokenizer(tokenizer_config)
|
| 102 |
+
|
| 103 |
+
summarization_train = SummarizationDataset(summarization_splits["train"])
|
| 104 |
+
summarization_val = SummarizationDataset(summarization_splits["val"])
|
| 105 |
+
|
| 106 |
+
emotion_train = EmotionDataset(emotion_splits["train"])
|
| 107 |
+
emotion_val = EmotionDataset(emotion_splits["val"], binarizer=emotion_train.binarizer)
|
| 108 |
+
|
| 109 |
+
topic_train = TopicDataset(topic_splits["train"])
|
| 110 |
+
topic_val = TopicDataset(topic_splits["val"], encoder=topic_train.encoder)
|
| 111 |
+
|
| 112 |
+
dataloader_args = training_cfg.get("dataloader", {})
|
| 113 |
+
batch_size = int(dataloader_args.get("batch_size", 8))
|
| 114 |
+
shuffle = bool(dataloader_args.get("shuffle", True))
|
| 115 |
+
max_length = tokenizer.config.max_length
|
| 116 |
+
|
| 117 |
+
train_loaders = {
|
| 118 |
+
"summarization": build_summarization_dataloader(
|
| 119 |
+
summarization_train,
|
| 120 |
+
tokenizer,
|
| 121 |
+
batch_size=batch_size,
|
| 122 |
+
shuffle=shuffle,
|
| 123 |
+
max_source_length=max_length,
|
| 124 |
+
max_target_length=max_length,
|
| 125 |
+
),
|
| 126 |
+
"emotion": build_emotion_dataloader(
|
| 127 |
+
emotion_train,
|
| 128 |
+
tokenizer,
|
| 129 |
+
batch_size=batch_size,
|
| 130 |
+
shuffle=shuffle,
|
| 131 |
+
max_length=max_length,
|
| 132 |
+
),
|
| 133 |
+
"topic": build_topic_dataloader(
|
| 134 |
+
topic_train,
|
| 135 |
+
tokenizer,
|
| 136 |
+
batch_size=batch_size,
|
| 137 |
+
shuffle=shuffle,
|
| 138 |
+
max_length=max_length,
|
| 139 |
+
),
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
val_loaders = {
|
| 143 |
+
"summarization": build_summarization_dataloader(
|
| 144 |
+
summarization_val,
|
| 145 |
+
tokenizer,
|
| 146 |
+
batch_size=batch_size,
|
| 147 |
+
shuffle=False,
|
| 148 |
+
max_source_length=max_length,
|
| 149 |
+
max_target_length=max_length,
|
| 150 |
+
),
|
| 151 |
+
"emotion": build_emotion_dataloader(
|
| 152 |
+
emotion_val,
|
| 153 |
+
tokenizer,
|
| 154 |
+
batch_size=batch_size,
|
| 155 |
+
shuffle=False,
|
| 156 |
+
max_length=max_length,
|
| 157 |
+
),
|
| 158 |
+
"topic": build_topic_dataloader(
|
| 159 |
+
topic_val,
|
| 160 |
+
tokenizer,
|
| 161 |
+
batch_size=batch_size,
|
| 162 |
+
shuffle=False,
|
| 163 |
+
max_length=max_length,
|
| 164 |
+
),
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
device = torch.device(args.device)
|
| 168 |
+
model = build_multitask_model(
|
| 169 |
+
tokenizer,
|
| 170 |
+
num_emotions=len(emotion_train.emotion_classes),
|
| 171 |
+
num_topics=len(topic_train.topic_classes),
|
| 172 |
+
config=model_cfg,
|
| 173 |
+
).to(device)
|
| 174 |
+
|
| 175 |
+
optimizer_cfg = training_cfg.get("optimizer", {})
|
| 176 |
+
lr = float(optimizer_cfg.get("lr", 3.0e-5))
|
| 177 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 178 |
+
|
| 179 |
+
trainer_cfg = training_cfg.get("trainer", {})
|
| 180 |
+
trainer = Trainer(
|
| 181 |
+
model=model,
|
| 182 |
+
optimizer=optimizer,
|
| 183 |
+
config=TrainerConfig(
|
| 184 |
+
max_epochs=int(trainer_cfg.get("max_epochs", 1)),
|
| 185 |
+
gradient_clip_norm=float(trainer_cfg.get("gradient_clip_norm", 1.0)),
|
| 186 |
+
logging_interval=int(trainer_cfg.get("logging_interval", 50)),
|
| 187 |
+
task_weights=trainer_cfg.get("task_weights"),
|
| 188 |
+
),
|
| 189 |
+
device=device,
|
| 190 |
+
tokenizer=tokenizer,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
history = trainer.fit(train_loaders, val_loaders)
|
| 194 |
+
|
| 195 |
+
checkpoint_path = Path(args.checkpoint_out)
|
| 196 |
+
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
| 197 |
+
save_state(model, str(checkpoint_path))
|
| 198 |
+
|
| 199 |
+
labels_path = Path(args.labels_out)
|
| 200 |
+
save_label_metadata(
|
| 201 |
+
LabelMetadata(
|
| 202 |
+
emotion=emotion_train.emotion_classes,
|
| 203 |
+
topic=topic_train.topic_classes,
|
| 204 |
+
),
|
| 205 |
+
labels_path,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
history_path = Path(args.history_out)
|
| 209 |
+
history_path.parent.mkdir(parents=True, exist_ok=True)
|
| 210 |
+
with history_path.open("w", encoding="utf-8") as handle:
|
| 211 |
+
json.dump(history, handle, indent=2)
|
| 212 |
+
|
| 213 |
+
print(f"Training complete. Checkpoint saved to {checkpoint_path}")
|
| 214 |
+
print(f"Label metadata saved to {labels_path}")
|
| 215 |
+
print(f"History saved to {history_path}")
|
| 216 |
+
|
| 217 |
|
| 218 |
if __name__ == "__main__":
|
| 219 |
+
main()
|
|
|
|
|
|
setup.py
CHANGED
|
@@ -7,13 +7,23 @@ setup(
|
|
| 7 |
package_dir={"": "src"},
|
| 8 |
install_requires=[
|
| 9 |
"torch>=2.0.0",
|
| 10 |
-
"transformers>=4.
|
| 11 |
-
|
|
|
|
|
|
|
| 12 |
],
|
| 13 |
-
|
| 14 |
-
"
|
| 15 |
-
"
|
| 16 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
],
|
| 18 |
},
|
| 19 |
)
|
|
|
|
| 7 |
package_dir={"": "src"},
|
| 8 |
install_requires=[
|
| 9 |
"torch>=2.0.0",
|
| 10 |
+
"transformers>=4.40.0",
|
| 11 |
+
"scikit-learn>=1.4.0",
|
| 12 |
+
"numpy>=1.24.0",
|
| 13 |
+
"pandas>=2.0.0",
|
| 14 |
],
|
| 15 |
+
extras_require={
|
| 16 |
+
"web": [
|
| 17 |
+
"streamlit>=1.25.0",
|
| 18 |
+
"plotly>=5.18.0",
|
| 19 |
+
],
|
| 20 |
+
"api": [
|
| 21 |
+
"fastapi>=0.110.0",
|
| 22 |
+
],
|
| 23 |
+
"all": [
|
| 24 |
+
"streamlit>=1.25.0",
|
| 25 |
+
"plotly>=5.18.0",
|
| 26 |
+
"fastapi>=0.110.0",
|
| 27 |
],
|
| 28 |
},
|
| 29 |
)
|
src/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""LexiMind core package."""
|
src/api/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""API surface for LexiMind."""
|
src/api/app.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI application entrypoint."""
|
| 2 |
+
from fastapi import FastAPI
|
| 3 |
+
|
| 4 |
+
from .routes import router
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def create_app() -> FastAPI:
|
| 8 |
+
app = FastAPI(title="LexiMind")
|
| 9 |
+
app.include_router(router)
|
| 10 |
+
return app
|
src/api/dependencies.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dependency providers for the FastAPI application."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from fastapi import HTTPException, status
|
| 8 |
+
|
| 9 |
+
from ..utils.logging import get_logger
|
| 10 |
+
logger = get_logger(__name__)
|
| 11 |
+
|
| 12 |
+
from ..inference.factory import create_inference_pipeline
|
| 13 |
+
from ..inference.pipeline import InferencePipeline
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@lru_cache(maxsize=1)
|
| 17 |
+
def get_pipeline() -> InferencePipeline:
|
| 18 |
+
"""Lazily construct and cache the inference pipeline for the API."""
|
| 19 |
+
|
| 20 |
+
checkpoint = Path("checkpoints/best.pt")
|
| 21 |
+
labels = Path("artifacts/labels.json")
|
| 22 |
+
model_config = Path("configs/model/base.yaml")
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
pipeline, _ = create_inference_pipeline(
|
| 26 |
+
checkpoint_path=checkpoint,
|
| 27 |
+
labels_path=labels,
|
| 28 |
+
model_config_path=model_config,
|
| 29 |
+
)
|
| 30 |
+
except FileNotFoundError as exc:
|
| 31 |
+
logger.exception("Pipeline initialization failed: missing artifact")
|
| 32 |
+
raise HTTPException(
|
| 33 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 34 |
+
detail="Service temporarily unavailable",
|
| 35 |
+
) from exc
|
| 36 |
+
except Exception as exc: # noqa: BLE001 - surface initialization issues to the caller
|
| 37 |
+
logger.exception("Pipeline initialization failed")
|
| 38 |
+
raise HTTPException(
|
| 39 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 40 |
+
detail="Service temporarily unavailable",
|
| 41 |
+
) from exc
|
| 42 |
+
return pipeline
|
src/api/inference/__init__.py
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
API inference module for LexiMind.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
from .inference import load_models, summarize_text, classify_emotion, topic_for_text
|
| 6 |
-
|
| 7 |
-
__all__ = ["load_models", "summarize_text", "classify_emotion", "topic_for_text"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/api/inference/inference.py
DELETED
|
@@ -1,133 +0,0 @@
|
|
| 1 |
-
"""Minimal inference helpers that rely on the custom transformer stack."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
|
| 10 |
-
from ...data.preprocessing import TextPreprocessor, TransformerTokenizer
|
| 11 |
-
from ...models.multitask import MultiTaskModel
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def _load_tokenizer(tokenizer_path: Path) -> TransformerTokenizer:
|
| 15 |
-
if not tokenizer_path.exists():
|
| 16 |
-
raise FileNotFoundError(f"tokenizer file '{tokenizer_path}' not found")
|
| 17 |
-
return TransformerTokenizer.load(tokenizer_path)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def load_models(config: Dict[str, Any]) -> Dict[str, Any]:
|
| 21 |
-
"""Load MultiTaskModel together with the tokenizer-driven preprocessor."""
|
| 22 |
-
|
| 23 |
-
device = torch.device(config.get("device", "cpu"))
|
| 24 |
-
tokenizer_path = config.get("tokenizer_path")
|
| 25 |
-
if tokenizer_path is None:
|
| 26 |
-
raise ValueError("'tokenizer_path' missing in config")
|
| 27 |
-
|
| 28 |
-
tokenizer = _load_tokenizer(Path(tokenizer_path))
|
| 29 |
-
preprocessor = TextPreprocessor(
|
| 30 |
-
max_length=int(config.get("max_length", 512)),
|
| 31 |
-
tokenizer=tokenizer,
|
| 32 |
-
min_freq=int(config.get("min_freq", 1)),
|
| 33 |
-
lowercase=bool(config.get("lowercase", True)),
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
-
encoder_kwargs = dict(config.get("encoder", {}))
|
| 37 |
-
decoder_kwargs = dict(config.get("decoder", {}))
|
| 38 |
-
|
| 39 |
-
encoder = preprocessor.build_encoder(**encoder_kwargs)
|
| 40 |
-
decoder = preprocessor.build_decoder(**decoder_kwargs)
|
| 41 |
-
model = MultiTaskModel(encoder=encoder, decoder=decoder)
|
| 42 |
-
|
| 43 |
-
checkpoint_path = config.get("checkpoint_path")
|
| 44 |
-
if checkpoint_path:
|
| 45 |
-
state = torch.load(checkpoint_path, map_location=device)
|
| 46 |
-
if isinstance(state, dict) and "state_dict" in state:
|
| 47 |
-
state = state["state_dict"]
|
| 48 |
-
model.load_state_dict(state, strict=False)
|
| 49 |
-
|
| 50 |
-
model.to(device)
|
| 51 |
-
|
| 52 |
-
return {
|
| 53 |
-
"loaded": True,
|
| 54 |
-
"device": device,
|
| 55 |
-
"mt": model,
|
| 56 |
-
"preprocessor": preprocessor,
|
| 57 |
-
}
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def summarize_text(
|
| 61 |
-
text: str,
|
| 62 |
-
compression: float = 0.25,
|
| 63 |
-
collect_attn: bool = False,
|
| 64 |
-
models: Optional[Dict[str, Any]] = None,
|
| 65 |
-
) -> Tuple[str, Optional[Dict[str, torch.Tensor]]]:
|
| 66 |
-
if models is None or not models.get("loaded"):
|
| 67 |
-
raise RuntimeError("Models must be loaded via load_models before summarize_text is called")
|
| 68 |
-
|
| 69 |
-
model: MultiTaskModel = models["mt"]
|
| 70 |
-
preprocessor: TextPreprocessor = models["preprocessor"]
|
| 71 |
-
device: torch.device = models["device"]
|
| 72 |
-
|
| 73 |
-
batch = preprocessor.batch_encode([text])
|
| 74 |
-
tokenizer = preprocessor.tokenizer
|
| 75 |
-
encoder = model.encoder
|
| 76 |
-
decoder = model.decoder
|
| 77 |
-
if tokenizer is None or encoder is None or decoder is None:
|
| 78 |
-
raise RuntimeError("Encoder, decoder, and tokenizer must be configured before summarization")
|
| 79 |
-
input_ids = batch.input_ids.to(device)
|
| 80 |
-
memory = encoder(input_ids)
|
| 81 |
-
src_len = batch.lengths[0]
|
| 82 |
-
max_tgt = max(4, int(src_len * compression))
|
| 83 |
-
generated = decoder.greedy_decode(
|
| 84 |
-
memory,
|
| 85 |
-
max_len=min(preprocessor.max_length, max_tgt),
|
| 86 |
-
start_token_id=tokenizer.bos_id,
|
| 87 |
-
end_token_id=tokenizer.eos_id,
|
| 88 |
-
)
|
| 89 |
-
summary = tokenizer.decode(generated[0].tolist(), skip_special_tokens=True)
|
| 90 |
-
return summary.strip(), None if not collect_attn else {}
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def classify_emotion(text: str, models: Optional[Dict[str, Any]] = None) -> Tuple[List[float], List[str]]:
|
| 94 |
-
if models is None or not models.get("loaded"):
|
| 95 |
-
raise RuntimeError("Models must be loaded via load_models before classify_emotion is called")
|
| 96 |
-
|
| 97 |
-
model: MultiTaskModel = models["mt"]
|
| 98 |
-
preprocessor: TextPreprocessor = models["preprocessor"]
|
| 99 |
-
device: torch.device = models["device"]
|
| 100 |
-
|
| 101 |
-
batch = preprocessor.batch_encode([text])
|
| 102 |
-
input_ids = batch.input_ids.to(device)
|
| 103 |
-
result = model.forward("emotion", {"input_ids": input_ids})
|
| 104 |
-
logits = result[1] if isinstance(result, tuple) else result
|
| 105 |
-
scores = torch.sigmoid(logits).squeeze(0).detach().cpu().tolist()
|
| 106 |
-
labels = models.get("emotion_labels") or [
|
| 107 |
-
"joy",
|
| 108 |
-
"sadness",
|
| 109 |
-
"anger",
|
| 110 |
-
"fear",
|
| 111 |
-
"surprise",
|
| 112 |
-
"disgust",
|
| 113 |
-
]
|
| 114 |
-
return scores, labels[: len(scores)]
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def topic_for_text(text: str, models: Optional[Dict[str, Any]] = None) -> Tuple[int, List[str]]:
|
| 118 |
-
if models is None or not models.get("loaded"):
|
| 119 |
-
raise RuntimeError("Models must be loaded via load_models before topic_for_text is called")
|
| 120 |
-
|
| 121 |
-
model: MultiTaskModel = models["mt"]
|
| 122 |
-
preprocessor: TextPreprocessor = models["preprocessor"]
|
| 123 |
-
device: torch.device = models["device"]
|
| 124 |
-
|
| 125 |
-
batch = preprocessor.batch_encode([text])
|
| 126 |
-
input_ids = batch.input_ids.to(device)
|
| 127 |
-
encoder = model.encoder
|
| 128 |
-
if encoder is None:
|
| 129 |
-
raise RuntimeError("Encoder must be configured before topic_for_text is called")
|
| 130 |
-
memory = encoder(input_ids)
|
| 131 |
-
embedding = memory.mean(dim=1).detach().cpu()
|
| 132 |
-
_ = embedding # placeholder for downstream clustering hook
|
| 133 |
-
return 0, ["topic_stub"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/api/routes.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API routes."""
|
| 2 |
+
from typing import cast
|
| 3 |
+
|
| 4 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
| 5 |
+
|
| 6 |
+
from ..inference import EmotionPrediction, InferencePipeline, TopicPrediction
|
| 7 |
+
from .dependencies import get_pipeline
|
| 8 |
+
from .schemas import SummaryRequest, SummaryResponse
|
| 9 |
+
|
| 10 |
+
router = APIRouter()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@router.post("/summarize", response_model=SummaryResponse)
|
| 14 |
+
def summarize(payload: SummaryRequest, pipeline: InferencePipeline = Depends(get_pipeline)) -> SummaryResponse:
|
| 15 |
+
try:
|
| 16 |
+
outputs = pipeline.batch_predict([payload.text])
|
| 17 |
+
except Exception as exc: # noqa: BLE001 - surface inference error to client
|
| 18 |
+
raise HTTPException(
|
| 19 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 20 |
+
detail=str(exc),
|
| 21 |
+
) from exc
|
| 22 |
+
summaries = cast(list[str], outputs["summaries"])
|
| 23 |
+
emotion_preds = cast(list[EmotionPrediction], outputs["emotion"])
|
| 24 |
+
topic_preds = cast(list[TopicPrediction], outputs["topic"])
|
| 25 |
+
|
| 26 |
+
emotion = emotion_preds[0]
|
| 27 |
+
topic = topic_preds[0]
|
| 28 |
+
return SummaryResponse(
|
| 29 |
+
summary=summaries[0],
|
| 30 |
+
emotion_labels=emotion.labels,
|
| 31 |
+
emotion_scores=emotion.scores,
|
| 32 |
+
topic=topic.label,
|
| 33 |
+
topic_confidence=topic.confidence,
|
| 34 |
+
)
|
src/api/schemas.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API schemas."""
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SummaryRequest(BaseModel):
|
| 6 |
+
text: str
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SummaryResponse(BaseModel):
|
| 10 |
+
summary: str
|
| 11 |
+
emotion_labels: list[str]
|
| 12 |
+
emotion_scores: list[float]
|
| 13 |
+
topic: str
|
| 14 |
+
topic_confidence: float
|
src/data/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Data utilities for LexiMind."""
|
src/data/dataloader.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task-aware DataLoader builders for the LexiMind multitask suite."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Iterable, List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
|
| 9 |
+
from .dataset import EmotionDataset, EmotionExample, SummarizationDataset, SummarizationExample, TopicDataset, TopicExample
|
| 10 |
+
from .tokenization import Tokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SummarizationCollator:
|
| 14 |
+
"""Prepare encoder-decoder batches for abstractive summarization."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, tokenizer: Tokenizer, *, max_source_length: int | None = None, max_target_length: int | None = None) -> None:
|
| 17 |
+
self.tokenizer = tokenizer
|
| 18 |
+
self.max_source_length = max_source_length
|
| 19 |
+
self.max_target_length = max_target_length
|
| 20 |
+
|
| 21 |
+
def __call__(self, batch: List[SummarizationExample]) -> dict[str, torch.Tensor]:
|
| 22 |
+
sources = [example.source for example in batch]
|
| 23 |
+
targets = [example.summary for example in batch]
|
| 24 |
+
|
| 25 |
+
source_enc = self.tokenizer.batch_encode(sources, max_length=self.max_source_length)
|
| 26 |
+
target_enc = self.tokenizer.batch_encode(targets, max_length=self.max_target_length)
|
| 27 |
+
|
| 28 |
+
labels = target_enc["input_ids"].clone()
|
| 29 |
+
decoder_input_ids = self.tokenizer.prepare_decoder_inputs(target_enc["input_ids"])
|
| 30 |
+
labels[target_enc["attention_mask"] == 0] = -100
|
| 31 |
+
|
| 32 |
+
return {
|
| 33 |
+
"src_ids": source_enc["input_ids"],
|
| 34 |
+
"src_mask": source_enc["attention_mask"],
|
| 35 |
+
"tgt_ids": decoder_input_ids,
|
| 36 |
+
"labels": labels,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class EmotionCollator:
|
| 41 |
+
"""Prepare batches for multi-label emotion classification."""
|
| 42 |
+
|
| 43 |
+
def __init__(self, tokenizer: Tokenizer, dataset: EmotionDataset, *, max_length: int | None = None) -> None:
|
| 44 |
+
self.tokenizer = tokenizer
|
| 45 |
+
self.binarizer = dataset.binarizer
|
| 46 |
+
self.max_length = max_length
|
| 47 |
+
|
| 48 |
+
def __call__(self, batch: List[EmotionExample]) -> dict[str, torch.Tensor]:
|
| 49 |
+
texts = [example.text for example in batch]
|
| 50 |
+
encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
|
| 51 |
+
label_array = self.binarizer.transform([example.emotions for example in batch])
|
| 52 |
+
labels = torch.as_tensor(label_array, dtype=torch.float32)
|
| 53 |
+
return {
|
| 54 |
+
"input_ids": encoded["input_ids"],
|
| 55 |
+
"attention_mask": encoded["attention_mask"],
|
| 56 |
+
"labels": labels,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class TopicCollator:
|
| 61 |
+
"""Prepare batches for topic classification using the projection head."""
|
| 62 |
+
|
| 63 |
+
def __init__(self, tokenizer: Tokenizer, dataset: TopicDataset, *, max_length: int | None = None) -> None:
|
| 64 |
+
self.tokenizer = tokenizer
|
| 65 |
+
self.encoder = dataset.encoder
|
| 66 |
+
self.max_length = max_length
|
| 67 |
+
|
| 68 |
+
def __call__(self, batch: List[TopicExample]) -> dict[str, torch.Tensor]:
|
| 69 |
+
texts = [example.text for example in batch]
|
| 70 |
+
encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
|
| 71 |
+
labels = torch.as_tensor(self.encoder.transform([example.topic for example in batch]), dtype=torch.long)
|
| 72 |
+
return {
|
| 73 |
+
"input_ids": encoded["input_ids"],
|
| 74 |
+
"attention_mask": encoded["attention_mask"],
|
| 75 |
+
"labels": labels,
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def build_summarization_dataloader(
|
| 80 |
+
dataset: SummarizationDataset,
|
| 81 |
+
tokenizer: Tokenizer,
|
| 82 |
+
*,
|
| 83 |
+
batch_size: int,
|
| 84 |
+
shuffle: bool = True,
|
| 85 |
+
max_source_length: int | None = None,
|
| 86 |
+
max_target_length: int | None = None,
|
| 87 |
+
) -> DataLoader:
|
| 88 |
+
collator = SummarizationCollator(
|
| 89 |
+
tokenizer,
|
| 90 |
+
max_source_length=max_source_length,
|
| 91 |
+
max_target_length=max_target_length,
|
| 92 |
+
)
|
| 93 |
+
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collator)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def build_emotion_dataloader(
|
| 97 |
+
dataset: EmotionDataset,
|
| 98 |
+
tokenizer: Tokenizer,
|
| 99 |
+
*,
|
| 100 |
+
batch_size: int,
|
| 101 |
+
shuffle: bool = True,
|
| 102 |
+
max_length: int | None = None,
|
| 103 |
+
) -> DataLoader:
|
| 104 |
+
collator = EmotionCollator(tokenizer, dataset, max_length=max_length)
|
| 105 |
+
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collator)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def build_topic_dataloader(
|
| 109 |
+
dataset: TopicDataset,
|
| 110 |
+
tokenizer: Tokenizer,
|
| 111 |
+
*,
|
| 112 |
+
batch_size: int,
|
| 113 |
+
shuffle: bool = True,
|
| 114 |
+
max_length: int | None = None,
|
| 115 |
+
) -> DataLoader:
|
| 116 |
+
collator = TopicCollator(tokenizer, dataset, max_length=max_length)
|
| 117 |
+
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collator)
|
src/data/dataset.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset definitions for the LexiMind multitask training pipeline."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Callable, Iterable, List, Sequence, TypeVar
|
| 8 |
+
|
| 9 |
+
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass(slots=True)
|
| 14 |
+
class SummarizationExample:
|
| 15 |
+
"""Container for abstractive summarization samples."""
|
| 16 |
+
|
| 17 |
+
source: str
|
| 18 |
+
summary: str
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass(slots=True)
|
| 22 |
+
class EmotionExample:
|
| 23 |
+
"""Container for multi-label emotion classification samples."""
|
| 24 |
+
|
| 25 |
+
text: str
|
| 26 |
+
emotions: Sequence[str]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass(slots=True)
|
| 30 |
+
class TopicExample:
|
| 31 |
+
"""Container for topic clustering / classification samples."""
|
| 32 |
+
|
| 33 |
+
text: str
|
| 34 |
+
topic: str
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SummarizationDataset(Dataset[SummarizationExample]):
|
| 38 |
+
"""Dataset yielding encoder-decoder training pairs."""
|
| 39 |
+
|
| 40 |
+
def __init__(self, examples: Iterable[SummarizationExample]) -> None:
|
| 41 |
+
self._examples = list(examples)
|
| 42 |
+
|
| 43 |
+
def __len__(self) -> int:
|
| 44 |
+
return len(self._examples)
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, index: int) -> SummarizationExample:
|
| 47 |
+
return self._examples[index]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class EmotionDataset(Dataset[EmotionExample]):
|
| 51 |
+
"""Dataset that owns a scikit-learn MultiLabelBinarizer for emissions."""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
examples: Iterable[EmotionExample],
|
| 56 |
+
*,
|
| 57 |
+
binarizer: MultiLabelBinarizer | None = None,
|
| 58 |
+
) -> None:
|
| 59 |
+
self._examples = list(examples)
|
| 60 |
+
all_labels = [example.emotions for example in self._examples]
|
| 61 |
+
if binarizer is None:
|
| 62 |
+
self._binarizer = MultiLabelBinarizer()
|
| 63 |
+
self._binarizer.fit(all_labels)
|
| 64 |
+
else:
|
| 65 |
+
self._binarizer = binarizer
|
| 66 |
+
if not hasattr(self._binarizer, "classes_"):
|
| 67 |
+
raise ValueError(
|
| 68 |
+
"Provided MultiLabelBinarizer must be pre-fitted with 'classes_' attribute."
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def __len__(self) -> int:
|
| 72 |
+
return len(self._examples)
|
| 73 |
+
|
| 74 |
+
def __getitem__(self, index: int) -> EmotionExample:
|
| 75 |
+
return self._examples[index]
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def binarizer(self) -> MultiLabelBinarizer:
|
| 79 |
+
return self._binarizer
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def emotion_classes(self) -> List[str]:
|
| 83 |
+
return list(self._binarizer.classes_)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class TopicDataset(Dataset[TopicExample]):
|
| 87 |
+
"""Dataset that owns a LabelEncoder for topic ids."""
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
examples: Iterable[TopicExample],
|
| 92 |
+
*,
|
| 93 |
+
encoder: LabelEncoder | None = None,
|
| 94 |
+
) -> None:
|
| 95 |
+
self._examples = list(examples)
|
| 96 |
+
topics = [example.topic for example in self._examples]
|
| 97 |
+
if encoder is None:
|
| 98 |
+
self._encoder = LabelEncoder().fit(topics)
|
| 99 |
+
else:
|
| 100 |
+
self._encoder = encoder
|
| 101 |
+
if not hasattr(self._encoder, "classes_"):
|
| 102 |
+
raise ValueError(
|
| 103 |
+
"Provided LabelEncoder must be pre-fitted with 'classes_' attribute."
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def __len__(self) -> int:
|
| 107 |
+
return len(self._examples)
|
| 108 |
+
|
| 109 |
+
def __getitem__(self, index: int) -> TopicExample:
|
| 110 |
+
return self._examples[index]
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def encoder(self) -> LabelEncoder:
|
| 114 |
+
return self._encoder
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def topic_classes(self) -> List[str]:
|
| 118 |
+
return list(self._encoder.classes_)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
T = TypeVar("T")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _safe_json_load(handle, path: Path) -> object:
|
| 125 |
+
try:
|
| 126 |
+
return json.load(handle)
|
| 127 |
+
except json.JSONDecodeError as exc:
|
| 128 |
+
raise ValueError(f"Failed to parse JSON in '{path}': {exc}") from exc
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _safe_json_loads(data: str, path: Path, line_number: int) -> object:
|
| 132 |
+
try:
|
| 133 |
+
return json.loads(data)
|
| 134 |
+
except json.JSONDecodeError as exc:
|
| 135 |
+
raise ValueError(f"Failed to parse JSON in '{path}' at line {line_number}: {exc}") from exc
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _validate_keys(
|
| 139 |
+
payload: dict,
|
| 140 |
+
required_keys: Sequence[str],
|
| 141 |
+
position: int,
|
| 142 |
+
*,
|
| 143 |
+
path: Path,
|
| 144 |
+
is_array: bool = False,
|
| 145 |
+
) -> None:
|
| 146 |
+
missing = [key for key in required_keys if key not in payload]
|
| 147 |
+
if missing:
|
| 148 |
+
keys = ", ".join(sorted(missing))
|
| 149 |
+
location = "index" if is_array else "line"
|
| 150 |
+
raise KeyError(f"Missing required keys ({keys}) at {location} {position} of '{path}'")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _load_jsonl_generic(
|
| 154 |
+
path: str,
|
| 155 |
+
constructor: Callable[[dict], T],
|
| 156 |
+
required_keys: Sequence[str],
|
| 157 |
+
) -> List[T]:
|
| 158 |
+
data_path = Path(path)
|
| 159 |
+
if not data_path.exists():
|
| 160 |
+
raise FileNotFoundError(f"Dataset file '{data_path}' does not exist")
|
| 161 |
+
if not data_path.is_file():
|
| 162 |
+
raise ValueError(f"Dataset path '{data_path}' is not a file")
|
| 163 |
+
|
| 164 |
+
items: List[T] = []
|
| 165 |
+
with data_path.open("r", encoding="utf-8") as handle:
|
| 166 |
+
first_non_ws = ""
|
| 167 |
+
while True:
|
| 168 |
+
pos = handle.tell()
|
| 169 |
+
char = handle.read(1)
|
| 170 |
+
if not char:
|
| 171 |
+
break
|
| 172 |
+
if not char.isspace():
|
| 173 |
+
first_non_ws = char
|
| 174 |
+
handle.seek(pos)
|
| 175 |
+
break
|
| 176 |
+
if not first_non_ws:
|
| 177 |
+
raise ValueError(f"Dataset file '{data_path}' is empty or contains only whitespace")
|
| 178 |
+
|
| 179 |
+
if first_non_ws == "[":
|
| 180 |
+
payloads = _safe_json_load(handle, data_path)
|
| 181 |
+
if not isinstance(payloads, list):
|
| 182 |
+
raise ValueError(f"Expected a JSON array in '{data_path}' but found {type(payloads).__name__}")
|
| 183 |
+
for idx, payload in enumerate(payloads):
|
| 184 |
+
if not isinstance(payload, dict):
|
| 185 |
+
raise ValueError(
|
| 186 |
+
f"Expected objects in array for '{data_path}', found {type(payload).__name__} at index {idx}"
|
| 187 |
+
)
|
| 188 |
+
_validate_keys(payload, required_keys, idx, path=data_path, is_array=True)
|
| 189 |
+
items.append(constructor(payload))
|
| 190 |
+
else:
|
| 191 |
+
handle.seek(0)
|
| 192 |
+
line_number = 0
|
| 193 |
+
for line in handle:
|
| 194 |
+
line_number += 1
|
| 195 |
+
if not line.strip():
|
| 196 |
+
continue
|
| 197 |
+
payload = _safe_json_loads(line, data_path, line_number)
|
| 198 |
+
if not isinstance(payload, dict):
|
| 199 |
+
raise ValueError(
|
| 200 |
+
f"Expected JSON object per line in '{data_path}', found {type(payload).__name__} at line {line_number}"
|
| 201 |
+
)
|
| 202 |
+
_validate_keys(payload, required_keys, line_number, path=data_path)
|
| 203 |
+
items.append(constructor(payload))
|
| 204 |
+
|
| 205 |
+
return items
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def load_summarization_jsonl(path: str) -> List[SummarizationExample]:
|
| 209 |
+
return _load_jsonl_generic(
|
| 210 |
+
path,
|
| 211 |
+
lambda payload: SummarizationExample(source=payload["source"], summary=payload["summary"]),
|
| 212 |
+
required_keys=("source", "summary"),
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def load_emotion_jsonl(path: str) -> List[EmotionExample]:
|
| 217 |
+
return _load_jsonl_generic(
|
| 218 |
+
path,
|
| 219 |
+
lambda payload: EmotionExample(text=payload["text"], emotions=payload.get("emotions", [])),
|
| 220 |
+
required_keys=("text",),
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def load_topic_jsonl(path: str) -> List[TopicExample]:
|
| 225 |
+
return _load_jsonl_generic(
|
| 226 |
+
path,
|
| 227 |
+
lambda payload: TopicExample(text=payload["text"], topic=payload["topic"]),
|
| 228 |
+
required_keys=("text", "topic"),
|
| 229 |
+
)
|
src/data/download.py
CHANGED
|
@@ -1,66 +1,45 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Download helpers for datasets.
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
| 8 |
|
| 9 |
-
Make sure you have Kaggle credentials configured if you call Kaggle downloads.
|
| 10 |
-
"""
|
| 11 |
-
import os
|
| 12 |
-
import requests
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
url = f"https://www.gutenberg.org/files/{gutenberg_id}/{gutenberg_id}-0.txt"
|
| 17 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 18 |
-
out_path = os.path.join(out_dir, filename)
|
| 19 |
-
if os.path.exists(out_path):
|
| 20 |
-
print("Already downloaded:", out_path)
|
| 21 |
-
return out_path
|
| 22 |
-
try:
|
| 23 |
-
r = requests.get(url, timeout=30)
|
| 24 |
-
r.raise_for_status()
|
| 25 |
-
with open(out_path, "wb") as f:
|
| 26 |
-
f.write(r.content)
|
| 27 |
-
print("Downloaded:", out_path)
|
| 28 |
-
return out_path
|
| 29 |
-
except Exception as e:
|
| 30 |
-
print("Failed to download Gutenberg file:", e)
|
| 31 |
-
return None
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
|
|
|
| 35 |
try:
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
try:
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
except
|
| 47 |
-
|
| 48 |
-
return False
|
| 49 |
-
|
| 50 |
-
def download_emotion_dataset():
|
| 51 |
-
target_dir = "data/raw/emotion"
|
| 52 |
-
return _safe_kaggle_download('praveengovi/emotions-dataset-for-nlp', target_dir)
|
| 53 |
-
|
| 54 |
-
def download_cnn_dailymail():
|
| 55 |
-
target_dir = "data/raw/summarization"
|
| 56 |
-
return _safe_kaggle_download('gowrishankarp/newspaper-text-summarization-cnn-dailymail', target_dir)
|
| 57 |
-
|
| 58 |
-
def download_ag_news():
|
| 59 |
-
target_dir = "data/raw/topic"
|
| 60 |
-
return _safe_kaggle_download('amananandrai/ag-news-classification-dataset', target_dir)
|
| 61 |
-
|
| 62 |
-
if __name__ == "__main__":
|
| 63 |
-
download_gutenberg()
|
| 64 |
-
download_emotion_dataset()
|
| 65 |
-
download_cnn_dailymail()
|
| 66 |
-
download_ag_news()
|
|
|
|
| 1 |
+
"""Dataset download helpers."""
|
|
|
|
| 2 |
|
| 3 |
+
import socket
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from subprocess import CalledProcessError, run
|
| 6 |
+
from urllib.error import URLError
|
| 7 |
+
from urllib.request import urlopen
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
DOWNLOAD_TIMEOUT = 60
|
| 11 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
def kaggle_download(dataset: str, output_dir: str) -> None:
|
| 14 |
+
target = Path(output_dir)
|
| 15 |
+
target.mkdir(parents=True, exist_ok=True)
|
| 16 |
try:
|
| 17 |
+
run([
|
| 18 |
+
"kaggle",
|
| 19 |
+
"datasets",
|
| 20 |
+
"download",
|
| 21 |
+
"-d",
|
| 22 |
+
dataset,
|
| 23 |
+
"-p",
|
| 24 |
+
str(target),
|
| 25 |
+
"--unzip",
|
| 26 |
+
], check=True)
|
| 27 |
+
except CalledProcessError as error:
|
| 28 |
+
raise RuntimeError(
|
| 29 |
+
"Kaggle download failed. Verify that the Kaggle CLI is authenticated,"
|
| 30 |
+
" you have accepted the dataset terms on kaggle.com, and your kaggle.json"
|
| 31 |
+
" credentials are located in %USERPROFILE%/.kaggle."
|
| 32 |
+
) from error
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def gutenberg_download(url: str, output_path: str) -> None:
|
| 36 |
+
target = Path(output_path)
|
| 37 |
+
target.parent.mkdir(parents=True, exist_ok=True)
|
| 38 |
try:
|
| 39 |
+
with urlopen(url, timeout=DOWNLOAD_TIMEOUT) as response, target.open("wb") as handle:
|
| 40 |
+
chunk = response.read(8192)
|
| 41 |
+
while chunk:
|
| 42 |
+
handle.write(chunk)
|
| 43 |
+
chunk = response.read(8192)
|
| 44 |
+
except (URLError, socket.timeout, OSError) as error:
|
| 45 |
+
raise RuntimeError(f"Failed to download '{url}' to '{target}': {error}") from error
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/preprocessing.py
CHANGED
|
@@ -1,260 +1,130 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
-
from collections import Counter
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
-
import json
|
| 8 |
-
from pathlib import Path
|
| 9 |
import re
|
| 10 |
-
from
|
|
|
|
| 11 |
|
| 12 |
import torch
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
from
|
| 15 |
-
from ..models.encoder import TransformerEncoder
|
| 16 |
-
|
| 17 |
-
SPECIAL_TOKENS: Tuple[str, str, str, str] = ("<pad>", "<bos>", "<eos>", "<unk>")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def _normalize(text: str, lowercase: bool) -> str:
|
| 21 |
-
text = text.strip()
|
| 22 |
-
text = re.sub(r"\s+", " ", text)
|
| 23 |
-
if lowercase:
|
| 24 |
-
text = text.lower()
|
| 25 |
-
return text
|
| 26 |
-
|
| 27 |
|
| 28 |
-
def _basic_tokenize(text: str) -> List[str]:
|
| 29 |
-
return re.findall(r"\b\w+\b|[.,;:?!]", text)
|
| 30 |
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
|
| 33 |
-
"""Minimal tokenizer that keeps vocabulary aligned with the custom transformer."""
|
| 34 |
-
|
| 35 |
-
def __init__(
|
| 36 |
-
self,
|
| 37 |
-
stoi: Dict[str, int],
|
| 38 |
-
itos: List[str],
|
| 39 |
-
specials: Sequence[str] = SPECIAL_TOKENS,
|
| 40 |
-
lowercase: bool = True,
|
| 41 |
-
) -> None:
|
| 42 |
-
self.stoi = stoi
|
| 43 |
-
self.itos = itos
|
| 44 |
-
self.specials = tuple(specials)
|
| 45 |
self.lowercase = lowercase
|
| 46 |
-
self.
|
| 47 |
-
self.bos_id = self._lookup(self.specials[1])
|
| 48 |
-
self.eos_id = self._lookup(self.specials[2])
|
| 49 |
-
self.unk_id = self._lookup(self.specials[3])
|
| 50 |
-
|
| 51 |
-
@classmethod
|
| 52 |
-
def build(
|
| 53 |
-
cls,
|
| 54 |
-
texts: Iterable[str],
|
| 55 |
-
min_freq: int = 1,
|
| 56 |
-
lowercase: bool = True,
|
| 57 |
-
specials: Sequence[str] = SPECIAL_TOKENS,
|
| 58 |
-
) -> "TransformerTokenizer":
|
| 59 |
-
counter: Counter[str] = Counter()
|
| 60 |
-
for text in texts:
|
| 61 |
-
normalized = _normalize(text, lowercase)
|
| 62 |
-
counter.update(_basic_tokenize(normalized))
|
| 63 |
-
|
| 64 |
-
ordered_specials = list(dict.fromkeys(specials))
|
| 65 |
-
itos: List[str] = ordered_specials.copy()
|
| 66 |
-
for token, freq in counter.most_common():
|
| 67 |
-
if freq < min_freq:
|
| 68 |
-
continue
|
| 69 |
-
if token in itos:
|
| 70 |
-
continue
|
| 71 |
-
itos.append(token)
|
| 72 |
|
| 73 |
-
|
| 74 |
-
return
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
return len(self.itos)
|
| 79 |
|
| 80 |
-
def
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
def encode(
|
| 85 |
-
self,
|
| 86 |
-
text: str,
|
| 87 |
-
add_special_tokens: bool = True,
|
| 88 |
-
max_length: Optional[int] = None,
|
| 89 |
-
) -> List[int]:
|
| 90 |
-
tokens = self.tokenize(text)
|
| 91 |
-
pieces = [self.stoi.get(tok, self.unk_id) for tok in tokens]
|
| 92 |
-
if add_special_tokens:
|
| 93 |
-
pieces = [self.bos_id] + pieces + [self.eos_id]
|
| 94 |
-
|
| 95 |
-
if max_length is not None and len(pieces) > max_length:
|
| 96 |
-
if add_special_tokens and max_length >= 2:
|
| 97 |
-
inner_max = max_length - 2
|
| 98 |
-
trimmed = pieces[1:-1][:inner_max]
|
| 99 |
-
pieces = [self.bos_id] + trimmed + [self.eos_id]
|
| 100 |
-
else:
|
| 101 |
-
pieces = pieces[:max_length]
|
| 102 |
-
return pieces
|
| 103 |
|
| 104 |
-
|
| 105 |
-
tokens: List[str] = []
|
| 106 |
-
for idx in ids:
|
| 107 |
-
if idx < 0 or idx >= len(self.itos):
|
| 108 |
-
continue
|
| 109 |
-
token = self.itos[idx]
|
| 110 |
-
if skip_special_tokens and token in self.specials:
|
| 111 |
-
continue
|
| 112 |
-
tokens.append(token)
|
| 113 |
-
return " ".join(tokens).strip()
|
| 114 |
-
|
| 115 |
-
def pad_batch(
|
| 116 |
-
self,
|
| 117 |
-
sequences: Sequence[Sequence[int]],
|
| 118 |
-
pad_to_length: Optional[int] = None,
|
| 119 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 120 |
-
if not sequences:
|
| 121 |
-
raise ValueError("pad_batch requires at least one sequence")
|
| 122 |
-
if pad_to_length is None:
|
| 123 |
-
pad_to_length = max(len(seq) for seq in sequences)
|
| 124 |
-
padded: List[List[int]] = []
|
| 125 |
-
mask: List[List[int]] = []
|
| 126 |
-
for seq in sequences:
|
| 127 |
-
trimmed = list(seq[:pad_to_length])
|
| 128 |
-
pad_len = pad_to_length - len(trimmed)
|
| 129 |
-
padded.append(trimmed + [self.pad_id] * pad_len)
|
| 130 |
-
mask.append([1] * len(trimmed) + [0] * pad_len)
|
| 131 |
-
return torch.tensor(padded, dtype=torch.long), torch.tensor(mask, dtype=torch.bool)
|
| 132 |
-
|
| 133 |
-
def save(self, path: Path) -> None:
|
| 134 |
-
payload = {
|
| 135 |
-
"itos": self.itos,
|
| 136 |
-
"specials": list(self.specials),
|
| 137 |
-
"lowercase": self.lowercase,
|
| 138 |
-
}
|
| 139 |
-
path.parent.mkdir(parents=True, exist_ok=True)
|
| 140 |
-
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
| 141 |
-
|
| 142 |
-
@classmethod
|
| 143 |
-
def load(cls, path: Path) -> "TransformerTokenizer":
|
| 144 |
-
data = json.loads(path.read_text(encoding="utf-8"))
|
| 145 |
-
itos = list(data["itos"])
|
| 146 |
-
stoi = {token: idx for idx, token in enumerate(itos)}
|
| 147 |
-
specials = data.get("specials", list(SPECIAL_TOKENS))
|
| 148 |
-
lowercase = bool(data.get("lowercase", True))
|
| 149 |
-
return cls(stoi=stoi, itos=itos, specials=specials, lowercase=lowercase)
|
| 150 |
-
|
| 151 |
-
def _lookup(self, token: str) -> int:
|
| 152 |
-
if token not in self.stoi:
|
| 153 |
-
raise ValueError(f"token '{token}' missing from vocabulary")
|
| 154 |
-
return self.stoi[token]
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
@dataclass
|
| 158 |
class Batch:
|
|
|
|
|
|
|
| 159 |
input_ids: torch.Tensor
|
| 160 |
attention_mask: torch.Tensor
|
| 161 |
lengths: List[int]
|
| 162 |
|
| 163 |
|
| 164 |
class TextPreprocessor:
|
| 165 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
def __init__(
|
| 168 |
self,
|
| 169 |
-
|
| 170 |
-
tokenizer: Optional[TransformerTokenizer] = None,
|
| 171 |
*,
|
| 172 |
-
|
|
|
|
|
|
|
| 173 |
lowercase: bool = True,
|
|
|
|
|
|
|
| 174 |
) -> None:
|
| 175 |
-
self.
|
| 176 |
-
self.min_freq = min_freq
|
| 177 |
self.lowercase = lowercase
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
def clean_text(self, text: str) -> str:
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
def batch_encode(self, texts: Sequence[str]) -> Batch:
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
return Batch(input_ids=input_ids, attention_mask=attention_mask, lengths=lengths)
|
| 205 |
|
| 206 |
-
def
|
| 207 |
-
|
| 208 |
-
raise RuntimeError("Tokenizer not fitted")
|
| 209 |
-
return TransformerEncoder(
|
| 210 |
-
vocab_size=self.tokenizer.vocab_size,
|
| 211 |
-
max_len=self.max_length,
|
| 212 |
-
pad_token_id=self.tokenizer.pad_id,
|
| 213 |
-
**encoder_kwargs,
|
| 214 |
-
)
|
| 215 |
-
|
| 216 |
-
def build_decoder(self, **decoder_kwargs) -> TransformerDecoder:
|
| 217 |
-
if self.tokenizer is None:
|
| 218 |
-
raise RuntimeError("Tokenizer not fitted")
|
| 219 |
-
return TransformerDecoder(
|
| 220 |
-
vocab_size=self.tokenizer.vocab_size,
|
| 221 |
-
max_len=self.max_length,
|
| 222 |
-
pad_token_id=self.tokenizer.pad_id,
|
| 223 |
-
**decoder_kwargs,
|
| 224 |
-
)
|
| 225 |
-
|
| 226 |
-
def save_tokenizer(self, path: Path) -> None:
|
| 227 |
-
if self.tokenizer is None:
|
| 228 |
-
raise RuntimeError("Tokenizer not fitted")
|
| 229 |
-
self.tokenizer.save(path)
|
| 230 |
-
|
| 231 |
-
def load_tokenizer(self, path: Path) -> TransformerTokenizer:
|
| 232 |
-
self.tokenizer = TransformerTokenizer.load(path)
|
| 233 |
-
return self.tokenizer
|
| 234 |
-
|
| 235 |
-
def chunk_text(self, text: str, *, chunk_size: int = 1000, overlap: int = 100) -> List[str]:
|
| 236 |
-
if chunk_size <= overlap:
|
| 237 |
-
raise ValueError("chunk_size must be larger than overlap")
|
| 238 |
-
words = self.clean_text(text).split()
|
| 239 |
-
chunks: List[str] = []
|
| 240 |
-
start = 0
|
| 241 |
-
while start < len(words):
|
| 242 |
-
end = min(start + chunk_size, len(words))
|
| 243 |
-
chunks.append(" ".join(words[start:end]))
|
| 244 |
-
start += chunk_size - overlap
|
| 245 |
-
return chunks
|
| 246 |
-
|
| 247 |
-
def save_book_chunks(
|
| 248 |
-
self,
|
| 249 |
-
input_path: Path,
|
| 250 |
-
out_dir: Path,
|
| 251 |
-
*,
|
| 252 |
-
chunk_size: int = 1000,
|
| 253 |
-
overlap: int = 100,
|
| 254 |
-
) -> Path:
|
| 255 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 256 |
-
raw_text = input_path.read_text(encoding="utf-8", errors="ignore")
|
| 257 |
-
chunks = self.chunk_text(raw_text, chunk_size=chunk_size, overlap=overlap)
|
| 258 |
-
out_file = out_dir / f"{input_path.stem}.json"
|
| 259 |
-
out_file.write_text(json.dumps(chunks, ensure_ascii=False, indent=2), encoding="utf-8")
|
| 260 |
-
return out_file
|
|
|
|
| 1 |
+
"""Text preprocessing utilities built around Hugging Face tokenizers."""
|
|
|
|
| 2 |
from __future__ import annotations
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import re
|
| 5 |
+
from dataclasses import dataclass, replace
|
| 6 |
+
from typing import Iterable, List, Sequence
|
| 7 |
|
| 8 |
import torch
|
| 9 |
+
from sklearn.base import BaseEstimator, TransformerMixin
|
| 10 |
+
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
|
| 11 |
|
| 12 |
+
from .tokenization import Tokenizer, TokenizerConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
class BasicTextCleaner(BaseEstimator, TransformerMixin):
|
| 16 |
+
"""Minimal text cleaner following scikit-learn conventions."""
|
| 17 |
|
| 18 |
+
def __init__(self, lowercase: bool = True, strip: bool = True) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
self.lowercase = lowercase
|
| 20 |
+
self.strip = strip
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
def fit(self, texts: Iterable[str], y: Iterable[str] | None = None):
|
| 23 |
+
return self
|
| 24 |
|
| 25 |
+
def transform(self, texts: Iterable[str]) -> List[str]:
|
| 26 |
+
return [self._clean_text(text) for text in texts]
|
|
|
|
| 27 |
|
| 28 |
+
def _clean_text(self, text: str) -> str:
|
| 29 |
+
item = text.strip() if self.strip else text
|
| 30 |
+
if self.lowercase:
|
| 31 |
+
item = item.lower()
|
| 32 |
+
return " ".join(item.split())
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
@dataclass(slots=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
class Batch:
|
| 37 |
+
"""Bundle of tensors returned by the text preprocessor."""
|
| 38 |
+
|
| 39 |
input_ids: torch.Tensor
|
| 40 |
attention_mask: torch.Tensor
|
| 41 |
lengths: List[int]
|
| 42 |
|
| 43 |
|
| 44 |
class TextPreprocessor:
|
| 45 |
+
"""Coordinate lightweight text cleaning and tokenization.
|
| 46 |
+
|
| 47 |
+
When supplying an already-initialized tokenizer instance, its configuration is left
|
| 48 |
+
untouched. If a differing ``max_length`` is requested, a ``ValueError`` is raised to
|
| 49 |
+
avoid mutating shared tokenizer state.
|
| 50 |
+
"""
|
| 51 |
|
| 52 |
def __init__(
|
| 53 |
self,
|
| 54 |
+
tokenizer: Tokenizer | None = None,
|
|
|
|
| 55 |
*,
|
| 56 |
+
tokenizer_config: TokenizerConfig | None = None,
|
| 57 |
+
tokenizer_name: str = "facebook/bart-base",
|
| 58 |
+
max_length: int | None = None,
|
| 59 |
lowercase: bool = True,
|
| 60 |
+
remove_stopwords: bool = False,
|
| 61 |
+
sklearn_transformer: TransformerMixin | None = None,
|
| 62 |
) -> None:
|
| 63 |
+
self.cleaner = BasicTextCleaner(lowercase=lowercase, strip=True)
|
|
|
|
| 64 |
self.lowercase = lowercase
|
| 65 |
+
if remove_stopwords:
|
| 66 |
+
raise ValueError(
|
| 67 |
+
"Stop-word removal is not supported because it conflicts with subword tokenizers; "
|
| 68 |
+
"clean the text externally before initializing TextPreprocessor."
|
| 69 |
+
)
|
| 70 |
+
self._stop_words = None
|
| 71 |
+
self._sklearn_transformer = sklearn_transformer
|
| 72 |
+
|
| 73 |
+
if tokenizer is None:
|
| 74 |
+
cfg = tokenizer_config or TokenizerConfig(pretrained_model_name=tokenizer_name)
|
| 75 |
+
if max_length is not None:
|
| 76 |
+
cfg = replace(cfg, max_length=max_length)
|
| 77 |
+
self.tokenizer = Tokenizer(cfg)
|
| 78 |
+
else:
|
| 79 |
+
self.tokenizer = tokenizer
|
| 80 |
+
if max_length is not None and max_length != tokenizer.config.max_length:
|
| 81 |
+
raise ValueError(
|
| 82 |
+
"Provided tokenizer config.max_length does not match requested max_length; "
|
| 83 |
+
"initialise the tokenizer with desired settings before passing it in."
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
self.max_length = max_length or self.tokenizer.config.max_length
|
| 87 |
|
| 88 |
def clean_text(self, text: str) -> str:
|
| 89 |
+
item = self.cleaner.transform([text])[0]
|
| 90 |
+
return self._normalize_tokens(item)
|
| 91 |
+
|
| 92 |
+
def _normalize_tokens(self, text: str) -> str:
|
| 93 |
+
"""Apply token-level normalization and optional stop-word filtering."""
|
| 94 |
+
# Note: Pre-tokenization word-splitting is incompatible with subword tokenizers.
|
| 95 |
+
# Stop-word filtering should be done post-tokenization or not at all for transformers.
|
| 96 |
+
return text
|
| 97 |
+
|
| 98 |
+
def _apply_sklearn_transform(self, texts: List[str]) -> List[str]:
|
| 99 |
+
if self._sklearn_transformer is None:
|
| 100 |
+
return texts
|
| 101 |
+
|
| 102 |
+
transform = getattr(self._sklearn_transformer, "transform", None)
|
| 103 |
+
if transform is None:
|
| 104 |
+
raise AttributeError("Provided sklearn transformer must implement a 'transform' method")
|
| 105 |
+
transformed = transform(texts)
|
| 106 |
+
if isinstance(transformed, list):
|
| 107 |
+
return transformed # assume downstream type is already list[str]
|
| 108 |
+
if hasattr(transformed, "tolist"):
|
| 109 |
+
transformed = transformed.tolist()
|
| 110 |
+
|
| 111 |
+
result = list(transformed)
|
| 112 |
+
if not all(isinstance(item, str) for item in result):
|
| 113 |
+
result = [str(item) for item in result]
|
| 114 |
+
return result
|
| 115 |
+
|
| 116 |
+
def _prepare_texts(self, texts: Sequence[str]) -> List[str]:
|
| 117 |
+
cleaned = self.cleaner.transform(texts)
|
| 118 |
+
normalized = [self._normalize_tokens(text) for text in cleaned]
|
| 119 |
+
return self._apply_sklearn_transform(normalized)
|
| 120 |
|
| 121 |
def batch_encode(self, texts: Sequence[str]) -> Batch:
|
| 122 |
+
cleaned = self._prepare_texts(texts)
|
| 123 |
+
encoded = self.tokenizer.batch_encode(cleaned, max_length=self.max_length)
|
| 124 |
+
input_ids: torch.Tensor = encoded["input_ids"]
|
| 125 |
+
attention_mask: torch.Tensor = encoded["attention_mask"].to(dtype=torch.bool)
|
| 126 |
+
lengths = attention_mask.sum(dim=1).tolist()
|
| 127 |
return Batch(input_ids=input_ids, attention_mask=attention_mask, lengths=lengths)
|
| 128 |
|
| 129 |
+
def __call__(self, texts: Sequence[str]) -> Batch:
|
| 130 |
+
return self.batch_encode(texts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/tokenization.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tokenizer wrapper around HuggingFace models used across LexiMind."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Iterable, List, Sequence, cast
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass(slots=True)
|
| 12 |
+
class TokenizerConfig:
|
| 13 |
+
pretrained_model_name: str = "facebook/bart-base"
|
| 14 |
+
max_length: int = 512
|
| 15 |
+
padding: str = "longest"
|
| 16 |
+
truncation: bool = True
|
| 17 |
+
lower: bool = False
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Tokenizer:
|
| 21 |
+
"""Lightweight façade over a HuggingFace tokenizer."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, config: TokenizerConfig | None = None) -> None:
|
| 24 |
+
cfg = config or TokenizerConfig()
|
| 25 |
+
self.config = cfg
|
| 26 |
+
self._tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(cfg.pretrained_model_name)
|
| 27 |
+
self._pad_token_id = self._resolve_id(self._tokenizer.pad_token_id)
|
| 28 |
+
self._bos_token_id = self._resolve_id(
|
| 29 |
+
self._tokenizer.bos_token_id if self._tokenizer.bos_token_id is not None else self._tokenizer.cls_token_id
|
| 30 |
+
)
|
| 31 |
+
self._eos_token_id = self._resolve_id(
|
| 32 |
+
self._tokenizer.eos_token_id if self._tokenizer.eos_token_id is not None else self._tokenizer.sep_token_id
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def tokenizer(self) -> PreTrainedTokenizerBase:
|
| 37 |
+
return self._tokenizer
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def pad_token_id(self) -> int:
|
| 41 |
+
return self._pad_token_id
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def bos_token_id(self) -> int:
|
| 45 |
+
return self._bos_token_id
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def eos_token_id(self) -> int:
|
| 49 |
+
return self._eos_token_id
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def vocab_size(self) -> int:
|
| 53 |
+
vocab = getattr(self._tokenizer, "vocab_size", None)
|
| 54 |
+
if vocab is None:
|
| 55 |
+
raise RuntimeError("Tokenizer must expose vocab_size")
|
| 56 |
+
return int(vocab)
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def _resolve_id(value) -> int:
|
| 60 |
+
if value is None:
|
| 61 |
+
raise ValueError("Tokenizer is missing required special token ids")
|
| 62 |
+
if isinstance(value, (list, tuple)):
|
| 63 |
+
value = value[0]
|
| 64 |
+
return int(value)
|
| 65 |
+
|
| 66 |
+
def encode(self, text: str) -> List[int]:
|
| 67 |
+
content = text.lower() if self.config.lower else text
|
| 68 |
+
return self._tokenizer.encode(
|
| 69 |
+
content,
|
| 70 |
+
max_length=self.config.max_length,
|
| 71 |
+
truncation=self.config.truncation,
|
| 72 |
+
padding=self.config.padding,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def encode_batch(self, texts: Sequence[str]) -> List[List[int]]:
|
| 76 |
+
normalized = (text.lower() if self.config.lower else text for text in texts)
|
| 77 |
+
encoded = self._tokenizer.batch_encode_plus(
|
| 78 |
+
list(normalized),
|
| 79 |
+
max_length=self.config.max_length,
|
| 80 |
+
padding=self.config.padding,
|
| 81 |
+
truncation=self.config.truncation,
|
| 82 |
+
return_attention_mask=False,
|
| 83 |
+
return_tensors=None,
|
| 84 |
+
)
|
| 85 |
+
return cast(List[List[int]], encoded["input_ids"])
|
| 86 |
+
|
| 87 |
+
def batch_encode(self, texts: Sequence[str], *, max_length: int | None = None) -> dict[str, torch.Tensor]:
|
| 88 |
+
normalized = [text.lower() if self.config.lower else text for text in texts]
|
| 89 |
+
encoded = self._tokenizer(
|
| 90 |
+
normalized,
|
| 91 |
+
padding=self.config.padding,
|
| 92 |
+
truncation=self.config.truncation,
|
| 93 |
+
max_length=max_length or self.config.max_length,
|
| 94 |
+
return_tensors="pt",
|
| 95 |
+
)
|
| 96 |
+
input_ids = cast(torch.Tensor, encoded["input_ids"])
|
| 97 |
+
attention_mask = cast(torch.Tensor, encoded["attention_mask"])
|
| 98 |
+
if input_ids.dtype != torch.long:
|
| 99 |
+
input_ids = input_ids.to(dtype=torch.long)
|
| 100 |
+
if attention_mask.dtype != torch.bool:
|
| 101 |
+
attention_mask = attention_mask.to(dtype=torch.bool)
|
| 102 |
+
return {
|
| 103 |
+
"input_ids": input_ids,
|
| 104 |
+
"attention_mask": attention_mask,
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
def decode(self, token_ids: Iterable[int]) -> str:
|
| 108 |
+
return self._tokenizer.decode(list(token_ids), skip_special_tokens=True)
|
| 109 |
+
|
| 110 |
+
def decode_batch(self, sequences: Sequence[Sequence[int]]) -> List[str]:
|
| 111 |
+
prepared = [list(seq) for seq in sequences]
|
| 112 |
+
return self._tokenizer.batch_decode(prepared, skip_special_tokens=True)
|
| 113 |
+
|
| 114 |
+
def prepare_decoder_inputs(self, labels: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
"""Shift decoder labels to create input ids prefixed by BOS."""
|
| 116 |
+
|
| 117 |
+
bos = self.bos_token_id
|
| 118 |
+
pad = self.pad_token_id
|
| 119 |
+
decoder_inputs = torch.full_like(labels, pad)
|
| 120 |
+
decoder_inputs[:, 0] = bos
|
| 121 |
+
decoder_inputs[:, 1:] = labels[:, :-1]
|
| 122 |
+
return decoder_inputs
|
src/inference/__init__.py
CHANGED
|
@@ -1,7 +1,12 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Inference utilities for LexiMind.
|
| 3 |
-
"""
|
| 4 |
|
| 5 |
-
from .
|
|
|
|
| 6 |
|
| 7 |
-
__all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference tools for LexiMind."""
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
from .factory import create_inference_pipeline
|
| 4 |
+
from .pipeline import EmotionPrediction, InferenceConfig, InferencePipeline, TopicPrediction
|
| 5 |
|
| 6 |
+
__all__ = [
|
| 7 |
+
"InferencePipeline",
|
| 8 |
+
"InferenceConfig",
|
| 9 |
+
"EmotionPrediction",
|
| 10 |
+
"TopicPrediction",
|
| 11 |
+
"create_inference_pipeline",
|
| 12 |
+
]
|
src/inference/baseline_summarizer.py
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
"""Thin wrapper around the custom transformer summarizer."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
from typing import Any, Dict, Optional, Tuple
|
| 5 |
-
import torch
|
| 6 |
-
from ..api.inference import load_models
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class TransformerSummarizer:
|
| 10 |
-
def __init__(self, config: Optional[Dict[str, Any]] = None) -> None:
|
| 11 |
-
models = load_models(config or {})
|
| 12 |
-
if not models.get("loaded"):
|
| 13 |
-
raise RuntimeError("load_models returned an unloaded model; check configuration")
|
| 14 |
-
self.model = models["mt"]
|
| 15 |
-
self.preprocessor = models["preprocessor"]
|
| 16 |
-
self.device = models["device"]
|
| 17 |
-
|
| 18 |
-
def summarize(
|
| 19 |
-
self,
|
| 20 |
-
text: str,
|
| 21 |
-
compression: float = 0.25,
|
| 22 |
-
collect_attn: bool = False,
|
| 23 |
-
) -> Tuple[str, Optional[Dict[str, torch.Tensor]]]:
|
| 24 |
-
batch = self.preprocessor.batch_encode([text])
|
| 25 |
-
tokenizer = self.preprocessor.tokenizer
|
| 26 |
-
encoder = self.model.encoder
|
| 27 |
-
decoder = self.model.decoder
|
| 28 |
-
if tokenizer is None or encoder is None or decoder is None:
|
| 29 |
-
raise RuntimeError("Model components are missing; ensure encoder, decoder, and tokenizer are set")
|
| 30 |
-
input_ids = batch.input_ids.to(self.device)
|
| 31 |
-
memory = encoder(input_ids)
|
| 32 |
-
src_len = batch.lengths[0]
|
| 33 |
-
target_len = max(4, int(src_len * compression))
|
| 34 |
-
generated = decoder.greedy_decode(
|
| 35 |
-
memory,
|
| 36 |
-
max_len=min(self.preprocessor.max_length, target_len),
|
| 37 |
-
start_token_id=tokenizer.bos_id,
|
| 38 |
-
end_token_id=tokenizer.eos_id,
|
| 39 |
-
)
|
| 40 |
-
summary = tokenizer.decode(generated[0].tolist(), skip_special_tokens=True)
|
| 41 |
-
return summary.strip(), None if not collect_attn else {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/inference/factory.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helpers to assemble an inference pipeline from saved artifacts."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from ..data.tokenization import Tokenizer, TokenizerConfig
|
| 10 |
+
from ..models.factory import ModelConfig, build_multitask_model, load_model_config
|
| 11 |
+
from ..utils.io import load_state
|
| 12 |
+
from ..utils.labels import LabelMetadata, load_label_metadata
|
| 13 |
+
from .pipeline import InferenceConfig, InferencePipeline
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_inference_pipeline(
|
| 17 |
+
checkpoint_path: str | Path,
|
| 18 |
+
labels_path: str | Path,
|
| 19 |
+
*,
|
| 20 |
+
tokenizer_config: TokenizerConfig | None = None,
|
| 21 |
+
tokenizer_dir: str | Path | None = None,
|
| 22 |
+
model_config_path: str | Path | None = None,
|
| 23 |
+
device: str | torch.device = "cpu",
|
| 24 |
+
summary_max_length: int | None = None,
|
| 25 |
+
) -> Tuple[InferencePipeline, LabelMetadata]:
|
| 26 |
+
"""Build an :class:`InferencePipeline` from saved model and label metadata."""
|
| 27 |
+
|
| 28 |
+
checkpoint = Path(checkpoint_path)
|
| 29 |
+
if not checkpoint.exists():
|
| 30 |
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
|
| 31 |
+
|
| 32 |
+
labels = load_label_metadata(labels_path)
|
| 33 |
+
|
| 34 |
+
resolved_tokenizer_config = tokenizer_config
|
| 35 |
+
if resolved_tokenizer_config is None:
|
| 36 |
+
default_dir = Path(__file__).resolve().parent.parent.parent / "artifacts" / "hf_tokenizer"
|
| 37 |
+
chosen_dir = Path(tokenizer_dir) if tokenizer_dir is not None else default_dir
|
| 38 |
+
local_tokenizer_dir = chosen_dir
|
| 39 |
+
if local_tokenizer_dir.exists():
|
| 40 |
+
resolved_tokenizer_config = TokenizerConfig(pretrained_model_name=str(local_tokenizer_dir))
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError(
|
| 43 |
+
"No tokenizer configuration provided and default tokenizer directory "
|
| 44 |
+
f"'{local_tokenizer_dir}' not found. Please provide tokenizer_config parameter or set tokenizer_dir."
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
tokenizer = Tokenizer(resolved_tokenizer_config)
|
| 48 |
+
model_config = load_model_config(model_config_path)
|
| 49 |
+
model = build_multitask_model(
|
| 50 |
+
tokenizer,
|
| 51 |
+
num_emotions=labels.emotion_size,
|
| 52 |
+
num_topics=labels.topic_size,
|
| 53 |
+
config=model_config,
|
| 54 |
+
)
|
| 55 |
+
load_state(model, str(checkpoint))
|
| 56 |
+
|
| 57 |
+
if isinstance(device, torch.device):
|
| 58 |
+
device_str = str(device)
|
| 59 |
+
else:
|
| 60 |
+
device_str = device
|
| 61 |
+
|
| 62 |
+
if summary_max_length is not None:
|
| 63 |
+
pipeline_config = InferenceConfig(summary_max_length=summary_max_length, device=device_str)
|
| 64 |
+
else:
|
| 65 |
+
pipeline_config = InferenceConfig(device=device_str)
|
| 66 |
+
|
| 67 |
+
pipeline = InferencePipeline(
|
| 68 |
+
model=model,
|
| 69 |
+
tokenizer=tokenizer,
|
| 70 |
+
config=pipeline_config,
|
| 71 |
+
emotion_labels=labels.emotion,
|
| 72 |
+
topic_labels=labels.topic,
|
| 73 |
+
device=device,
|
| 74 |
+
)
|
| 75 |
+
return pipeline, labels
|
src/inference/generation.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generation helpers."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def greedy_decode(model: torch.nn.Module, input_ids: torch.Tensor, max_length: int) -> torch.Tensor:
|
| 7 |
+
"""Run greedy decoding with ``model.generate`` and return generated token ids."""
|
| 8 |
+
|
| 9 |
+
return model.generate(
|
| 10 |
+
input_ids,
|
| 11 |
+
max_length=max_length,
|
| 12 |
+
do_sample=False,
|
| 13 |
+
num_beams=1,
|
| 14 |
+
)
|
src/inference/pipeline.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference helpers for multitask LexiMind models."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass, fields, replace
|
| 5 |
+
from typing import Iterable, List, Sequence
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from ..data.preprocessing import Batch, TextPreprocessor
|
| 11 |
+
from ..data.tokenization import Tokenizer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(slots=True)
|
| 15 |
+
class InferenceConfig:
|
| 16 |
+
"""Configuration knobs for the inference pipeline."""
|
| 17 |
+
|
| 18 |
+
summary_max_length: int = 128
|
| 19 |
+
emotion_threshold: float = 0.5
|
| 20 |
+
device: str | None = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass(slots=True)
|
| 24 |
+
class EmotionPrediction:
|
| 25 |
+
labels: List[str]
|
| 26 |
+
scores: List[float]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass(slots=True)
|
| 30 |
+
class TopicPrediction:
|
| 31 |
+
label: str
|
| 32 |
+
confidence: float
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class InferencePipeline:
|
| 36 |
+
"""Run summarization, emotion, and topic heads through a unified interface."""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
model: torch.nn.Module,
|
| 41 |
+
tokenizer: Tokenizer,
|
| 42 |
+
*,
|
| 43 |
+
preprocessor: TextPreprocessor | None = None,
|
| 44 |
+
emotion_labels: Sequence[str] | None = None,
|
| 45 |
+
topic_labels: Sequence[str] | None = None,
|
| 46 |
+
config: InferenceConfig | None = None,
|
| 47 |
+
device: torch.device | str | None = None,
|
| 48 |
+
) -> None:
|
| 49 |
+
self.model = model
|
| 50 |
+
self.tokenizer = tokenizer
|
| 51 |
+
self.config = config or InferenceConfig()
|
| 52 |
+
chosen_device = device or self.config.device
|
| 53 |
+
if chosen_device is None:
|
| 54 |
+
first_param = next(model.parameters(), None)
|
| 55 |
+
chosen_device = first_param.device if first_param is not None else "cpu"
|
| 56 |
+
self.device = torch.device(chosen_device)
|
| 57 |
+
self.model.to(self.device)
|
| 58 |
+
self.model.eval()
|
| 59 |
+
|
| 60 |
+
self.preprocessor = preprocessor or TextPreprocessor(tokenizer=tokenizer)
|
| 61 |
+
self.emotion_labels = list(emotion_labels) if emotion_labels is not None else None
|
| 62 |
+
self.topic_labels = list(topic_labels) if topic_labels is not None else None
|
| 63 |
+
|
| 64 |
+
def summarize(self, texts: Sequence[str], *, max_length: int | None = None) -> List[str]:
|
| 65 |
+
if not texts:
|
| 66 |
+
return []
|
| 67 |
+
batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
|
| 68 |
+
src_ids = batch.input_ids
|
| 69 |
+
src_mask = batch.attention_mask
|
| 70 |
+
max_len = max_length or self.config.summary_max_length
|
| 71 |
+
|
| 72 |
+
if not hasattr(self.model, "encoder") or not hasattr(self.model, "decoder"):
|
| 73 |
+
raise RuntimeError("Model must expose encoder and decoder attributes for summarization.")
|
| 74 |
+
|
| 75 |
+
with torch.inference_mode():
|
| 76 |
+
encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
| 77 |
+
memory = self.model.encoder(src_ids, mask=encoder_mask)
|
| 78 |
+
generated = self.model.decoder.greedy_decode(
|
| 79 |
+
memory=memory,
|
| 80 |
+
max_len=max_len,
|
| 81 |
+
start_token_id=self.tokenizer.bos_token_id,
|
| 82 |
+
end_token_id=self.tokenizer.eos_token_id,
|
| 83 |
+
device=self.device,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
return self.tokenizer.decode_batch(generated.tolist())
|
| 87 |
+
|
| 88 |
+
def predict_emotions(
|
| 89 |
+
self,
|
| 90 |
+
texts: Sequence[str],
|
| 91 |
+
*,
|
| 92 |
+
threshold: float | None = None,
|
| 93 |
+
) -> List[EmotionPrediction]:
|
| 94 |
+
if not texts:
|
| 95 |
+
return []
|
| 96 |
+
if self.emotion_labels is None or not self.emotion_labels:
|
| 97 |
+
raise RuntimeError("emotion_labels must be provided to decode emotion predictions")
|
| 98 |
+
|
| 99 |
+
batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
|
| 100 |
+
model_inputs = self._batch_to_model_inputs(batch)
|
| 101 |
+
decision_threshold = threshold or self.config.emotion_threshold
|
| 102 |
+
|
| 103 |
+
with torch.inference_mode():
|
| 104 |
+
logits = self.model.forward("emotion", model_inputs)
|
| 105 |
+
probs = torch.sigmoid(logits)
|
| 106 |
+
|
| 107 |
+
predictions: List[EmotionPrediction] = []
|
| 108 |
+
for row in probs.cpu():
|
| 109 |
+
pairs = [
|
| 110 |
+
(label, score)
|
| 111 |
+
for label, score in zip(self.emotion_labels, row.tolist())
|
| 112 |
+
if score >= decision_threshold
|
| 113 |
+
]
|
| 114 |
+
labels = [label for label, _ in pairs]
|
| 115 |
+
scores = [score for _, score in pairs]
|
| 116 |
+
predictions.append(EmotionPrediction(labels=labels, scores=scores))
|
| 117 |
+
return predictions
|
| 118 |
+
|
| 119 |
+
def predict_topics(self, texts: Sequence[str]) -> List[TopicPrediction]:
|
| 120 |
+
if not texts:
|
| 121 |
+
return []
|
| 122 |
+
if self.topic_labels is None or not self.topic_labels:
|
| 123 |
+
raise RuntimeError("topic_labels must be provided to decode topic predictions")
|
| 124 |
+
|
| 125 |
+
batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
|
| 126 |
+
model_inputs = self._batch_to_model_inputs(batch)
|
| 127 |
+
|
| 128 |
+
with torch.inference_mode():
|
| 129 |
+
logits = self.model.forward("topic", model_inputs)
|
| 130 |
+
probs = F.softmax(logits, dim=-1)
|
| 131 |
+
|
| 132 |
+
results: List[TopicPrediction] = []
|
| 133 |
+
for row in probs.cpu():
|
| 134 |
+
scores = row.tolist()
|
| 135 |
+
best_index = int(row.argmax().item())
|
| 136 |
+
results.append(TopicPrediction(label=self.topic_labels[best_index], confidence=scores[best_index]))
|
| 137 |
+
return results
|
| 138 |
+
|
| 139 |
+
def batch_predict(self, texts: Iterable[str]) -> dict[str, object]:
|
| 140 |
+
text_list = list(texts)
|
| 141 |
+
if self.emotion_labels is None or not self.emotion_labels:
|
| 142 |
+
raise RuntimeError("emotion_labels must be provided for batch predictions")
|
| 143 |
+
if self.topic_labels is None or not self.topic_labels:
|
| 144 |
+
raise RuntimeError("topic_labels must be provided for batch predictions")
|
| 145 |
+
return {
|
| 146 |
+
"summaries": self.summarize(text_list),
|
| 147 |
+
"emotion": self.predict_emotions(text_list),
|
| 148 |
+
"topic": self.predict_topics(text_list),
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
def _batch_to_device(self, batch: Batch) -> Batch:
|
| 152 |
+
tensor_updates: dict[str, torch.Tensor] = {}
|
| 153 |
+
for item in fields(batch):
|
| 154 |
+
value = getattr(batch, item.name)
|
| 155 |
+
if torch.is_tensor(value):
|
| 156 |
+
tensor_updates[item.name] = value.to(self.device)
|
| 157 |
+
if not tensor_updates:
|
| 158 |
+
return batch
|
| 159 |
+
return replace(batch, **tensor_updates)
|
| 160 |
+
|
| 161 |
+
@staticmethod
|
| 162 |
+
def _batch_to_model_inputs(batch: Batch) -> dict[str, torch.Tensor]:
|
| 163 |
+
inputs: dict[str, torch.Tensor] = {"input_ids": batch.input_ids}
|
| 164 |
+
if batch.attention_mask is not None:
|
| 165 |
+
inputs["attention_mask"] = batch.attention_mask
|
| 166 |
+
return inputs
|
src/inference/postprocessing.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Output cleaning helpers."""
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def strip_whitespace(texts: List[str]) -> List[str]:
|
| 6 |
+
return [text.strip() for text in texts]
|
src/models/factory.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Factory helpers to assemble multitask models for inference/training."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from ..data.tokenization import Tokenizer
|
| 9 |
+
from ..utils.config import load_yaml
|
| 10 |
+
from .decoder import TransformerDecoder
|
| 11 |
+
from .encoder import TransformerEncoder
|
| 12 |
+
from .heads import ClassificationHead, LMHead
|
| 13 |
+
from .multitask import MultiTaskModel
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass(slots=True)
|
| 17 |
+
class ModelConfig:
|
| 18 |
+
"""Configuration describing the transformer architecture."""
|
| 19 |
+
|
| 20 |
+
d_model: int = 512
|
| 21 |
+
num_encoder_layers: int = 6
|
| 22 |
+
num_decoder_layers: int = 6
|
| 23 |
+
num_attention_heads: int = 8
|
| 24 |
+
ffn_dim: int = 2048
|
| 25 |
+
dropout: float = 0.1
|
| 26 |
+
|
| 27 |
+
def __post_init__(self):
|
| 28 |
+
if self.d_model % self.num_attention_heads != 0:
|
| 29 |
+
raise ValueError(
|
| 30 |
+
f"d_model ({self.d_model}) must be divisible by num_attention_heads ({self.num_attention_heads})"
|
| 31 |
+
)
|
| 32 |
+
if not 0 <= self.dropout <= 1:
|
| 33 |
+
raise ValueError(f"dropout must be in [0, 1], got {self.dropout}")
|
| 34 |
+
if self.d_model <= 0 or self.num_encoder_layers <= 0 or self.num_decoder_layers <= 0:
|
| 35 |
+
raise ValueError("Model dimensions must be positive")
|
| 36 |
+
if self.num_attention_heads <= 0 or self.ffn_dim <= 0:
|
| 37 |
+
raise ValueError("Model dimensions must be positive")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_model_config(path: Optional[str | Path]) -> ModelConfig:
|
| 41 |
+
"""Load a model configuration from YAML with sane defaults."""
|
| 42 |
+
|
| 43 |
+
if path is None:
|
| 44 |
+
return ModelConfig()
|
| 45 |
+
|
| 46 |
+
data = load_yaml(str(path)).data
|
| 47 |
+
return ModelConfig(
|
| 48 |
+
d_model=int(data.get("d_model", 512)),
|
| 49 |
+
num_encoder_layers=int(data.get("num_encoder_layers", 6)),
|
| 50 |
+
num_decoder_layers=int(data.get("num_decoder_layers", 6)),
|
| 51 |
+
num_attention_heads=int(data.get("num_attention_heads", 8)),
|
| 52 |
+
ffn_dim=int(data.get("ffn_dim", 2048)),
|
| 53 |
+
dropout=float(data.get("dropout", 0.1)),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def build_multitask_model(
|
| 58 |
+
tokenizer: Tokenizer,
|
| 59 |
+
*,
|
| 60 |
+
num_emotions: int,
|
| 61 |
+
num_topics: int,
|
| 62 |
+
config: ModelConfig | None = None,
|
| 63 |
+
) -> MultiTaskModel:
|
| 64 |
+
"""Construct the multitask transformer with heads for the three tasks."""
|
| 65 |
+
|
| 66 |
+
cfg = config or ModelConfig()
|
| 67 |
+
if not isinstance(num_emotions, int) or num_emotions <= 0:
|
| 68 |
+
raise ValueError("num_emotions must be a positive integer")
|
| 69 |
+
if not isinstance(num_topics, int) or num_topics <= 0:
|
| 70 |
+
raise ValueError("num_topics must be a positive integer")
|
| 71 |
+
encoder = TransformerEncoder(
|
| 72 |
+
vocab_size=tokenizer.vocab_size,
|
| 73 |
+
d_model=cfg.d_model,
|
| 74 |
+
num_layers=cfg.num_encoder_layers,
|
| 75 |
+
num_heads=cfg.num_attention_heads,
|
| 76 |
+
d_ff=cfg.ffn_dim,
|
| 77 |
+
dropout=cfg.dropout,
|
| 78 |
+
max_len=tokenizer.config.max_length,
|
| 79 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 80 |
+
)
|
| 81 |
+
decoder = TransformerDecoder(
|
| 82 |
+
vocab_size=tokenizer.vocab_size,
|
| 83 |
+
d_model=cfg.d_model,
|
| 84 |
+
num_layers=cfg.num_decoder_layers,
|
| 85 |
+
num_heads=cfg.num_attention_heads,
|
| 86 |
+
d_ff=cfg.ffn_dim,
|
| 87 |
+
dropout=cfg.dropout,
|
| 88 |
+
max_len=tokenizer.config.max_length,
|
| 89 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
model = MultiTaskModel(encoder=encoder, decoder=decoder, decoder_outputs_logits=True)
|
| 93 |
+
model.add_head(
|
| 94 |
+
"summarization",
|
| 95 |
+
LMHead(d_model=cfg.d_model, vocab_size=tokenizer.vocab_size, tie_embedding=decoder.embedding),
|
| 96 |
+
)
|
| 97 |
+
model.add_head(
|
| 98 |
+
"emotion",
|
| 99 |
+
ClassificationHead(d_model=cfg.d_model, num_labels=num_emotions, pooler="mean", dropout=cfg.dropout),
|
| 100 |
+
)
|
| 101 |
+
model.add_head(
|
| 102 |
+
"topic",
|
| 103 |
+
ClassificationHead(d_model=cfg.d_model, num_labels=num_topics, pooler="mean", dropout=cfg.dropout),
|
| 104 |
+
)
|
| 105 |
+
return model
|
src/models/multitask.py
CHANGED
|
@@ -39,17 +39,28 @@ class MultiTaskModel(nn.Module):
|
|
| 39 |
mt = MultiTaskModel(encoder=enc, decoder=dec)
|
| 40 |
mt.add_head("summarize", LMHead(...))
|
| 41 |
logits = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
"""
|
| 43 |
|
| 44 |
def __init__(
|
| 45 |
self,
|
| 46 |
encoder: Optional[TransformerEncoder] = None,
|
| 47 |
decoder: Optional[TransformerDecoder] = None,
|
|
|
|
|
|
|
| 48 |
):
|
| 49 |
super().__init__()
|
| 50 |
self.encoder = encoder
|
| 51 |
self.decoder = decoder
|
| 52 |
self.heads: Dict[str, nn.Module] = {}
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def add_head(self, name: str, module: nn.Module) -> None:
|
| 55 |
"""Register a head under a task name."""
|
|
@@ -99,9 +110,15 @@ class MultiTaskModel(nn.Module):
|
|
| 99 |
raise RuntimeError("Encoder is required for encoder-side heads")
|
| 100 |
# accept either input_ids or embeddings
|
| 101 |
if "input_ids" in inputs:
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
| 103 |
elif "embeddings" in inputs:
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
| 105 |
else:
|
| 106 |
raise ValueError("inputs must contain 'input_ids' or 'embeddings' for encoder tasks")
|
| 107 |
logits = head(enc_out)
|
|
@@ -120,10 +137,20 @@ class MultiTaskModel(nn.Module):
|
|
| 120 |
raise RuntimeError("Both encoder and decoder are required for LM-style heads")
|
| 121 |
|
| 122 |
# Build encoder memory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
if "src_ids" in inputs:
|
| 124 |
-
memory = self.encoder(inputs["src_ids"])
|
| 125 |
elif "src_embeddings" in inputs:
|
| 126 |
-
memory = self.encoder(inputs["src_embeddings"])
|
| 127 |
else:
|
| 128 |
raise ValueError("inputs must contain 'src_ids' or 'src_embeddings' for seq2seq tasks")
|
| 129 |
|
|
@@ -137,12 +164,13 @@ class MultiTaskModel(nn.Module):
|
|
| 137 |
# Here we don't attempt to generate when labels not provided.
|
| 138 |
raise ValueError("Seq2seq tasks require 'tgt_ids' or 'tgt_embeddings' for training forward")
|
| 139 |
|
| 140 |
-
# Run decoder. Decoder returns logits shaped (B, T, vocab) in this codebase.
|
| 141 |
decoder_out = self.decoder(decoder_inputs, memory)
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
| 146 |
logits = decoder_out
|
| 147 |
else:
|
| 148 |
logits = head(decoder_out)
|
|
@@ -195,4 +223,15 @@ class MultiTaskModel(nn.Module):
|
|
| 195 |
return F.cross_entropy(logits, labels.long())
|
| 196 |
|
| 197 |
# If we can't determine, raise
|
| 198 |
-
raise RuntimeError("Cannot compute loss for unknown head type")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
mt = MultiTaskModel(encoder=enc, decoder=dec)
|
| 40 |
mt.add_head("summarize", LMHead(...))
|
| 41 |
logits = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids})
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
encoder: optional encoder backbone.
|
| 45 |
+
decoder: optional decoder backbone.
|
| 46 |
+
decoder_outputs_logits: set True when ``decoder.forward`` already returns vocabulary logits;
|
| 47 |
+
set False if the decoder produces hidden states that must be projected by the LM head.
|
| 48 |
"""
|
| 49 |
|
| 50 |
def __init__(
|
| 51 |
self,
|
| 52 |
encoder: Optional[TransformerEncoder] = None,
|
| 53 |
decoder: Optional[TransformerDecoder] = None,
|
| 54 |
+
*,
|
| 55 |
+
decoder_outputs_logits: bool = True,
|
| 56 |
):
|
| 57 |
super().__init__()
|
| 58 |
self.encoder = encoder
|
| 59 |
self.decoder = decoder
|
| 60 |
self.heads: Dict[str, nn.Module] = {}
|
| 61 |
+
# When True, decoder.forward(...) is expected to return logits already projected to the vocabulary space.
|
| 62 |
+
# When False, decoder outputs hidden states that must be passed through the registered LM head.
|
| 63 |
+
self.decoder_outputs_logits = decoder_outputs_logits
|
| 64 |
|
| 65 |
def add_head(self, name: str, module: nn.Module) -> None:
|
| 66 |
"""Register a head under a task name."""
|
|
|
|
| 110 |
raise RuntimeError("Encoder is required for encoder-side heads")
|
| 111 |
# accept either input_ids or embeddings
|
| 112 |
if "input_ids" in inputs:
|
| 113 |
+
encoder_mask = None
|
| 114 |
+
if "attention_mask" in inputs:
|
| 115 |
+
encoder_mask = self._expand_attention_mask(inputs["attention_mask"], inputs["input_ids"].device)
|
| 116 |
+
enc_out = self.encoder(inputs["input_ids"], mask=encoder_mask)
|
| 117 |
elif "embeddings" in inputs:
|
| 118 |
+
encoder_mask = inputs.get("attention_mask")
|
| 119 |
+
if encoder_mask is not None:
|
| 120 |
+
encoder_mask = self._expand_attention_mask(encoder_mask, inputs["embeddings"].device)
|
| 121 |
+
enc_out = self.encoder(inputs["embeddings"], mask=encoder_mask)
|
| 122 |
else:
|
| 123 |
raise ValueError("inputs must contain 'input_ids' or 'embeddings' for encoder tasks")
|
| 124 |
logits = head(enc_out)
|
|
|
|
| 137 |
raise RuntimeError("Both encoder and decoder are required for LM-style heads")
|
| 138 |
|
| 139 |
# Build encoder memory
|
| 140 |
+
src_mask = inputs.get("src_mask")
|
| 141 |
+
if src_mask is None:
|
| 142 |
+
src_mask = inputs.get("attention_mask")
|
| 143 |
+
encoder_mask = None
|
| 144 |
+
reference_tensor = inputs.get("src_ids")
|
| 145 |
+
if reference_tensor is None:
|
| 146 |
+
reference_tensor = inputs.get("src_embeddings")
|
| 147 |
+
if src_mask is not None and reference_tensor is not None:
|
| 148 |
+
encoder_mask = self._expand_attention_mask(src_mask, reference_tensor.device)
|
| 149 |
+
|
| 150 |
if "src_ids" in inputs:
|
| 151 |
+
memory = self.encoder(inputs["src_ids"], mask=encoder_mask)
|
| 152 |
elif "src_embeddings" in inputs:
|
| 153 |
+
memory = self.encoder(inputs["src_embeddings"], mask=encoder_mask)
|
| 154 |
else:
|
| 155 |
raise ValueError("inputs must contain 'src_ids' or 'src_embeddings' for seq2seq tasks")
|
| 156 |
|
|
|
|
| 164 |
# Here we don't attempt to generate when labels not provided.
|
| 165 |
raise ValueError("Seq2seq tasks require 'tgt_ids' or 'tgt_embeddings' for training forward")
|
| 166 |
|
|
|
|
| 167 |
decoder_out = self.decoder(decoder_inputs, memory)
|
| 168 |
|
| 169 |
+
if self.decoder_outputs_logits:
|
| 170 |
+
if not isinstance(decoder_out, torch.Tensor):
|
| 171 |
+
raise TypeError(
|
| 172 |
+
"Decoder is configured to return logits, but forward returned a non-tensor value."
|
| 173 |
+
)
|
| 174 |
logits = decoder_out
|
| 175 |
else:
|
| 176 |
logits = head(decoder_out)
|
|
|
|
| 223 |
return F.cross_entropy(logits, labels.long())
|
| 224 |
|
| 225 |
# If we can't determine, raise
|
| 226 |
+
raise RuntimeError("Cannot compute loss for unknown head type")
|
| 227 |
+
|
| 228 |
+
@staticmethod
|
| 229 |
+
def _expand_attention_mask(mask: torch.Tensor, device: torch.device) -> torch.Tensor:
|
| 230 |
+
if mask is None:
|
| 231 |
+
return None # type: ignore[return-value]
|
| 232 |
+
bool_mask = mask.to(device=device, dtype=torch.bool)
|
| 233 |
+
if bool_mask.dim() == 2:
|
| 234 |
+
return bool_mask.unsqueeze(1) & bool_mask.unsqueeze(2)
|
| 235 |
+
if bool_mask.dim() in (3, 4):
|
| 236 |
+
return bool_mask
|
| 237 |
+
raise ValueError("Attention mask must be 2D, 3D, or 4D tensor")
|
src/training/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Training utilities for LexiMind."""
|
src/training/callbacks.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Callback hooks for training."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Dict, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def save_checkpoint(
|
| 10 |
+
model: torch.nn.Module,
|
| 11 |
+
optimizer: torch.optim.Optimizer,
|
| 12 |
+
epoch: int,
|
| 13 |
+
output_path: str,
|
| 14 |
+
*,
|
| 15 |
+
metrics: Optional[Dict[str, Any]] = None,
|
| 16 |
+
) -> None:
|
| 17 |
+
"""Persist model and optimizer state for resuming training."""
|
| 18 |
+
|
| 19 |
+
checkpoint = {
|
| 20 |
+
"model_state_dict": model.state_dict(),
|
| 21 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 22 |
+
"epoch": int(epoch),
|
| 23 |
+
}
|
| 24 |
+
if metrics:
|
| 25 |
+
checkpoint["metrics"] = metrics
|
| 26 |
+
|
| 27 |
+
target = Path(output_path)
|
| 28 |
+
target.parent.mkdir(parents=True, exist_ok=True)
|
| 29 |
+
temp_path = target.parent / f"{target.name}.tmp"
|
| 30 |
+
try:
|
| 31 |
+
torch.save(checkpoint, temp_path)
|
| 32 |
+
temp_path.replace(target)
|
| 33 |
+
except Exception:
|
| 34 |
+
raise
|
| 35 |
+
finally:
|
| 36 |
+
if temp_path.exists():
|
| 37 |
+
temp_path.unlink(missing_ok=True)
|
src/training/losses.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Loss helpers."""
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def multitask_loss(losses: dict[str, torch.Tensor]) -> torch.Tensor:
|
| 6 |
+
iterator = iter(losses.values())
|
| 7 |
+
try:
|
| 8 |
+
total = next(iterator).clone()
|
| 9 |
+
except StopIteration:
|
| 10 |
+
raise ValueError("losses is empty")
|
| 11 |
+
for value in iterator:
|
| 12 |
+
total = total + value
|
| 13 |
+
return total / len(losses)
|
src/training/metrics.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Metric helpers used during training and evaluation."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Sequence
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def accuracy(predictions: Sequence[int], targets: Sequence[int]) -> float:
|
| 10 |
+
matches = sum(int(pred == target) for pred, target in zip(predictions, targets))
|
| 11 |
+
return matches / max(1, len(predictions))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def multilabel_f1(predictions: torch.Tensor, targets: torch.Tensor) -> float:
|
| 15 |
+
preds = predictions.float()
|
| 16 |
+
gold = targets.float()
|
| 17 |
+
true_positive = (preds * gold).sum(dim=1)
|
| 18 |
+
precision = true_positive / (preds.sum(dim=1).clamp(min=1.0))
|
| 19 |
+
recall = true_positive / (gold.sum(dim=1).clamp(min=1.0))
|
| 20 |
+
f1 = (2 * precision * recall) / (precision + recall).clamp(min=1e-8)
|
| 21 |
+
return float(f1.mean().item())
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def rouge_like(predictions: Sequence[str], references: Sequence[str]) -> float:
|
| 25 |
+
if not predictions or not references:
|
| 26 |
+
return 0.0
|
| 27 |
+
scores = []
|
| 28 |
+
for pred, ref in zip(predictions, references):
|
| 29 |
+
pred_tokens = pred.split()
|
| 30 |
+
ref_tokens = ref.split()
|
| 31 |
+
if not ref_tokens:
|
| 32 |
+
scores.append(0.0)
|
| 33 |
+
continue
|
| 34 |
+
overlap = len(set(pred_tokens) & set(ref_tokens))
|
| 35 |
+
scores.append(overlap / len(ref_tokens))
|
| 36 |
+
return sum(scores) / len(scores)
|