Skip to content

Commit cda13f0

Browse files
committed
Added code to load and save models and continue from last epoch
1 parent 072e0a4 commit cda13f0

File tree

1 file changed

+58
-13
lines changed

1 file changed

+58
-13
lines changed

tiny_transformer.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
import torch
22
import torch.nn as nn
33
from datasets import load_dataset
4+
import os
45

6+
# File path for saving the model
7+
MODEL_PATH = "tiny_transformer.pth"
8+
9+
#Config
510
device = torch.device("mps")
11+
torch.set_default_dtype(torch.bfloat16)
12+
#
613
# Network Parameters
7-
batch_size = 64
8-
learning_rate = 1e-3
9-
context_size = 128
1014
num_epochs = 10000
11-
embedding_dim = 128
12-
num_heads = 4
13-
dropout_rate = 0.1
14-
num_blocks = 4
15+
batch_size = 64
16+
learning_rate = 3e-4
17+
dropout_rate = 0.2
18+
context_size = 256
19+
embedding_dim = 384
20+
num_heads = 6
21+
num_blocks = 6
1522

1623
ds = load_dataset("minnbanya/nlp-a2-sherlock")
1724
# Concatenate all the text in the train and validation sets
@@ -65,6 +72,7 @@ def __init__(self, head_size):
6572
self.key = nn.Linear(embedding_dim, head_size)
6673
self.value = nn.Linear(embedding_dim, head_size)
6774
self.register_buffer("tril", torch.tril(torch.ones(context_size, context_size).to(device)))
75+
self.dropout = nn.Dropout(dropout_rate)
6876

6977
def forward(self, x):
7078
_, T, _ = x.shape
@@ -80,6 +88,8 @@ def forward(self, x):
8088
attention = attention.masked_fill(self.tril[:T, :T]== 0, float("-inf"))
8189
# softmax is applied to the attention weights to get the final attention weights.
8290
attention = torch.nn.functional.softmax(attention, dim=-1)
91+
# Dropout is applied to the attention weights
92+
attention = self.dropout(attention)
8393
# The attention weights are multiplied with the value vector to get the final output.
8494
output = torch.matmul(attention, v)
8595
return output
@@ -157,22 +167,56 @@ def forward(self, x):
157167
return x
158168

159169

160-
# Function to train the model
170+
# Function to save the model
171+
def save_model(model, optimizer, epoch, loss):
172+
checkpoint = {
173+
'model_state_dict': model.state_dict(),
174+
'optimizer_state_dict': optimizer.state_dict(),
175+
'epoch': epoch,
176+
'loss': loss
177+
}
178+
torch.save(checkpoint, MODEL_PATH)
179+
print(f"Model saved at epoch {epoch} with loss {loss:.4f}")
180+
181+
# Function to load the model
182+
def load_model(model, optimizer):
183+
if os.path.exists(MODEL_PATH):
184+
checkpoint = torch.load(MODEL_PATH, map_location=device)
185+
model.load_state_dict(checkpoint['model_state_dict'])
186+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
187+
start_epoch = checkpoint['epoch']
188+
loss = checkpoint['loss']
189+
print(f"Model loaded from epoch {start_epoch} with loss {loss:.4f}")
190+
return model, optimizer, start_epoch, loss
191+
else:
192+
print("No saved model found. Training from scratch.")
193+
return model, optimizer, 0, None
194+
195+
# Function to train the model and save it
161196
def train(model):
162197
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
163198
loss_fn = nn.CrossEntropyLoss()
199+
200+
# Load model if saved
201+
model, optimizer, start_epoch, _ = load_model(model, optimizer)
202+
164203
print("Training started")
165-
for epoch in range(num_epochs):
204+
for epoch in range(start_epoch, num_epochs): # Continue training from last saved epoch
166205
x, y = get_batch_data()
167-
x = x.to(device)
168-
y = y.to(device)
206+
x, y = x.to(device), y.to(device)
207+
169208
output = model(x)
170209
loss = loss_fn(output.view(-1, vocab_size), y.view(-1))
210+
171211
loss.backward()
172212
optimizer.step()
173213
optimizer.zero_grad()
214+
215+
# Save model every 100 epochs
174216
if epoch % 100 == 0:
175-
print(f"Epoch: {epoch}, Loss: {loss.item()}")
217+
print(f"Epoch: {epoch}, Loss: {loss.item():.4f}")
218+
if epoch % 1000 == 0:
219+
save_model(model, optimizer, epoch, loss.item())
176220

177221
# Function to validate the model
178222
def validate(model):
@@ -186,6 +230,7 @@ def validate(model):
186230

187231
#Function to generate text from the model
188232
def generate(model, start_text, num_chars):
233+
print(start_text, end="")
189234
chars = torch.tensor(encode(start_text)).to(device)
190235
chars = chars.view(1, len(chars))
191236
for i in range(num_chars):
@@ -204,4 +249,4 @@ def generate(model, start_text, num_chars):
204249
train(model)
205250
with torch.no_grad():
206251
validate(model)
207-
generate(model, "Sherlock Holmes", 100)
252+
generate(model, "Sherlock Holmes", 1000)

0 commit comments

Comments
 (0)