|
2 | 2 | import torch.nn as nn |
3 | 3 | from datasets import load_dataset |
4 | 4 |
|
5 | | - |
| 5 | +device = torch.device("mps") |
6 | 6 | # Network Parameters |
7 | | -batch_size = 1 |
| 7 | +batch_size = 64 |
8 | 8 | learning_rate = 1e-3 |
9 | 9 | context_size = 128 |
| 10 | +num_epochs = 2000 |
| 11 | +embedding_dim = 10 |
10 | 12 |
|
11 | 13 | ds = load_dataset("minnbanya/nlp-a2-sherlock") |
12 | 14 | # Concatenate all the text in the train and validation sets |
|
17 | 19 | vocabulary = sorted(set(list(train_data + validation_data))) |
18 | 20 | print(f"Vocabulary size: {len(vocabulary)}") |
19 | 21 | print(f"Vocabulary : {vocabulary}") |
| 22 | +vocab_size = len(vocabulary) |
20 | 23 |
|
21 | 24 |
|
22 | 25 | # Create the character to token mapping. |
@@ -51,4 +54,59 @@ def get_batch_data(data_type="train"): |
51 | 54 | x, y = torch.stack(x), torch.stack(y) |
52 | 55 | return x, y |
53 | 56 |
|
54 | | -x,y = get_batch_data() |
| 57 | + |
| 58 | +class Transformer(nn.Module): |
| 59 | + |
| 60 | + def __init__(self, vocab_size, context_size): |
| 61 | + super(Transformer, self).__init__() |
| 62 | + self.embedding = nn.Embedding(vocab_size, embedding_dim) |
| 63 | + self.linear = nn.Linear(embedding_dim, vocab_size) |
| 64 | + |
| 65 | + def forward(self, x): |
| 66 | + x = self.embedding(x) |
| 67 | + x = self.linear(x) |
| 68 | + return x |
| 69 | + |
| 70 | + |
| 71 | +def train(model): |
| 72 | + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) |
| 73 | + loss_fn = nn.CrossEntropyLoss() |
| 74 | + print("Training started") |
| 75 | + for epoch in range(num_epochs): |
| 76 | + x, y = get_batch_data() |
| 77 | + x = x.to(device) |
| 78 | + y = y.to(device) |
| 79 | + output = model(x) |
| 80 | + loss = loss_fn(output.view(-1, vocab_size), y.view(-1)) |
| 81 | + loss.backward() |
| 82 | + optimizer.step() |
| 83 | + optimizer.zero_grad() |
| 84 | + if epoch % 100 == 0: |
| 85 | + print(f"Epoch: {epoch}, Loss: {loss.item()}") |
| 86 | + |
| 87 | +def validate(model): |
| 88 | + loss_fn = nn.CrossEntropyLoss() |
| 89 | + x, y = get_batch_data("validation") |
| 90 | + x = x.to(device) |
| 91 | + y = y.to(device) |
| 92 | + output = model(x) |
| 93 | + loss = loss_fn(output.view(-1, vocab_size), y.view(-1)) |
| 94 | + print(f"Validation Loss: {loss.item()}") |
| 95 | + |
| 96 | +def generate(model, start_text, num_chars): |
| 97 | + chars = encode(start_text) |
| 98 | + for i in range(num_chars): |
| 99 | + x = torch.tensor(chars[-context_size:]).unsqueeze(0).to(device) |
| 100 | + output = model(x) |
| 101 | + prob = torch.nn.functional.softmax(output[0, -1], dim=0) |
| 102 | + idx = torch.multinomial(prob, num_samples=1) |
| 103 | + print(decode(idx.cpu().numpy()), end="") |
| 104 | + |
| 105 | + |
| 106 | +model = Transformer(len(vocabulary), context_size) |
| 107 | +model = model.to(device) |
| 108 | +model.train() |
| 109 | +train(model) |
| 110 | +with torch.no_grad(): |
| 111 | + validate(model) |
| 112 | + generate(model, "Sherlock Holmes", 100) |
0 commit comments