1111embedding_dim = 128
1212num_heads = 4
1313dropout_rate = 0.1
14+ num_blocks = 4
1415
1516ds = load_dataset ("minnbanya/nlp-a2-sherlock" )
1617# Concatenate all the text in the train and validation sets
@@ -88,11 +89,17 @@ def forward(self, x):
8889# Each head outputs a vector of size embedding_dim // num_heads
8990class MultiAttentionHead (nn .Module ):
9091 def __init__ (self , head_size ):
91- super ().__init__ ()
92+ super ().__init__ ()
93+
94+ # Multiple Self Attention Heads
9295 self .attention_heads = nn .ModuleList ([SelfAttention (embedding_dim // num_heads ) for _ in range (num_heads )])
96+ # Projection layer for additional processing of the output of the attention heads.
97+ self .projection = nn .Linear (embedding_dim , embedding_dim )
98+ self .dropout = nn .Dropout (dropout_rate )
9399
94100 def forward (self , x ):
95101 x = torch .cat ([head (x ) for head in self .attention_heads ], dim = 2 )
102+ x = self .dropout (self .projection (x ))
96103 return x
97104
98105# Feed Forward Network to process the output of the attention heads.
@@ -109,22 +116,44 @@ def __init__(self):
109116 def forward (self , x ):
110117 return self .ffwd (x )
111118
119+ # Each Transformer Block consists of a Multi Attention Head and a Feed Forward Network.
120+ # It also has a Layer Normalization layer to normalize the output of the Multi Attention Head and the Feed Forward Network
121+ # Residual connections are used to add the output of the Multi Attention Head and the Feed Forward Network to the input.
122+ class TransformerBlock (nn .Module ):
123+ def __init__ (self ):
124+ super ().__init__ ()
125+ self .attention_heads = MultiAttentionHead (embedding_dim // num_heads )
126+ self .feed_forward = FeedForward ()
127+ self .norm1 = nn .LayerNorm (embedding_dim )
128+ self .norm2 = nn .LayerNorm (embedding_dim )
129+
130+ def forward (self , x ):
131+ # Residual connection is added to the output of the Multi Attention Head and the Feed Forward Network.
132+ x = self .norm1 (x + self .attention_heads (x ))
133+ x = self .norm2 (x + self .feed_forward (x ))
134+ return x
135+
136+ # The Transformer model consists of an Embedding layer, Positional Embedding layer, Transformer Blocks, and a Linear layer.
112137class Transformer (nn .Module ):
113138 def __init__ (self , vocab_size , context_size ):
114139 super ().__init__ ()
140+ # Embedding Layer to convert the tokens to vectors.
115141 self .embedding = nn .Embedding (vocab_size , embedding_dim )
142+ # Poistional Embedding Layer to add the position of the tokens to the vectors.
116143 self .positional_embedding = nn .Embedding (context_size , embedding_dim )
117- self .attention_heads = MultiAttentionHead (embedding_dim // num_heads )
118- self .feed_forward = FeedForward ()
144+ # Transformer Blocks to process the context of the input data.
145+ self .block = nn .Sequential (* [TransformerBlock () for _ in range (num_blocks )])
146+ self .layer_norm = nn .LayerNorm (embedding_dim )
119147 self .linear = nn .Linear (embedding_dim , vocab_size )
120148
121149 def forward (self , x ):
122150 _ , T = x .shape
123151 token_embed = self .embedding (x )
152+ # Calculate the positionial embedding for the input data.
124153 position_emb = self .positional_embedding (torch .arange (T , device = device ))
125154 x = token_embed + position_emb
126- x = self .attention_heads (x )
127- x = self .linear (x )
155+ x = self .block (x )
156+ x = self .linear (self . layer_norm ( x ) )
128157 return x
129158
130159
0 commit comments