OliverPerrin commited on
Commit
ba4cb76
Β·
1 Parent(s): f3096ca

Reformatted Project Structure

Browse files
.gitignore CHANGED
@@ -1,101 +1,62 @@
1
- # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
  *.py[cod]
4
  *$py.class
5
-
6
- # C extensions
7
  *.so
8
-
9
- # Distribution / packaging
10
  .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
  *.egg-info/
23
- .installed.cfg
 
24
  *.egg
25
 
26
  # Virtual environments
27
- .env
28
- .venv
29
- env/
30
  venv/
 
31
  ENV/
32
- env.bak/
33
- venv.bak/
34
-
35
- # Jupyter Notebook checkpoints
36
- .ipynb_checkpoints
37
-
38
- # PyInstaller
39
- *.manifest
40
- *.spec
41
-
42
- # Unit test / coverage reports
43
- htmlcov/
44
- .tox/
45
- .nox/
46
- .coverage
47
- .coverage.*
48
- .cache
49
- nosetests.xml
50
- coverage.xml
51
- *.cover
52
- *.py,cover
53
- .hypothesis/
54
- .pytest_cache/
55
 
56
- # Pyre type checker
57
- .pyre/
58
-
59
- # mypy
60
- .mypy_cache/
61
- .dmypy.json
62
- dmypy.json
63
 
64
- # Pylint
65
- pylint-report.txt
66
- pylint.log
 
 
 
 
 
67
 
68
- # TensorFlow / Keras / PyTorch training outputs
69
- *.h5
70
- *.hdf5
71
- *.ckpt
72
- *.pb
73
- *.tflite
74
- *.onnx
75
- *.pth
76
  *.pt
 
 
77
 
78
- # Model checkpoints and logs
79
- checkpoints/
80
  logs/
 
81
  runs/
82
 
83
- # Dataset and large files (you may want Git LFS for these)
84
- data/
85
- *.csv
86
- *.tsv
87
- *.json
88
- *.parquet
89
 
90
- # System files
 
 
 
 
91
  .DS_Store
92
  Thumbs.db
 
 
93
 
94
- # IDE / Editor settings
95
- .vscode/
96
- .idea/
97
- *.sublime-project
98
- *.sublime-workspace
99
 
100
- # Streamlit / FastAPI specific
101
- .streamlit/
 
1
+ # Python
2
  __pycache__/
3
  *.py[cod]
4
  *$py.class
 
 
5
  *.so
 
 
6
  .Python
 
 
 
 
 
 
 
 
 
 
 
7
  *.egg-info/
8
+ dist/
9
+ build/
10
  *.egg
11
 
12
  # Virtual environments
 
 
 
13
  venv/
14
+ env/
15
  ENV/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # IDE
18
+ .vscode/
19
+ .idea/
20
+ *.swp
21
+ *.swo
 
 
22
 
23
+ # Data
24
+ data/raw/
25
+ data/processed/
26
+ data/cache/
27
+ *.csv
28
+ *.json
29
+ *.txt
30
+ !requirements*.txt
31
 
32
+ # Models
33
+ checkpoints/
 
 
 
 
 
 
34
  *.pt
35
+ *.pth
36
+ *.ckpt
37
 
38
+ # Logs
 
39
  logs/
40
+ *.log
41
  runs/
42
 
43
+ # Outputs
44
+ outputs/
45
+ results/
 
 
 
46
 
47
+ # Jupyter
48
+ .ipynb_checkpoints/
49
+ *.ipynb
50
+
51
+ # OS - Windows specific
52
  .DS_Store
53
  Thumbs.db
54
+ desktop.ini
55
+ $RECYCLE.BIN/
56
 
57
+ # Windows thumbnail cache
58
+ ehthumbs.db
59
+ ehthumbs_vista.db
 
 
60
 
61
+ # Config overrides
62
+ configs/local/
README.md CHANGED
@@ -1,2 +1,175 @@
1
- # LexiMind
2
- Full NLP Pipeline for text summarization, emotion detection, and topic grouping.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LexiMind: Multi-Task Transformer for Document Analysis
2
+
3
+ A PyTorch-based multi-task learning system that performs abstractive summarization, emotion classification, and topic clustering on textual data using a shared Transformer encoder architecture.
4
+
5
+ ## 🎯 Project Overview
6
+
7
+ LexiMind demonstrates multi-task learning (MTL) by training a single model to simultaneously:
8
+ 1. **Abstractive Summarization**: Generate concise summaries with user-defined compression levels
9
+ 2. **Emotion Classification**: Detect multiple emotions present in text (multi-label classification)
10
+ 3. **Topic Clustering**: Group documents by semantic similarity for topic discovery
11
+
12
+ ### Key Features
13
+ - Custom encoder-decoder Transformer architecture with shared representations
14
+ - Multi-task loss function with learnable task weighting
15
+ - Attention weight visualization for model interpretability
16
+ - Interactive web interface for real-time inference
17
+ - Trained on diverse corpora: news articles (CNN/DailyMail, BBC) and literary texts (Project Gutenberg)
18
+
19
+ ## πŸ—οΈ Architecture
20
+
21
+ ```
22
+ Input Text
23
+ ↓
24
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
25
+ β”‚ Shared Encoder β”‚ ← TransformerEncoder (6 layers)
26
+ β”‚ (Multi-head Attn) β”‚
27
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
28
+ ↓ ↓ ↓
29
+ β”‚ β”‚ └──────────────┐
30
+ β”‚ β”‚ β”‚
31
+ β”‚ └─────────┐ β”‚
32
+ β”‚ β”‚ β”‚
33
+ ↓ ↓ ↓
34
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”
35
+ β”‚ Decoder β”‚ β”‚Classifyβ”‚ β”‚ Project β”‚
36
+ β”‚ Head β”‚ β”‚ Head β”‚ β”‚ Head β”‚
37
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
38
+ ↓ ↓ ↓
39
+ Summary Emotions Embeddings
40
+ (for clustering)
41
+ ```
42
+
43
+ ## πŸ“Š Datasets
44
+
45
+ - **CNN/DailyMail**: 300k+ news articles with human-written summaries
46
+ - **BBC News**: 2,225 articles across 5 categories
47
+ - **Project Gutenberg**: Classic literature for long-form text analysis
48
+
49
+ ## πŸš€ Quick Start
50
+
51
+ ### Installation
52
+ ```bash
53
+ git clone https://github.com/OliverPerrin/LexiMind.git
54
+ cd LexiMind
55
+ pip install -r requirements.txt
56
+ ```
57
+
58
+ ### Download Data
59
+ ```bash
60
+ python src/download_datasets.py
61
+ ```
62
+
63
+ ### Train Model
64
+ ```bash
65
+ python src/train.py --config configs/default.yaml
66
+ ```
67
+
68
+ ### Launch Interface
69
+ ```bash
70
+ python src/app.py
71
+ ```
72
+
73
+ ## πŸ“ Project Structure
74
+
75
+ ```
76
+ LexiMind/
77
+ β”œβ”€β”€ src/
78
+ β”‚ β”œβ”€β”€ models/
79
+ β”‚ β”‚ β”œβ”€β”€ encoder.py # Shared Transformer encoder
80
+ β”‚ β”‚ β”œβ”€β”€ summarization.py # Seq2seq decoder head
81
+ β”‚ β”‚ β”œβ”€β”€ emotion.py # Multi-label classification head
82
+ β”‚ β”‚ └── clustering.py # Projection head for embeddings
83
+ β”‚ β”œβ”€β”€ data/
84
+ β”‚ β”‚ β”œβ”€β”€ download_datasets.py # Data acquisition
85
+ β”‚ β”‚ β”œβ”€β”€ preprocessing.py # Text cleaning & tokenization
86
+ β”‚ β”‚ └── dataset.py # PyTorch Dataset classes
87
+ β”‚ β”œβ”€β”€ training/
88
+ β”‚ β”‚ β”œβ”€β”€ train.py # Training loop
89
+ β”‚ β”‚ β”œβ”€β”€ losses.py # Multi-task loss functions
90
+ β”‚ β”‚ └── metrics.py # ROUGE, F1, silhouette scores
91
+ β”‚ β”œβ”€β”€ inference/
92
+ β”‚ β”‚ └── pipeline.py # End-to-end inference
93
+ β”‚ β”œβ”€β”€ visualization/
94
+ β”‚ β”‚ └── attention.py # Attention heatmap generation
95
+ β”‚ └── app.py # Gradio/FastAPI interface
96
+ β”œβ”€β”€ configs/
97
+ β”‚ └── default.yaml # Model & training hyperparameters
98
+ β”œβ”€β”€ tests/
99
+ β”‚ └── test_*.py # Unit tests
100
+ β”œβ”€β”€ notebooks/
101
+ β”‚ └── exploratory.ipynb # Data exploration & analysis
102
+ β”œβ”€β”€ requirements.txt
103
+ └── README.md
104
+ ```
105
+
106
+ ## πŸ§ͺ Evaluation Metrics
107
+
108
+ | Task | Metric | Score |
109
+ |------|--------|-------|
110
+ | Summarization | ROUGE-1 / ROUGE-L | TBD |
111
+ | Emotion Classification | Macro F1 | TBD |
112
+ | Topic Clustering | Silhouette Score | TBD |
113
+
114
+ ## πŸ”¬ Technical Details
115
+
116
+ ### Model Specifications
117
+ - **Encoder**: 6-layer Transformer (d_model=512, 8 attention heads)
118
+ - **Decoder**: 6-layer autoregressive Transformer
119
+ - **Vocab Size**: 32,000 (SentencePiece tokenizer)
120
+ - **Parameters**: ~60M total
121
+
122
+ ### Training
123
+ - **Optimizer**: AdamW (lr=1e-4, weight_decay=0.01)
124
+ - **Scheduler**: Linear warmup (5000 steps) + cosine decay
125
+ - **Loss**: Weighted sum of cross-entropy (summarization), BCE (emotions), triplet loss (clustering)
126
+ - **Hardware**: Trained on single NVIDIA RTX 3090 (24GB VRAM)
127
+ - **Time**: ~48 hours for 10 epochs
128
+
129
+ ### Multi-Task Learning Strategy
130
+ Uses uncertainty weighting ([Kendall et al., 2018](https://arxiv.org/abs/1705.07115)) to automatically balance task losses:
131
+
132
+ ```
133
+ L_total = Ξ£ (1/2σ²_i * L_i + log(Οƒ_i))
134
+ ```
135
+
136
+ where Οƒ_i are learnable parameters representing task uncertainty.
137
+
138
+ ## 🎨 Interface Preview
139
+
140
+ The web interface provides:
141
+ - Text input with real-time token count
142
+ - Compression level slider (20%-80%)
143
+ - Side-by-side original/summary comparison
144
+ - Emotion probability bars with color coding
145
+ - Interactive attention heatmap (click tokens to highlight attention)
146
+ - Downloadable results (JSON/CSV)
147
+
148
+ ## πŸ“ˆ Future Enhancements
149
+
150
+ - [ ] Add multilingual support (mBART)
151
+ - [ ] Implement beam search for better summaries
152
+ - [ ] Fine-tune on domain-specific corpora (medical, legal)
153
+ - [ ] Add semantic search across document embeddings
154
+ - [ ] Deploy as REST API with Docker
155
+ - [ ] Implement model distillation for mobile deployment
156
+
157
+ ## πŸ“š References
158
+
159
+ - Vaswani et al. (2017) - [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
160
+ - Lewis et al. (2019) - [BART: Denoising Sequence-to-Sequence Pre-training](https://arxiv.org/abs/1910.13461)
161
+ - Caruana (1997) - [Multitask Learning](https://link.springer.com/article/10.1023/A:1007379606734)
162
+ - Demszky et al. (2020) - [GoEmotions Dataset](https://arxiv.org/abs/2005.00547)
163
+
164
+ ## πŸ“„ License
165
+
166
+ GNU General Public License v3.0
167
+
168
+ ## πŸ‘€ Author
169
+
170
+ **Oliver Perrin**
171
+ - Portfolio: [oliverperrin.com](https://oliverperrin.com)
172
+ - LinkedIn: [linkedin.com/in/oliverperrin](https://linkedin.com/in/oliverperrin)
173
+ - Email: [email protected]
174
+
175
+ ---
src/app.py β†’ configs/data/datasets.yaml RENAMED
File without changes
configs/model/base.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ vocab_size: 32000
3
+ d_model: 512
4
+ num_encoder_layers: 6
5
+ num_decoder_layers: 6
6
+ num_heads: 8
7
+ d_ff: 2048
8
+ dropout: 0.1
9
+ max_seq_length: 512
10
+
11
+ tasks:
12
+ summarization:
13
+ enabled: true
14
+ decoder_layers: 6
15
+
16
+ emotion:
17
+ enabled: true
18
+ num_classes: 27
19
+ pool_strategy: "mean" # Options: mean, max, cls, attention
20
+
21
+ clustering:
22
+ enabled: true
23
+ embedding_dim: 128
24
+ normalize: true
25
+
26
+ training:
27
+ batch_size: 16
28
+ gradient_accumulation_steps: 2 # Effective batch = 32
29
+ learning_rate: 1e-4
30
+ weight_decay: 0.01
31
+ num_epochs: 10
32
+ warmup_steps: 1000
33
+ max_grad_norm: 1.0
34
+
35
+ scheduler:
36
+ type: "cosine" # Options: linear, cosine, polynomial
37
+
38
+ mixed_precision: true # Use AMP for faster training
39
+
40
+ data:
41
+ max_length: 512
42
+ summary_max_length: 128
43
+ train_split: 0.8
44
+ val_split: 0.1
45
+ test_split: 0.1
46
+
47
+ preprocessing:
48
+ lowercase: true
49
+ remove_stopwords: false
50
+ min_token_length: 3
src/emotion_classifier.py β†’ configs/model/large.yaml RENAMED
File without changes
configs/model/small.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # configs/model/small.yaml (for fast iteration)
2
+ model:
3
+ d_model: 256
4
+ num_encoder_layers: 4
5
+ num_decoder_layers: 4
6
+ num_heads: 8
7
+
8
+ training:
9
+ batch_size: 32 # ~4GB VRAM
10
+ gradient_accumulation_steps: 1
11
+ mixed_precision: true # Essential!
12
+
13
+ # configs/model/base.yaml (production)
14
+ model:
15
+ d_model: 512
16
+ num_encoder_layers: 6
17
+ num_decoder_layers: 6
18
+ num_heads: 8
19
+
20
+ training:
21
+ batch_size: 8 # ~8GB VRAM
22
+ gradient_accumulation_steps: 4 # Effective batch = 32
23
+ mixed_precision: true
src/pipeline.py β†’ configs/training/default.yaml RENAMED
File without changes
src/topic_model.py β†’ configs/training/full.yaml RENAMED
File without changes
configs/training/quick_test.yaml ADDED
File without changes
data/.gitkeep ADDED
File without changes
data/external/.gitkeep ADDED
File without changes
docker/Dockerfile ADDED
File without changes
docker/docker-compose.yml ADDED
File without changes
docs/api.md ADDED
File without changes
docs/architecture.md ADDED
File without changes
docs/training.md ADDED
File without changes
pyproject.toml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "leximind"
7
+ version = "0.1.0"
8
+ description = "Multi-Task Transformer for Document Analysis"
9
+ authors = [{name = "Oliver Perrin", email = "[email protected]"}]
10
+ readme = "README.md"
11
+ requires-python = ">=3.9"
12
+ license = {text = "GPL-3.0"}
13
+
14
+ dependencies = [
15
+ "torch>=2.0.0",
16
+ "transformers>=4.30.0",
17
+ "datasets>=2.14.0",
18
+ "tokenizers>=0.13.0",
19
+ "numpy>=1.24.0",
20
+ "pandas>=2.0.0",
21
+ "scikit-learn>=1.3.0",
22
+ "matplotlib>=3.7.0",
23
+ "seaborn>=0.12.0",
24
+ "tqdm>=4.65.0",
25
+ "pyyaml>=6.0",
26
+ "omegaconf>=2.3.0",
27
+ "tensorboard>=2.13.0",
28
+ "gradio>=3.35.0",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest>=7.4.0",
34
+ "pytest-cov>=4.1.0",
35
+ "black>=23.7.0",
36
+ "isort>=5.12.0",
37
+ "flake8>=6.0.0",
38
+ "mypy>=1.4.0",
39
+ "jupyter>=1.0.0",
40
+ "ipywidgets>=8.0.0",
41
+ ]
42
+
43
+ [tool.black]
44
+ line-length = 100
45
+ target-version = ['py39']
46
+
47
+ [tool.isort]
48
+ profile = "black"
49
+ line_length = 100
50
+
51
+ [tool.pytest.ini_options]
52
+ testpaths = ["tests"]
53
+ python_files = "test_*.py"
requirements-dev.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # requirements-dev.txt
2
+ pytest>=7.4.0
3
+ pytest-cov>=4.1.0
4
+ black>=23.7.0
5
+ isort>=5.12.0
6
+ flake8>=6.0.0
7
+ mypy>=1.4.0
8
+ jupyter>=1.0.0
9
+ ipywidgets>=8.0.0
requirements.txt CHANGED
@@ -1,13 +1,18 @@
1
- torch>=1.9.0
2
- transformers>=4.20.0
3
- scikit-learn>=1.0.0
4
- nltk>=3.7
5
- numpy>=1.21.0
6
- pandas>=1.3.0
7
- tensorflow>=2.12
8
- kaggle>=1.6.17
 
 
 
 
 
 
 
 
9
  requests>=2.31.0
10
- sentencepiece>=0.1.99
11
- tf-keras==2.20.1
12
- keras>=2.7.0
13
- tensorflow>=2.7.0
 
1
+ # requirements.txt
2
+ torch>=2.0.0
3
+ transformers>=4.30.0
4
+ datasets>=2.14.0
5
+ tokenizers>=0.13.0
6
+ numpy>=1.24.0
7
+ pandas>=2.0.0
8
+ scikit-learn>=1.3.0
9
+ matplotlib>=3.7.0
10
+ seaborn>=0.12.0
11
+ nltk>=3.8.0
12
+ tqdm>=4.65.0
13
+ pyyaml>=6.0
14
+ omegaconf>=2.3.0
15
+ tensorboard>=2.13.0
16
+ gradio>=3.35.0
17
  requests>=2.31.0
18
+ kagglehub>=0.2.0
 
 
 
scripts/test_gpu.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_gpu.py
2
+ import torch
3
+
4
+ print("=" * 50)
5
+ print("GPU Information")
6
+ print("=" * 50)
7
+
8
+ if torch.cuda.is_available():
9
+ gpu_name = torch.cuda.get_device_name(0)
10
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
11
+
12
+ print(f"βœ… GPU: {gpu_name}")
13
+ print(f"βœ… Memory: {gpu_memory:.2f} GB")
14
+
15
+ # Test tensor creation
16
+ x = torch.randn(1000, 1000, device='cuda')
17
+ y = torch.randn(1000, 1000, device='cuda')
18
+ z = x @ y
19
+
20
+ print(f"βœ… CUDA operations working!")
21
+ print(f"βœ… Current memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
22
+ print(f"βœ… Max memory allocated: {torch.cuda.max_memory_allocated(0) / 1e9:.2f} GB")
23
+ else:
24
+ print("❌ CUDA not available!")
25
+ print("Using CPU - training will be slow!")
26
+
27
+ print("=" * 50)
scripts/train.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # scripts/train.py
2
+ from src.training.trainer import Trainer
3
+ from src.utils.config import load_config
4
+
5
+ if __name__ == "__main__":
6
+ config = load_config("configs/training/default.yaml")
7
+ trainer = Trainer(config)
8
+ trainer.train()
setup.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="leximind",
5
+ version="0.1.0",
6
+ packages=find_packages(where="src"),
7
+ package_dir={"": "src"},
8
+ install_requires=[
9
+ "torch>=2.0.0",
10
+ "transformers>=4.30.0",
11
+ # ... (or read from requirements.txt)
12
+ ],
13
+ entry_points={
14
+ "console_scripts": [
15
+ "leximind-train=scripts.train:main",
16
+ "leximind-infer=scripts.inference:main",
17
+ ],
18
+ },
19
+ )
src/__init__.py ADDED
File without changes
src/api/__init__.py ADDED
File without changes
src/data/__init__.py ADDED
File without changes
src/{download_datasets.py β†’ data/download.py} RENAMED
File without changes
src/{preprocessing.py β†’ data/preprocessing.py} RENAMED
File without changes
src/inference/__init__.py ADDED
File without changes
src/{summarizer.py β†’ inference/baseline_summarizer.py} RENAMED
File without changes
src/models/__init__.py ADDED
File without changes
src/models/attention.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Attention mechanisms for Transformer architecture.
3
+
4
+ This module implements the core attention mechanisms used in the Transformer model:
5
+ - ScaledDotProductAttention: Fundamental attention operation
6
+ - MultiHeadAttention: Parallel attention with learned projections
7
+
8
+ Author: Oliver Perrin
9
+ Date: 2025-10-23
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import math
16
+ from typing import Optional, Tuple
17
+
18
+
19
+ class ScaledDotProductAttention(nn.Module):
20
+ """
21
+ Scaled Dot-Product Attention as described in "Attention Is All You Need".
22
+
23
+ Computes: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
24
+
25
+ The scaling factor (1/sqrt(d_k)) prevents the dot products from growing too large,
26
+ which would push the softmax into regions with extremely small gradients.
27
+
28
+ Args:
29
+ None - this module has no learnable parameters
30
+
31
+ Forward Args:
32
+ query: Query tensor of shape (batch, seq_len, d_k)
33
+ key: Key tensor of shape (batch, seq_len, d_k)
34
+ value: Value tensor of shape (batch, seq_len, d_v)
35
+ mask: Optional mask tensor of shape (batch, seq_len, seq_len)
36
+ True/1 values indicate positions to attend to, False/0 to mask
37
+
38
+ Returns:
39
+ output: Attention output of shape (batch, seq_len, d_v)
40
+ attention_weights: Attention probability matrix (batch, seq_len, seq_len)
41
+
42
+ TODO: Implement the forward method below
43
+ Research questions to answer:
44
+ 1. Why divide by sqrt(d_k)? What happens without it?
45
+ 2. How does masking work? When do we need it?
46
+ 3. What's the computational complexity?
47
+ """
48
+
49
+ def __init__(self):
50
+ super().__init__()
51
+ # TODO: Do you need any parameters here?
52
+ pass
53
+
54
+ def forward(
55
+ self,
56
+ query: torch.Tensor,
57
+ key: torch.Tensor,
58
+ value: torch.Tensor,
59
+ mask: Optional[torch.Tensor] = None
60
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ """
62
+ TODO: Implement this method
63
+
64
+ Steps:
65
+ 1. Compute attention scores: scores = query @ key.transpose(-2, -1)
66
+ 2. Scale by sqrt(d_k)
67
+ 3. Apply mask if provided (set masked positions to -inf before softmax)
68
+ 4. Apply softmax to get attention weights
69
+ 5. Compute output: output = attention_weights @ value
70
+ 6. Return both output and attention_weights
71
+ """
72
+ pass
73
+
74
+
75
+ # TODO: After you implement ScaledDotProductAttention, we'll add MultiHeadAttention
src/training/__init__.py ADDED
File without changes
src/utils/__init__.py ADDED
File without changes
src/utils/config.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional, Dict, Any
4
+ import yaml
5
+ from omegaconf import OmegaConf
6
+
7
+ @dataclass
8
+ class ModelConfig:
9
+ vocab_size: int
10
+ d_model: int
11
+ num_encoder_layers: int
12
+ num_decoder_layers: int
13
+ num_heads: int
14
+ d_ff: int
15
+ dropout: float
16
+ max_seq_length: int
17
+
18
+ @dataclass
19
+ class TrainingConfig:
20
+ batch_size: int
21
+ learning_rate: float
22
+ num_epochs: int
23
+ warmup_steps: int
24
+ max_grad_norm: float
25
+ mixed_precision: bool
26
+
27
+ @dataclass
28
+ class Config:
29
+ model: ModelConfig
30
+ training: TrainingConfig
31
+ data: Dict[str, Any]
32
+ tasks: Dict[str, Any]
33
+
34
+ def load_config(config_path: str) -> Config:
35
+ """Load config from YAML and convert to structured dataclass."""
36
+ cfg = OmegaConf.load(config_path)
37
+
38
+ # Convert to dataclass for type safety
39
+ model_cfg = ModelConfig(**cfg.model)
40
+ training_cfg = TrainingConfig(**cfg.training)
41
+
42
+ return Config(
43
+ model=model_cfg,
44
+ training=training_cfg,
45
+ data=dict(cfg.data),
46
+ tasks=dict(cfg.tasks)
47
+ )
src/visualization/__init__.py ADDED
File without changes
tests/__init__.py ADDED
File without changes
tests/test_models/test_attention.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for attention mechanisms.
3
+
4
+ Run with: pytest tests/test_models/test_attention.py -v
5
+ """
6
+
7
+ import pytest
8
+ import torch
9
+ from src.models.attention import ScaledDotProductAttention
10
+
11
+
12
+ class TestScaledDotProductAttention:
13
+ """Test suite for ScaledDotProductAttention."""
14
+
15
+ def test_output_shape(self):
16
+ """Test that output shapes are correct."""
17
+ attention = ScaledDotProductAttention()
18
+ batch_size, seq_len, d_k = 2, 10, 64
19
+
20
+ Q = torch.randn(batch_size, seq_len, d_k)
21
+ K = torch.randn(batch_size, seq_len, d_k)
22
+ V = torch.randn(batch_size, seq_len, d_k)
23
+
24
+ output, weights = attention(Q, K, V)
25
+
26
+ assert output.shape == (batch_size, seq_len, d_k)
27
+ assert weights.shape == (batch_size, seq_len, seq_len)
28
+
29
+ def test_attention_weights_sum_to_one(self):
30
+ """Test that attention weights are a valid probability distribution."""
31
+ attention = ScaledDotProductAttention()
32
+ batch_size, seq_len, d_k = 2, 10, 64
33
+
34
+ Q = K = V = torch.randn(batch_size, seq_len, d_k)
35
+ _, weights = attention(Q, K, V)
36
+
37
+ # Each row should sum to 1 (probability distribution over keys)
38
+ row_sums = weights.sum(dim=-1)
39
+ assert torch.allclose(row_sums, torch.ones(batch_size, seq_len), atol=1e-6)
40
+
41
+ def test_masking(self):
42
+ """Test that masking properly zeros out attention to masked positions."""
43
+ attention = ScaledDotProductAttention()
44
+ batch_size, seq_len, d_k = 1, 5, 64
45
+
46
+ Q = K = V = torch.randn(batch_size, seq_len, d_k)
47
+
48
+ # Create mask: only attend to first 3 positions
49
+ mask = torch.zeros(batch_size, seq_len, seq_len, dtype=torch.bool)
50
+ mask[:, :, :3] = True
51
+
52
+ _, weights = attention(Q, K, V, mask)
53
+
54
+ # Positions 3 and 4 should have zero attention weight
55
+ assert torch.allclose(weights[:, :, 3:], torch.zeros(batch_size, seq_len, 2), atol=1e-6)
56
+
57
+ # TODO: Add more tests as you understand the mechanism better
58
+
59
+
60
+ if __name__ == "__main__":
61
+ pytest.main([__file__, "-v"])