Skip to content

Commit b1c3669

Browse files
committed
Added training validation and generate code
1 parent 7c2c189 commit b1c3669

File tree

1 file changed

+61
-3
lines changed

1 file changed

+61
-3
lines changed

tiny_transformer.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import torch.nn as nn
33
from datasets import load_dataset
44

5-
5+
device = torch.device("mps")
66
# Network Parameters
7-
batch_size = 1
7+
batch_size = 64
88
learning_rate = 1e-3
99
context_size = 128
10+
num_epochs = 2000
11+
embedding_dim = 10
1012

1113
ds = load_dataset("minnbanya/nlp-a2-sherlock")
1214
# Concatenate all the text in the train and validation sets
@@ -17,6 +19,7 @@
1719
vocabulary = sorted(set(list(train_data + validation_data)))
1820
print(f"Vocabulary size: {len(vocabulary)}")
1921
print(f"Vocabulary : {vocabulary}")
22+
vocab_size = len(vocabulary)
2023

2124

2225
# Create the character to token mapping.
@@ -51,4 +54,59 @@ def get_batch_data(data_type="train"):
5154
x, y = torch.stack(x), torch.stack(y)
5255
return x, y
5356

54-
x,y = get_batch_data()
57+
58+
class Transformer(nn.Module):
59+
60+
def __init__(self, vocab_size, context_size):
61+
super(Transformer, self).__init__()
62+
self.embedding = nn.Embedding(vocab_size, embedding_dim)
63+
self.linear = nn.Linear(embedding_dim, vocab_size)
64+
65+
def forward(self, x):
66+
x = self.embedding(x)
67+
x = self.linear(x)
68+
return x
69+
70+
71+
def train(model):
72+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
73+
loss_fn = nn.CrossEntropyLoss()
74+
print("Training started")
75+
for epoch in range(num_epochs):
76+
x, y = get_batch_data()
77+
x = x.to(device)
78+
y = y.to(device)
79+
output = model(x)
80+
loss = loss_fn(output.view(-1, vocab_size), y.view(-1))
81+
loss.backward()
82+
optimizer.step()
83+
optimizer.zero_grad()
84+
if epoch % 100 == 0:
85+
print(f"Epoch: {epoch}, Loss: {loss.item()}")
86+
87+
def validate(model):
88+
loss_fn = nn.CrossEntropyLoss()
89+
x, y = get_batch_data("validation")
90+
x = x.to(device)
91+
y = y.to(device)
92+
output = model(x)
93+
loss = loss_fn(output.view(-1, vocab_size), y.view(-1))
94+
print(f"Validation Loss: {loss.item()}")
95+
96+
def generate(model, start_text, num_chars):
97+
chars = encode(start_text)
98+
for i in range(num_chars):
99+
x = torch.tensor(chars[-context_size:]).unsqueeze(0).to(device)
100+
output = model(x)
101+
prob = torch.nn.functional.softmax(output[0, -1], dim=0)
102+
idx = torch.multinomial(prob, num_samples=1)
103+
print(decode(idx.cpu().numpy()), end="")
104+
105+
106+
model = Transformer(len(vocabulary), context_size)
107+
model = model.to(device)
108+
model.train()
109+
train(model)
110+
with torch.no_grad():
111+
validate(model)
112+
generate(model, "Sherlock Holmes", 100)

0 commit comments

Comments
 (0)