99context_size = 128
1010num_epochs = 10000
1111embedding_dim = 128
12+ num_heads = 4
1213
1314ds = load_dataset ("minnbanya/nlp-a2-sherlock" )
1415# Concatenate all the text in the train and validation sets
@@ -56,12 +57,12 @@ def get_batch_data(data_type="train"):
5657
5758# Single Attention Head to process the context of input data
5859class SelfAttention (nn .Module ):
59- def __init__ (self , embedding_dim ):
60+ def __init__ (self , head_size ):
6061 super ().__init__ ()
61- self .query = nn .Linear (embedding_dim , embedding_dim )
62- self .key = nn .Linear (embedding_dim , embedding_dim )
63- self .value = nn .Linear (embedding_dim , embedding_dim )
64- self .register_buffer ("tril" , torch .tril (torch .ones (context_size , context_size )))
62+ self .query = nn .Linear (embedding_dim , head_size )
63+ self .key = nn .Linear (embedding_dim , head_size )
64+ self .value = nn .Linear (embedding_dim , head_size )
65+ self .register_buffer ("tril" , torch .tril (torch .ones (context_size , context_size ). to ( device ) ))
6566
6667 def forward (self , x ):
6768 _ , T , _ = x .shape
@@ -81,17 +82,31 @@ def forward(self, x):
8182 output = torch .matmul (attention , v )
8283 return output
8384
84- class Transformer (nn .Module ):
85+ # Multi Attention Head to process the context of input data.
86+ # Each head processes the context differently and the outputs are concatenated to get the final output.
87+ class MultiAttentionHead (nn .Module ):
88+ def __init__ (self , head_size ):
89+ super ().__init__ ()
90+ self .attention_heads = nn .ModuleList ([SelfAttention (embedding_dim // num_heads ) for _ in range (num_heads )])
91+
92+ def forward (self , x ):
93+ x = torch .cat ([head (x ) for head in self .attention_heads ], dim = 2 )
94+ return x
8595
96+ class Transformer (nn .Module ):
8697 def __init__ (self , vocab_size , context_size ):
87- super (Transformer , self ).__init__ ()
98+ super ().__init__ ()
8899 self .embedding = nn .Embedding (vocab_size , embedding_dim )
89- self .attention_head = SelfAttention (embedding_dim )
100+ self .positional_embedding = nn .Embedding (context_size , embedding_dim )
101+ self .attention_heads = MultiAttentionHead (embedding_dim // num_heads )
90102 self .linear = nn .Linear (embedding_dim , vocab_size )
91103
92104 def forward (self , x ):
93- x = self .embedding (x )
94- x = self .attention_head (x )
105+ _ , T = x .shape
106+ token_embed = self .embedding (x )
107+ position_emb = self .positional_embedding (torch .arange (T , device = device ))
108+ x = token_embed + position_emb
109+ x = self .attention_heads (x )
95110 x = self .linear (x )
96111 return x
97112
0 commit comments