Skip to content

Commit 4477038

Browse files
committed
Added single attention head
1 parent b1c3669 commit 4477038

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

tiny_transformer.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
batch_size = 64
88
learning_rate = 1e-3
99
context_size = 128
10-
num_epochs = 2000
11-
embedding_dim = 10
10+
num_epochs = 10000
11+
embedding_dim = 128
1212

1313
ds = load_dataset("minnbanya/nlp-a2-sherlock")
1414
# Concatenate all the text in the train and validation sets
@@ -54,16 +54,36 @@ def get_batch_data(data_type="train"):
5454
x, y = torch.stack(x), torch.stack(y)
5555
return x, y
5656

57+
class SelfAttention(nn.Module):
58+
def __init__(self, embedding_dim):
59+
super().__init__()
60+
self.query = nn.Linear(embedding_dim, embedding_dim)
61+
self.key = nn.Linear(embedding_dim, embedding_dim)
62+
self.value = nn.Linear(embedding_dim, embedding_dim)
63+
self.register_buffer("tril", torch.tril(torch.ones(context_size, context_size)))
64+
65+
def forward(self, x):
66+
_, T, _ = x.shape
67+
q = self.query(x)
68+
k = self.key(x)
69+
v = self.value(x)
70+
attention = torch.matmul(q, k.transpose(-2, -1)) * context_size ** -0.5
71+
attention = attention.masked_fill(self.tril[:T, :T]== 0, float("-inf"))
72+
attention = torch.nn.functional.softmax(attention, dim=-1)
73+
output = torch.matmul(attention, v)
74+
return output
5775

5876
class Transformer(nn.Module):
5977

6078
def __init__(self, vocab_size, context_size):
6179
super(Transformer, self).__init__()
6280
self.embedding = nn.Embedding(vocab_size, embedding_dim)
81+
self.attention_head = SelfAttention(embedding_dim)
6382
self.linear = nn.Linear(embedding_dim, vocab_size)
6483

6584
def forward(self, x):
6685
x = self.embedding(x)
86+
x = self.attention_head(x)
6787
x = self.linear(x)
6888
return x
6989

@@ -94,13 +114,16 @@ def validate(model):
94114
print(f"Validation Loss: {loss.item()}")
95115

96116
def generate(model, start_text, num_chars):
97-
chars = encode(start_text)
117+
chars = torch.tensor(encode(start_text)).to(device)
118+
chars = chars.view(1, len(chars))
98119
for i in range(num_chars):
99-
x = torch.tensor(chars[-context_size:]).unsqueeze(0).to(device)
100-
output = model(x)
120+
output = model(chars)
101121
prob = torch.nn.functional.softmax(output[0, -1], dim=0)
102122
idx = torch.multinomial(prob, num_samples=1)
103-
print(decode(idx.cpu().numpy()), end="")
123+
char = decode(idx.cpu().numpy())
124+
print(char, end="")
125+
chars = torch.cat([chars, idx.view(1, 1)], dim=1)
126+
chars = chars[:, -context_size:]
104127

105128

106129
model = Transformer(len(vocabulary), context_size)

0 commit comments

Comments
 (0)