11import torch
22import torch .nn as nn
33from datasets import load_dataset
4+ import os
45
6+ # File path for saving the model
7+ MODEL_PATH = "tiny_transformer.pth"
8+
9+ #Config
510device = torch .device ("mps" )
11+ torch .set_default_dtype (torch .bfloat16 )
12+ #
613# Network Parameters
7- batch_size = 64
8- learning_rate = 1e-3
9- context_size = 128
1014num_epochs = 10000
11- embedding_dim = 128
12- num_heads = 4
13- dropout_rate = 0.1
14- num_blocks = 4
15+ batch_size = 64
16+ learning_rate = 3e-4
17+ dropout_rate = 0.2
18+ context_size = 256
19+ embedding_dim = 384
20+ num_heads = 6
21+ num_blocks = 6
1522
1623ds = load_dataset ("minnbanya/nlp-a2-sherlock" )
1724# Concatenate all the text in the train and validation sets
@@ -65,6 +72,7 @@ def __init__(self, head_size):
6572 self .key = nn .Linear (embedding_dim , head_size )
6673 self .value = nn .Linear (embedding_dim , head_size )
6774 self .register_buffer ("tril" , torch .tril (torch .ones (context_size , context_size ).to (device )))
75+ self .dropout = nn .Dropout (dropout_rate )
6876
6977 def forward (self , x ):
7078 _ , T , _ = x .shape
@@ -80,6 +88,8 @@ def forward(self, x):
8088 attention = attention .masked_fill (self .tril [:T , :T ]== 0 , float ("-inf" ))
8189 # softmax is applied to the attention weights to get the final attention weights.
8290 attention = torch .nn .functional .softmax (attention , dim = - 1 )
91+ # Dropout is applied to the attention weights
92+ attention = self .dropout (attention )
8393 # The attention weights are multiplied with the value vector to get the final output.
8494 output = torch .matmul (attention , v )
8595 return output
@@ -157,22 +167,56 @@ def forward(self, x):
157167 return x
158168
159169
160- # Function to train the model
170+ # Function to save the model
171+ def save_model (model , optimizer , epoch , loss ):
172+ checkpoint = {
173+ 'model_state_dict' : model .state_dict (),
174+ 'optimizer_state_dict' : optimizer .state_dict (),
175+ 'epoch' : epoch ,
176+ 'loss' : loss
177+ }
178+ torch .save (checkpoint , MODEL_PATH )
179+ print (f"Model saved at epoch { epoch } with loss { loss :.4f} " )
180+
181+ # Function to load the model
182+ def load_model (model , optimizer ):
183+ if os .path .exists (MODEL_PATH ):
184+ checkpoint = torch .load (MODEL_PATH , map_location = device )
185+ model .load_state_dict (checkpoint ['model_state_dict' ])
186+ optimizer .load_state_dict (checkpoint ['optimizer_state_dict' ])
187+ start_epoch = checkpoint ['epoch' ]
188+ loss = checkpoint ['loss' ]
189+ print (f"Model loaded from epoch { start_epoch } with loss { loss :.4f} " )
190+ return model , optimizer , start_epoch , loss
191+ else :
192+ print ("No saved model found. Training from scratch." )
193+ return model , optimizer , 0 , None
194+
195+ # Function to train the model and save it
161196def train (model ):
162197 optimizer = torch .optim .AdamW (model .parameters (), lr = learning_rate )
163198 loss_fn = nn .CrossEntropyLoss ()
199+
200+ # Load model if saved
201+ model , optimizer , start_epoch , _ = load_model (model , optimizer )
202+
164203 print ("Training started" )
165- for epoch in range (num_epochs ):
204+ for epoch in range (start_epoch , num_epochs ): # Continue training from last saved epoch
166205 x , y = get_batch_data ()
167- x = x .to (device )
168- y = y . to ( device )
206+ x , y = x . to ( device ), y .to (device )
207+
169208 output = model (x )
170209 loss = loss_fn (output .view (- 1 , vocab_size ), y .view (- 1 ))
210+
171211 loss .backward ()
172212 optimizer .step ()
173213 optimizer .zero_grad ()
214+
215+ # Save model every 100 epochs
174216 if epoch % 100 == 0 :
175- print (f"Epoch: { epoch } , Loss: { loss .item ()} " )
217+ print (f"Epoch: { epoch } , Loss: { loss .item ():.4f} " )
218+ if epoch % 1000 == 0 :
219+ save_model (model , optimizer , epoch , loss .item ())
176220
177221# Function to validate the model
178222def validate (model ):
@@ -186,6 +230,7 @@ def validate(model):
186230
187231#Function to generate text from the model
188232def generate (model , start_text , num_chars ):
233+ print (start_text , end = "" )
189234 chars = torch .tensor (encode (start_text )).to (device )
190235 chars = chars .view (1 , len (chars ))
191236 for i in range (num_chars ):
@@ -204,4 +249,4 @@ def generate(model, start_text, num_chars):
204249train (model )
205250with torch .no_grad ():
206251 validate (model )
207- generate (model , "Sherlock Holmes" , 100 )
252+ generate (model , "Sherlock Holmes" , 1000 )
0 commit comments