| |
| |
|
|
| |
| import matplotlib.pyplot as plt |
| import torch.nn.functional as F |
| import torch.nn as nn |
| import torch |
| import numpy as np |
| import math |
|
|
| |
|
|
| |
| seed = 42 |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
|
|
| |
| embdim = 256 |
| headdim = 64 |
| tokens = torch.randn(1, 5, embdim) |
|
|
| |
| Wq = torch.randn(embdim, headdim) / math.sqrt(embdim) |
| Wk = torch.randn(embdim, headdim) / math.sqrt(embdim) |
| Wv = torch.randn(embdim, embdim) / math.sqrt(embdim) |
|
|
| |
| qis = torch.einsum("BSE,EH->BSH", tokens, Wq) |
| kis = torch.einsum("BTE,EH->BTH", tokens, Wk) |
| vis = torch.einsum("BTE,EF->BTF", tokens, Wv) |
|
|
| |
| random_mat1 = torch.randn(2, 5, 4) |
| random_mat2 = torch.randn(2, 5, 4) |
|
|
| |
| torch.matmul(random_mat1, random_mat2.transpose(1, 2)) |
| print(qis.shape) |
| print(kis.shape) |
| |
| |
|
|
|
|
| scoremat = torch.matmul(qis, kis.transpose(1, 2)) |
| attmat = F.softmax(scoremat / math.sqrt(headdim), dim=2) |
|
|
| |
| zis = torch.einsum("BST,BTF->BSF", attmat, vis) |
|
|
| |
| attn_torch = F.scaled_dot_product_attention(qis, kis, vis) |
| assert (torch.allclose(attn_torch, zis, atol=1E-6, rtol=1E-6)) |
|
|
| |
| embdim = 768 |
| headcnt = 12 |
| headdim = embdim // headcnt |
| |
| assert headdim * headcnt == embdim |
| tokens = torch.randn(1, 5, embdim) |
|
|
| |
| Wq = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) |
| Wk = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) |
| Wv = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) |
|
|
| print(Wq.shape) |
| print(Wk.shape) |
| print(Wv.shape) |
|
|
| batch, token_num, _ = tokens.shape |
| |
|
|
| |
| qis = torch.einsum("BSE,EH->BSH", tokens, Wq) |
| kis = torch.einsum("BTE,EH->BTH", tokens, Wk) |
| vis = torch.einsum("BTE,EH->BTH", tokens, Wv) |
| |
|
|
| |
| |
| qis_mh = qis.view(batch, token_num, headcnt, headdim) |
| kis_mh = kis.view(batch, token_num, headcnt, headdim) |
| vis_mh = vis.view(batch, token_num, headcnt, headdim) |
|
|
| scoremat_mh = torch.einsum("BSHC,BTHC->BHST", qis_mh, kis_mh) |
| print(scoremat_mh.shape) |
|
|
| |
|
|
| attmat_mh = F.softmax(scoremat_mh / math.sqrt(headdim), dim=-1) |
| zis_mh = torch.einsum("BCST,BTCH->BSCH", attmat_mh, vis_mh) |
| zis = zis_mh.reshape(batch, token_num, headcnt * headdim) |
|
|
| |
|
|
| |
| mha = nn.MultiheadAttention(embdim, headcnt, batch_first=True, ) |
| print(mha.in_proj_weight.shape) |
| mha.in_proj_weight.data = torch.cat([Wq, Wk, Wv], dim=1).T |
| attn_out, attn_weights = mha(tokens, tokens, tokens, average_attn_weights=False, ) |
|
|
| |
| assert torch.allclose(attmat_mh, attn_weights, atol=1e-6, rtol=1e-6) |
|
|
| print(attn_weights.shape) |
| print(attn_out.shape) |
|
|
| |
| |
|
|
| attn_mask = torch.ones(token_num, token_num, ) |
| attn_mask = -1E4 * torch.triu(attn_mask, 1) |
| print(attn_mask) |
| scoremat_mh_msk = torch.einsum("BSCH,BTCH->BCST", qis_mh, kis_mh) |
| scoremat_mh_msk += attn_mask |
| attmat_mh_msk = F.softmax(scoremat_mh_msk / math.sqrt(headdim), dim=-1) |
| zis_mh_msk = torch.einsum("BCST,BTCH->BSCH", attmat_mh_msk, vis_mh) |
| zis_msk = zis_mh_msk.reshape(batch, token_num, headcnt * headdim) |
|
|
| attn_out_causal, attn_weights_causal = mha(tokens, tokens, tokens, average_attn_weights=False, attn_mask=attn_mask) |
|
|
| |
| plt.figure() |
| for head in range(headcnt): |
| plt.subplot(3, 4, head + 1) |
| plt.imshow(attn_weights_causal[0, head].detach().numpy()) |
| plt.title(f"head {head}") |
| plt.axis("off") |
| plt.show() |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| class TransformerBlock(nn.Module): |
|
|
| def __init__(self, embdim:int, headcnt, *args, dropout=0.0, **kwargs) -> None: |
| super().__init__(*args, **kwargs) |
| self.ln1 = nn.LayerNorm(embdim) |
| self.ln2 = nn.LayerNorm(embdim) |
| self.attn = nn.MultiheadAttention(embdim, headcnt, batch_first=True,) |
| self.ffn = nn.Sequential( |
| nn.Linear(embdim, 4 * embdim), |
| nn.GELU(), |
| nn.Linear(4 * embdim, embdim), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x, is_causal=True): |
| """ |
| Input to forward function is matrix with shape B, S, E, we can assume therefore that input and positional embeddings have been added. |
| """ |
| batch, token_num, hidden_dim = x.shape |
| if is_causal: |
| attn_mask = torch.ones(token_num, token_num,) |
| attn_mask = -1E4 * torch.triu(attn_mask,1) |
| else: |
| attn_mask = None |
|
|
| residue = x |
| attn_output, attn_weights = self.attn(x, x, x, average_attn_weights=False, ) |
| x = residue + attn_output |
| x = self.ln1(x) |
| residue = x |
| ffn_output = self.ffn(x) |
| output = residue + ffn_output |
| return output |
|
|
|
|
|
|
| if __name__ == "__main__": |
| |
| print("Testing the Transformer Block") |
| transformer_block = TransformerBlock(embdim, headcnt) |
| tokens = torch.randn(1, 5, embdim) |
| output = transformer_block(tokens) |
| print(output.shape) |
|
|