Skip to content

Commit 072e0a4

Browse files
committed
Implemented Transformer Blocks
1 parent ac6281f commit 072e0a4

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

tiny_transformer.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
embedding_dim = 128
1212
num_heads = 4
1313
dropout_rate = 0.1
14+
num_blocks = 4
1415

1516
ds = load_dataset("minnbanya/nlp-a2-sherlock")
1617
# Concatenate all the text in the train and validation sets
@@ -88,11 +89,17 @@ def forward(self, x):
8889
# Each head outputs a vector of size embedding_dim // num_heads
8990
class MultiAttentionHead(nn.Module):
9091
def __init__(self, head_size):
91-
super().__init__()
92+
super().__init__()
93+
94+
# Multiple Self Attention Heads
9295
self.attention_heads = nn.ModuleList([SelfAttention(embedding_dim // num_heads) for _ in range(num_heads)])
96+
# Projection layer for additional processing of the output of the attention heads.
97+
self.projection = nn.Linear(embedding_dim, embedding_dim)
98+
self.dropout = nn.Dropout(dropout_rate)
9399

94100
def forward(self, x):
95101
x = torch.cat([head(x) for head in self.attention_heads], dim=2)
102+
x = self.dropout(self.projection(x))
96103
return x
97104

98105
# Feed Forward Network to process the output of the attention heads.
@@ -109,22 +116,44 @@ def __init__(self):
109116
def forward(self, x):
110117
return self.ffwd(x)
111118

119+
# Each Transformer Block consists of a Multi Attention Head and a Feed Forward Network.
120+
# It also has a Layer Normalization layer to normalize the output of the Multi Attention Head and the Feed Forward Network
121+
# Residual connections are used to add the output of the Multi Attention Head and the Feed Forward Network to the input.
122+
class TransformerBlock(nn.Module):
123+
def __init__(self):
124+
super().__init__()
125+
self.attention_heads = MultiAttentionHead(embedding_dim // num_heads)
126+
self.feed_forward = FeedForward()
127+
self.norm1 = nn.LayerNorm(embedding_dim)
128+
self.norm2 = nn.LayerNorm(embedding_dim)
129+
130+
def forward(self, x):
131+
# Residual connection is added to the output of the Multi Attention Head and the Feed Forward Network.
132+
x = self.norm1(x + self.attention_heads(x))
133+
x = self.norm2(x + self.feed_forward(x))
134+
return x
135+
136+
# The Transformer model consists of an Embedding layer, Positional Embedding layer, Transformer Blocks, and a Linear layer.
112137
class Transformer(nn.Module):
113138
def __init__(self, vocab_size, context_size):
114139
super().__init__()
140+
# Embedding Layer to convert the tokens to vectors.
115141
self.embedding = nn.Embedding(vocab_size, embedding_dim)
142+
# Poistional Embedding Layer to add the position of the tokens to the vectors.
116143
self.positional_embedding = nn.Embedding(context_size, embedding_dim)
117-
self.attention_heads = MultiAttentionHead(embedding_dim // num_heads)
118-
self.feed_forward = FeedForward()
144+
# Transformer Blocks to process the context of the input data.
145+
self.block = nn.Sequential(*[TransformerBlock() for _ in range(num_blocks)])
146+
self.layer_norm = nn.LayerNorm(embedding_dim)
119147
self.linear = nn.Linear(embedding_dim, vocab_size)
120148

121149
def forward(self, x):
122150
_, T = x.shape
123151
token_embed = self.embedding(x)
152+
# Calculate the positionial embedding for the input data.
124153
position_emb = self.positional_embedding(torch.arange(T, device=device))
125154
x = token_embed + position_emb
126-
x = self.attention_heads(x)
127-
x = self.linear(x)
155+
x = self.block(x)
156+
x = self.linear(self.layer_norm(x))
128157
return x
129158

130159

0 commit comments

Comments
 (0)