@@ -30,159 +30,6 @@ def _make_positions(self, tensor, pad_index: int):
30
30
return torch .cumsum (masked , dim = 1 ) * masked + pad_index
31
31
32
32
33
- class ResidualMLP (Module ):
34
- def __init__ (
35
- self ,
36
- input_dim : int ,
37
- hidden_dims : List [int ],
38
- dropout : float = 0.1 ,
39
- activation = nn .GELU ,
40
- add_residual = True ,
41
- ):
42
- super ().__init__ ()
43
- modules = []
44
- for last_dim , dim in zip ([input_dim ] + hidden_dims , hidden_dims ):
45
- modules .extend ([nn .Linear (last_dim , dim ), activation (), nn .Dropout (dropout )])
46
-
47
- last_dim = hidden_dims [- 1 ] if hidden_dims else input_dim
48
- modules .extend ([nn .Linear (last_dim , input_dim ), nn .Dropout (dropout )])
49
-
50
- self .mlp = nn .Sequential (* modules )
51
- self .add_residual = add_residual
52
- self .hidden_dim = hidden_dims [0 ] if hidden_dims else input_dim
53
-
54
- def forward (self , input ):
55
- bias = self .mlp (input )
56
- if not hasattr (self , "add_residual" ):
57
- self .add_residual = True
58
- if self .add_residual :
59
- return input + bias
60
- else :
61
- return bias
62
-
63
-
64
- class MultiheadSelfAttention (Module ):
65
- def __init__ (
66
- self ,
67
- embed_dim : int ,
68
- num_heads : int ,
69
- scaling : Optional [float ] = None ,
70
- dropout : float = 0.1 ,
71
- ):
72
- super ().__init__ ()
73
- self .embed_dim = embed_dim
74
- self .num_heads = num_heads
75
- self .head_dim = embed_dim // num_heads
76
-
77
- expected_scaling = float (1 / math .sqrt (self .head_dim ))
78
-
79
- assert embed_dim % num_heads == 0 , f"embed_dim={ embed_dim } should be a multiple of num_heads={ num_heads } "
80
-
81
- if not scaling :
82
- logger .warn (
83
- f"""
84
- Scaling not set. Please manually set scaling for transformers.
85
- In this case the suggested value { expected_scaling } will be inferred,
86
- or float(1 / math.sqrt(head_dim))
87
- where head_dim = embed_dim // num_heads = { self .head_dim }
88
- and embed_dim = { embed_dim } and num_heads = { num_heads } .
89
- """
90
- )
91
- scaling = expected_scaling
92
-
93
- self .scaling = scaling
94
- self .dropout = nn .Dropout (dropout )
95
- self .input_projection = nn .Linear (embed_dim , 3 * embed_dim )
96
- self .output_projection = nn .Linear (embed_dim , embed_dim )
97
-
98
- def forward (self , query : torch .Tensor , key_padding_mask : torch .Tensor , attn_mask : Optional [torch .Tensor ] = None ):
99
- target_length , batch_size , embed_dim = query .size ()
100
- mask_batch_size , source_length = key_padding_mask .size ()
101
-
102
- torch ._assert (embed_dim == self .embed_dim , "query embed dim doesn't match" )
103
- torch ._assert (
104
- batch_size == mask_batch_size ,
105
- "query and key_padding_mask batch sizes differed" ,
106
- )
107
-
108
- projection = self .input_projection (query )
109
- q , k , v = projection .chunk (3 , dim = - 1 )
110
- q = self .scaling * q
111
-
112
- batch_heads = batch_size * self .num_heads
113
-
114
- q = q .contiguous ().view (- 1 , batch_heads , self .head_dim ).transpose (0 , 1 )
115
- k = k .contiguous ().view (- 1 , batch_heads , self .head_dim ).transpose (0 , 1 )
116
- v = v .contiguous ().view (- 1 , batch_heads , self .head_dim ).transpose (0 , 1 )
117
-
118
- torch ._assert (k .size (1 ) == source_length , "key size should be equal to source length" )
119
-
120
- attn_weights = torch .bmm (q , k .transpose (1 , 2 ))
121
- if attn_mask is not None :
122
- torch ._assert (attn_mask .dim () == 2 , "Expected attn_mask of dim 2 but got {}" .format (attn_mask .dim ()))
123
- torch ._assert (
124
- attn_mask .size (0 ) == target_length ,
125
- "attn_mask shape didn't match for target length {}" .format (target_length ),
126
- )
127
- torch ._assert (
128
- attn_mask .size (1 ) == source_length ,
129
- "attn_mask shape didn't match for source length {}" .format (source_length ),
130
- )
131
- torch ._assert (
132
- attn_mask .is_floating_point () or attn_mask .dtype == torch .bool ,
133
- f"Only float or bool types are supported for attn_mask not { attn_mask .dtype } " ,
134
- )
135
- if attn_mask .dtype == torch .bool :
136
- new_attn_mask = torch .zeros_like (attn_mask , dtype = query .dtype )
137
- new_attn_mask .masked_fill_ (attn_mask , - 1e8 if query .dtype == torch .float32 else - 1e4 )
138
- attn_mask = new_attn_mask
139
- attn_mask = attn_mask .unsqueeze (0 )
140
- attn_weights += attn_mask
141
-
142
- torch ._assert (attn_weights .dim () == 3 , "Unexpected attn_weights dim" )
143
- torch ._assert (
144
- attn_weights .size (0 ) == batch_heads ,
145
- "attn_weights shape didn't match for batch heads" ,
146
- )
147
- torch ._assert (
148
- attn_weights .size (1 ) == target_length ,
149
- "attn_weights shape didn't match for target length" ,
150
- )
151
- torch ._assert (
152
- attn_weights .size (2 ) == source_length ,
153
- "attn_weights shape didn't match for source length" ,
154
- )
155
-
156
- attn_weights = attn_weights .view (batch_size , self .num_heads , target_length , source_length )
157
- attn_weights = attn_weights .masked_fill (key_padding_mask .unsqueeze (1 ).unsqueeze (2 ), float ("-inf" ))
158
- attn_weights = attn_weights .view (batch_heads , target_length , source_length )
159
-
160
- attn_weights = F .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).type_as (attn_weights )
161
- attn_weights = self .dropout (attn_weights )
162
-
163
- attn = torch .bmm (attn_weights , v )
164
-
165
- torch ._assert (
166
- attn .dim () == 3 ,
167
- "unexpected attn dim size" ,
168
- )
169
- torch ._assert (
170
- attn .size (0 ) == batch_heads ,
171
- "attn shape didn't match for batch heads" ,
172
- )
173
- torch ._assert (
174
- attn .size (1 ) == target_length ,
175
- "attn shape didn't match for target length" ,
176
- )
177
- torch ._assert (
178
- attn .size (2 ) == self .head_dim ,
179
- "attn shape didn't match for head dim" ,
180
- )
181
- attn = attn .transpose (0 , 1 ).contiguous ().view (target_length , batch_size , self .head_dim * self .num_heads )
182
- attn = self .output_projection (attn )
183
-
184
- return attn
185
-
186
33
187
34
class TransformerEncoderLayer (Module ):
188
35
def __init__ (
0 commit comments