|
7 | 7 | batch_size = 64 |
8 | 8 | learning_rate = 1e-3 |
9 | 9 | context_size = 128 |
10 | | -num_epochs = 2000 |
11 | | -embedding_dim = 10 |
| 10 | +num_epochs = 10000 |
| 11 | +embedding_dim = 128 |
12 | 12 |
|
13 | 13 | ds = load_dataset("minnbanya/nlp-a2-sherlock") |
14 | 14 | # Concatenate all the text in the train and validation sets |
@@ -54,16 +54,36 @@ def get_batch_data(data_type="train"): |
54 | 54 | x, y = torch.stack(x), torch.stack(y) |
55 | 55 | return x, y |
56 | 56 |
|
| 57 | +class SelfAttention(nn.Module): |
| 58 | + def __init__(self, embedding_dim): |
| 59 | + super().__init__() |
| 60 | + self.query = nn.Linear(embedding_dim, embedding_dim) |
| 61 | + self.key = nn.Linear(embedding_dim, embedding_dim) |
| 62 | + self.value = nn.Linear(embedding_dim, embedding_dim) |
| 63 | + self.register_buffer("tril", torch.tril(torch.ones(context_size, context_size))) |
| 64 | + |
| 65 | + def forward(self, x): |
| 66 | + _, T, _ = x.shape |
| 67 | + q = self.query(x) |
| 68 | + k = self.key(x) |
| 69 | + v = self.value(x) |
| 70 | + attention = torch.matmul(q, k.transpose(-2, -1)) * context_size ** -0.5 |
| 71 | + attention = attention.masked_fill(self.tril[:T, :T]== 0, float("-inf")) |
| 72 | + attention = torch.nn.functional.softmax(attention, dim=-1) |
| 73 | + output = torch.matmul(attention, v) |
| 74 | + return output |
57 | 75 |
|
58 | 76 | class Transformer(nn.Module): |
59 | 77 |
|
60 | 78 | def __init__(self, vocab_size, context_size): |
61 | 79 | super(Transformer, self).__init__() |
62 | 80 | self.embedding = nn.Embedding(vocab_size, embedding_dim) |
| 81 | + self.attention_head = SelfAttention(embedding_dim) |
63 | 82 | self.linear = nn.Linear(embedding_dim, vocab_size) |
64 | 83 |
|
65 | 84 | def forward(self, x): |
66 | 85 | x = self.embedding(x) |
| 86 | + x = self.attention_head(x) |
67 | 87 | x = self.linear(x) |
68 | 88 | return x |
69 | 89 |
|
@@ -94,13 +114,16 @@ def validate(model): |
94 | 114 | print(f"Validation Loss: {loss.item()}") |
95 | 115 |
|
96 | 116 | def generate(model, start_text, num_chars): |
97 | | - chars = encode(start_text) |
| 117 | + chars = torch.tensor(encode(start_text)).to(device) |
| 118 | + chars = chars.view(1, len(chars)) |
98 | 119 | for i in range(num_chars): |
99 | | - x = torch.tensor(chars[-context_size:]).unsqueeze(0).to(device) |
100 | | - output = model(x) |
| 120 | + output = model(chars) |
101 | 121 | prob = torch.nn.functional.softmax(output[0, -1], dim=0) |
102 | 122 | idx = torch.multinomial(prob, num_samples=1) |
103 | | - print(decode(idx.cpu().numpy()), end="") |
| 123 | + char = decode(idx.cpu().numpy()) |
| 124 | + print(char, end="") |
| 125 | + chars = torch.cat([chars, idx.view(1, 1)], dim=1) |
| 126 | + chars = chars[:, -context_size:] |
104 | 127 |
|
105 | 128 |
|
106 | 129 | model = Transformer(len(vocabulary), context_size) |
|
0 commit comments