Skip to content

Commit 9a49ae4

Browse files
committed
fast weights experiment
1 parent 2cb1e3f commit 9a49ae4

File tree

2 files changed

+132
-9
lines changed

2 files changed

+132
-9
lines changed

labml_nn/transformers/fast_weights/__init__.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch.
77
---
88
"""
9+
from typing import Optional
910

1011
import torch
1112
from torch import nn
@@ -16,16 +17,27 @@
1617
from labml_nn.utils import clone_module_list
1718

1819

19-
class LinearAttentionFunction(Module):
20-
def __init__(self):
20+
class DPFP(Module):
21+
def __init__(self, nu: int = 1, eps: float = 1e-6):
2122
super().__init__()
23+
self.nu = nu
24+
self.r = nn.ReLU()
25+
self.eps = eps
2226

2327
def __call__(self, x: torch.Tensor):
24-
return x
28+
x = self.dpfp(x)
29+
return x / (torch.sum(x, dim=-1, keepdim=True) + self.eps)
30+
31+
def dpfp(self, x: torch.Tensor):
32+
x = torch.cat([self.r(x), self.r(-x)], dim=-1)
33+
x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu + 1)], dim=-1)
34+
x_repeat = torch.cat([x] * self.nu, dim=-1)
35+
36+
return x_repeat * x_rolled
2537

2638

2739
class FastWeightAttention(Module):
28-
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
40+
def __init__(self, heads: int, d_model: int, dropout_prob: float, sigma: DPFP):
2941
super().__init__()
3042

3143
# Number of features per head
@@ -42,18 +54,21 @@ def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
4254
self.gate = nn.Sequential(PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
4355
nn.Sigmoid())
4456

45-
self.sigma = LinearAttentionFunction()
57+
self.sigma = sigma
4658

4759
# Output layer
4860
self.output = nn.Linear(d_model, d_model)
4961
# Dropout
5062
self.dropout = nn.Dropout(dropout_prob)
5163

52-
def __call__(self, x: torch.Tensor, weights: torch.Tensor):
64+
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
5365
query = self.sigma(self.query(x))
5466
key = self.sigma(self.key(x))
5567
value = self.value(x)
5668

69+
if weights is None:
70+
weights = key.new_zeros((key.shape[0], key.shape[1], value.shape[2], key.shape[2]))
71+
5772
value_existing = torch.einsum('bhvk,bhk->bhv', weights, key)
5873

5974
beta = self.gate(x)
@@ -87,7 +102,7 @@ def __init__(self, *,
87102
self.norm_self_attn = nn.LayerNorm([d_model])
88103
self.norm_ff = nn.LayerNorm([d_model])
89104

90-
def __call__(self, x: torch.Tensor, weights: torch.Tensor):
105+
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
91106
attn, weights = self.attn(x, weights)
92107
# Add the self attention results
93108
x = x + self.dropout(attn)
@@ -117,13 +132,13 @@ def __call__(self, x_seq: torch.Tensor):
117132
# List to store the outputs
118133
res = []
119134
# For each input step
120-
weights = [torch.zeros() for _ in range(len(self.layers))]
135+
weights = [None for _ in range(len(self.layers))]
121136

122137
for x in x_seq:
123138
# Run through each layer
124139
for i, layer in enumerate(self.layers):
125140
# Get layer output
126-
x = layer(x, weights[i])
141+
x, weights[i] = layer(x, weights[i])
127142

128143
res.append(x)
129144

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
---
3+
title: Train Fast Weights Transformer
4+
summary: This is training code with notes for a Fast Weights Transformer.
5+
---
6+
"""
7+
8+
import torch
9+
from torch import nn
10+
11+
from labml import experiment
12+
from labml.configs import option
13+
from labml.utils.pytorch import get_modules
14+
from labml_helpers.module import Module
15+
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
16+
17+
18+
class AutoregressiveModel(Module):
19+
"""
20+
## Auto regressive model
21+
"""
22+
23+
def __init__(self, n_vocab: int, d_model: int, transformer: Module):
24+
super().__init__()
25+
# Token embedding module
26+
self.src_embed = nn.Embedding(n_vocab, d_model)
27+
self.transformer = transformer
28+
self.generator = nn.Linear(d_model, n_vocab)
29+
30+
def forward(self, x: torch.Tensor):
31+
# Embed the tokens
32+
x = self.src_embed(x)
33+
# Run it through the the transformer
34+
res = self.transformer(x)
35+
# Generate logits of the next token
36+
return self.generator(res), None
37+
38+
39+
class Configs(NLPAutoRegressionConfigs):
40+
"""
41+
## Configurations
42+
43+
The default configs can and will be over-ridden when we start the experiment
44+
"""
45+
46+
model: AutoregressiveModel
47+
48+
d_model: int = 512
49+
nu: int = 1
50+
heads: int = 8
51+
dropout: float = 0.0
52+
d_ff: int = 2048
53+
n_layers: int = 6
54+
55+
56+
@option(Configs.model)
57+
def fast_weights_transformer(c: Configs):
58+
"""
59+
Create [fast weights transformer](index.html).
60+
"""
61+
from labml_nn.transformers.fast_weights import FastWeightAttentionTransformer, \
62+
FastWeightAttentionTransformerLayer, FastWeightAttention, FeedForward
63+
64+
from labml_nn.transformers.fast_weights import DPFP
65+
return AutoregressiveModel(
66+
c.n_tokens, c.d_model,
67+
FastWeightAttentionTransformer(
68+
FastWeightAttentionTransformerLayer(d_model=c.d_model,
69+
attn=FastWeightAttention(c.heads, c.d_model, c.dropout, DPFP(nu=c.nu)),
70+
feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
71+
dropout_prob=c.dropout),
72+
c.n_layers)).to(c.device)
73+
74+
75+
def main():
76+
# Create experiment
77+
experiment.create(name="fast_weights_transformer")
78+
# Create configs
79+
conf = Configs()
80+
# Load configurations
81+
experiment.configs(conf,
82+
# A dictionary of configurations to override
83+
{'tokenizer': 'character',
84+
'text': 'tiny_shakespeare',
85+
'optimizer.learning_rate': 1.0,
86+
'optimizer.optimizer': 'Noam',
87+
'prompt': 'It is',
88+
'prompt_separator': '',
89+
90+
'train_loader': 'shuffled_train_loader',
91+
'valid_loader': 'shuffled_valid_loader',
92+
93+
'seq_len': 128,
94+
'epochs': 128,
95+
'batch_size': 16,
96+
'inner_iterations': 25})
97+
98+
# Set models for saving and loading
99+
experiment.add_pytorch_models(get_modules(conf))
100+
101+
# Start the experiment
102+
with experiment.start():
103+
# Run the training loop
104+
conf.run()
105+
106+
107+
if __name__ == '__main__':
108+
main()

0 commit comments

Comments
 (0)