6
6
Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch.
7
7
---
8
8
"""
9
+ from typing import Optional
9
10
10
11
import torch
11
12
from torch import nn
16
17
from labml_nn .utils import clone_module_list
17
18
18
19
19
- class LinearAttentionFunction (Module ):
20
- def __init__ (self ):
20
+ class DPFP (Module ):
21
+ def __init__ (self , nu : int = 1 , eps : float = 1e-6 ):
21
22
super ().__init__ ()
23
+ self .nu = nu
24
+ self .r = nn .ReLU ()
25
+ self .eps = eps
22
26
23
27
def __call__ (self , x : torch .Tensor ):
24
- return x
28
+ x = self .dpfp (x )
29
+ return x / (torch .sum (x , dim = - 1 , keepdim = True ) + self .eps )
30
+
31
+ def dpfp (self , x : torch .Tensor ):
32
+ x = torch .cat ([self .r (x ), self .r (- x )], dim = - 1 )
33
+ x_rolled = torch .cat ([x .roll (shifts = j , dims = - 1 ) for j in range (1 , self .nu + 1 )], dim = - 1 )
34
+ x_repeat = torch .cat ([x ] * self .nu , dim = - 1 )
35
+
36
+ return x_repeat * x_rolled
25
37
26
38
27
39
class FastWeightAttention (Module ):
28
- def __init__ (self , heads : int , d_model : int , dropout_prob : float = 0.1 ):
40
+ def __init__ (self , heads : int , d_model : int , dropout_prob : float , sigma : DPFP ):
29
41
super ().__init__ ()
30
42
31
43
# Number of features per head
@@ -42,18 +54,21 @@ def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
42
54
self .gate = nn .Sequential (PrepareForMultiHeadAttention (d_model , heads , 1 , bias = False ),
43
55
nn .Sigmoid ())
44
56
45
- self .sigma = LinearAttentionFunction ()
57
+ self .sigma = sigma
46
58
47
59
# Output layer
48
60
self .output = nn .Linear (d_model , d_model )
49
61
# Dropout
50
62
self .dropout = nn .Dropout (dropout_prob )
51
63
52
- def __call__ (self , x : torch .Tensor , weights : torch .Tensor ):
64
+ def __call__ (self , x : torch .Tensor , weights : Optional [ torch .Tensor ] ):
53
65
query = self .sigma (self .query (x ))
54
66
key = self .sigma (self .key (x ))
55
67
value = self .value (x )
56
68
69
+ if weights is None :
70
+ weights = key .new_zeros ((key .shape [0 ], key .shape [1 ], value .shape [2 ], key .shape [2 ]))
71
+
57
72
value_existing = torch .einsum ('bhvk,bhk->bhv' , weights , key )
58
73
59
74
beta = self .gate (x )
@@ -87,7 +102,7 @@ def __init__(self, *,
87
102
self .norm_self_attn = nn .LayerNorm ([d_model ])
88
103
self .norm_ff = nn .LayerNorm ([d_model ])
89
104
90
- def __call__ (self , x : torch .Tensor , weights : torch .Tensor ):
105
+ def __call__ (self , x : torch .Tensor , weights : Optional [ torch .Tensor ] ):
91
106
attn , weights = self .attn (x , weights )
92
107
# Add the self attention results
93
108
x = x + self .dropout (attn )
@@ -117,13 +132,13 @@ def __call__(self, x_seq: torch.Tensor):
117
132
# List to store the outputs
118
133
res = []
119
134
# For each input step
120
- weights = [torch . zeros () for _ in range (len (self .layers ))]
135
+ weights = [None for _ in range (len (self .layers ))]
121
136
122
137
for x in x_seq :
123
138
# Run through each layer
124
139
for i , layer in enumerate (self .layers ):
125
140
# Get layer output
126
- x = layer (x , weights [i ])
141
+ x , weights [ i ] = layer (x , weights [i ])
127
142
128
143
res .append (x )
129
144
0 commit comments