Skip to content

Commit ccf3552

Browse files
committed
Implemented MultiAttentionHead
1 parent 4ff66ca commit ccf3552

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

tiny_transformer.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
context_size = 128
1010
num_epochs = 10000
1111
embedding_dim = 128
12+
num_heads = 4
1213

1314
ds = load_dataset("minnbanya/nlp-a2-sherlock")
1415
# Concatenate all the text in the train and validation sets
@@ -56,12 +57,12 @@ def get_batch_data(data_type="train"):
5657

5758
# Single Attention Head to process the context of input data
5859
class SelfAttention(nn.Module):
59-
def __init__(self, embedding_dim):
60+
def __init__(self, head_size):
6061
super().__init__()
61-
self.query = nn.Linear(embedding_dim, embedding_dim)
62-
self.key = nn.Linear(embedding_dim, embedding_dim)
63-
self.value = nn.Linear(embedding_dim, embedding_dim)
64-
self.register_buffer("tril", torch.tril(torch.ones(context_size, context_size)))
62+
self.query = nn.Linear(embedding_dim, head_size)
63+
self.key = nn.Linear(embedding_dim, head_size)
64+
self.value = nn.Linear(embedding_dim, head_size)
65+
self.register_buffer("tril", torch.tril(torch.ones(context_size, context_size).to(device)))
6566

6667
def forward(self, x):
6768
_, T, _ = x.shape
@@ -81,17 +82,31 @@ def forward(self, x):
8182
output = torch.matmul(attention, v)
8283
return output
8384

84-
class Transformer(nn.Module):
85+
# Multi Attention Head to process the context of input data.
86+
# Each head processes the context differently and the outputs are concatenated to get the final output.
87+
class MultiAttentionHead(nn.Module):
88+
def __init__(self, head_size):
89+
super().__init__()
90+
self.attention_heads = nn.ModuleList([SelfAttention(embedding_dim // num_heads) for _ in range(num_heads)])
91+
92+
def forward(self, x):
93+
x = torch.cat([head(x) for head in self.attention_heads], dim=2)
94+
return x
8595

96+
class Transformer(nn.Module):
8697
def __init__(self, vocab_size, context_size):
87-
super(Transformer, self).__init__()
98+
super().__init__()
8899
self.embedding = nn.Embedding(vocab_size, embedding_dim)
89-
self.attention_head = SelfAttention(embedding_dim)
100+
self.positional_embedding = nn.Embedding(context_size, embedding_dim)
101+
self.attention_heads = MultiAttentionHead(embedding_dim // num_heads)
90102
self.linear = nn.Linear(embedding_dim, vocab_size)
91103

92104
def forward(self, x):
93-
x = self.embedding(x)
94-
x = self.attention_head(x)
105+
_, T = x.shape
106+
token_embed = self.embedding(x)
107+
position_emb = self.positional_embedding(torch.arange(T, device=device))
108+
x = token_embed + position_emb
109+
x = self.attention_heads(x)
95110
x = self.linear(x)
96111
return x
97112

0 commit comments

Comments
 (0)