11import math
2- from inspect import isfunction
32
43import torch
54import torch .nn .functional as F
65from torch import nn
76
87
9- class AttentionBlockNew (nn .Module ):
8+ class AttentionBlock (nn .Module ):
109 """
1110 An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
1211 to the N-d case.
@@ -82,55 +81,6 @@ def forward(self, hidden_states):
8281 hidden_states = (hidden_states + residual ) / self .rescale_output_factor
8382 return hidden_states
8483
85- def set_weight (self , attn_layer ):
86- self .group_norm .weight .data = attn_layer .norm .weight .data
87- self .group_norm .bias .data = attn_layer .norm .bias .data
88-
89- if hasattr (attn_layer , "q" ):
90- self .query .weight .data = attn_layer .q .weight .data [:, :, 0 , 0 ]
91- self .key .weight .data = attn_layer .k .weight .data [:, :, 0 , 0 ]
92- self .value .weight .data = attn_layer .v .weight .data [:, :, 0 , 0 ]
93-
94- self .query .bias .data = attn_layer .q .bias .data
95- self .key .bias .data = attn_layer .k .bias .data
96- self .value .bias .data = attn_layer .v .bias .data
97-
98- self .proj_attn .weight .data = attn_layer .proj_out .weight .data [:, :, 0 , 0 ]
99- self .proj_attn .bias .data = attn_layer .proj_out .bias .data
100- elif hasattr (attn_layer , "NIN_0" ):
101- self .query .weight .data = attn_layer .NIN_0 .W .data .T
102- self .key .weight .data = attn_layer .NIN_1 .W .data .T
103- self .value .weight .data = attn_layer .NIN_2 .W .data .T
104-
105- self .query .bias .data = attn_layer .NIN_0 .b .data
106- self .key .bias .data = attn_layer .NIN_1 .b .data
107- self .value .bias .data = attn_layer .NIN_2 .b .data
108-
109- self .proj_attn .weight .data = attn_layer .NIN_3 .W .data .T
110- self .proj_attn .bias .data = attn_layer .NIN_3 .b .data
111-
112- self .group_norm .weight .data = attn_layer .GroupNorm_0 .weight .data
113- self .group_norm .bias .data = attn_layer .GroupNorm_0 .bias .data
114- else :
115- qkv_weight = attn_layer .qkv .weight .data .reshape (
116- self .num_heads , 3 * self .channels // self .num_heads , self .channels
117- )
118- qkv_bias = attn_layer .qkv .bias .data .reshape (self .num_heads , 3 * self .channels // self .num_heads )
119-
120- q_w , k_w , v_w = qkv_weight .split (self .channels // self .num_heads , dim = 1 )
121- q_b , k_b , v_b = qkv_bias .split (self .channels // self .num_heads , dim = 1 )
122-
123- self .query .weight .data = q_w .reshape (- 1 , self .channels )
124- self .key .weight .data = k_w .reshape (- 1 , self .channels )
125- self .value .weight .data = v_w .reshape (- 1 , self .channels )
126-
127- self .query .bias .data = q_b .reshape (- 1 )
128- self .key .bias .data = k_b .reshape (- 1 )
129- self .value .bias .data = v_b .reshape (- 1 )
130-
131- self .proj_attn .weight .data = attn_layer .proj .weight .data [:, :, 0 ]
132- self .proj_attn .bias .data = attn_layer .proj .bias .data
133-
13484
13585class SpatialTransformer (nn .Module ):
13686 """
@@ -170,12 +120,6 @@ def forward(self, x, context=None):
170120 x = self .proj_out (x )
171121 return x + x_in
172122
173- def set_weight (self , layer ):
174- self .norm = layer .norm
175- self .proj_in = layer .proj_in
176- self .transformer_blocks = layer .transformer_blocks
177- self .proj_out = layer .proj_out
178-
179123
180124class BasicTransformerBlock (nn .Module ):
181125 def __init__ (self , dim , n_heads , d_head , dropout = 0.0 , context_dim = None , gated_ff = True , checkpoint = True ):
@@ -203,7 +147,7 @@ class CrossAttention(nn.Module):
203147 def __init__ (self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0.0 ):
204148 super ().__init__ ()
205149 inner_dim = dim_head * heads
206- context_dim = default ( context_dim , query_dim )
150+ context_dim = context_dim if context_dim is not None else query_dim
207151
208152 self .scale = dim_head ** - 0.5
209153 self .heads = heads
@@ -234,7 +178,7 @@ def forward(self, x, context=None, mask=None):
234178 h = self .heads
235179
236180 q = self .to_q (x )
237- context = default ( context , x )
181+ context = context if context is not None else x
238182 k = self .to_k (context )
239183 v = self .to_v (context )
240184
@@ -244,7 +188,7 @@ def forward(self, x, context=None, mask=None):
244188
245189 sim = torch .einsum ("b i d, b j d -> b i j" , q , k ) * self .scale
246190
247- if exists ( mask ) :
191+ if mask is not None :
248192 mask = mask .reshape (batch_size , - 1 )
249193 max_neg_value = - torch .finfo (sim .dtype ).max
250194 mask = mask [:, None , :].repeat (h , 1 , 1 )
@@ -262,8 +206,8 @@ class FeedForward(nn.Module):
262206 def __init__ (self , dim , dim_out = None , mult = 4 , glu = False , dropout = 0.0 ):
263207 super ().__init__ ()
264208 inner_dim = int (dim * mult )
265- dim_out = default ( dim_out , dim )
266- project_in = nn . Sequential ( nn . Linear ( dim , inner_dim ), nn . GELU ()) if not glu else GEGLU (dim , inner_dim )
209+ dim_out = dim_out if dim_out is not None else dim
210+ project_in = GEGLU (dim , inner_dim )
267211
268212 self .net = nn .Sequential (project_in , nn .Dropout (dropout ), nn .Linear (inner_dim , dim_out ))
269213
@@ -280,155 +224,3 @@ def __init__(self, dim_in, dim_out):
280224 def forward (self , x ):
281225 x , gate = self .proj (x ).chunk (2 , dim = - 1 )
282226 return x * F .gelu (gate )
283-
284-
285- # TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
286- class NIN (nn .Module ):
287- def __init__ (self , in_dim , num_units , init_scale = 0.1 ):
288- super ().__init__ ()
289- self .W = nn .Parameter (torch .zeros (in_dim , num_units ), requires_grad = True )
290- self .b = nn .Parameter (torch .zeros (num_units ), requires_grad = True )
291-
292-
293- def exists (val ):
294- return val is not None
295-
296-
297- def default (val , d ):
298- if exists (val ):
299- return val
300- return d () if isfunction (d ) else d
301-
302-
303- # the main attention block that is used for all models
304- class AttentionBlock (nn .Module ):
305- """
306- An attention block that allows spatial positions to attend to each other.
307-
308- Originally ported from here, but adapted to the N-d case.
309- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
310- """
311-
312- def __init__ (
313- self ,
314- channels ,
315- num_heads = 1 ,
316- num_head_channels = None ,
317- num_groups = 32 ,
318- encoder_channels = None ,
319- overwrite_qkv = False ,
320- overwrite_linear = False ,
321- rescale_output_factor = 1.0 ,
322- eps = 1e-5 ,
323- ):
324- super ().__init__ ()
325- self .channels = channels
326- if num_head_channels is None :
327- self .num_heads = num_heads
328- else :
329- assert (
330- channels % num_head_channels == 0
331- ), f"q,k,v channels { channels } is not divisible by num_head_channels { num_head_channels } "
332- self .num_heads = channels // num_head_channels
333-
334- self .norm = nn .GroupNorm (num_channels = channels , num_groups = num_groups , eps = eps , affine = True )
335- self .qkv = nn .Conv1d (channels , channels * 3 , 1 )
336- self .n_heads = self .num_heads
337- self .rescale_output_factor = rescale_output_factor
338-
339- if encoder_channels is not None :
340- self .encoder_kv = nn .Conv1d (encoder_channels , channels * 2 , 1 )
341-
342- self .proj = nn .Conv1d (channels , channels , 1 )
343-
344- self .overwrite_qkv = overwrite_qkv
345- self .overwrite_linear = overwrite_linear
346-
347- if overwrite_qkv :
348- in_channels = channels
349- self .norm = nn .GroupNorm (num_channels = channels , num_groups = num_groups , eps = 1e-6 )
350- self .q = torch .nn .Conv2d (in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
351- self .k = torch .nn .Conv2d (in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
352- self .v = torch .nn .Conv2d (in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
353- self .proj_out = torch .nn .Conv2d (in_channels , in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
354- elif self .overwrite_linear :
355- num_groups = min (channels // 4 , 32 )
356- self .norm = nn .GroupNorm (num_channels = channels , num_groups = num_groups , eps = 1e-6 )
357- self .NIN_0 = NIN (channels , channels )
358- self .NIN_1 = NIN (channels , channels )
359- self .NIN_2 = NIN (channels , channels )
360- self .NIN_3 = NIN (channels , channels )
361-
362- self .GroupNorm_0 = nn .GroupNorm (num_groups = num_groups , num_channels = channels , eps = 1e-6 )
363- else :
364- self .proj_out = nn .Conv1d (channels , channels , 1 )
365- self .set_weights (self )
366-
367- self .is_overwritten = False
368-
369- def set_weights (self , module ):
370- if self .overwrite_qkv :
371- qkv_weight = torch .cat ([module .q .weight .data , module .k .weight .data , module .v .weight .data ], dim = 0 )[
372- :, :, :, 0
373- ]
374- qkv_bias = torch .cat ([module .q .bias .data , module .k .bias .data , module .v .bias .data ], dim = 0 )
375-
376- self .qkv .weight .data = qkv_weight
377- self .qkv .bias .data = qkv_bias
378-
379- proj_out = nn .Conv1d (self .channels , self .channels , 1 )
380- proj_out .weight .data = module .proj_out .weight .data [:, :, :, 0 ]
381- proj_out .bias .data = module .proj_out .bias .data
382-
383- self .proj = proj_out
384- elif self .overwrite_linear :
385- self .qkv .weight .data = torch .concat (
386- [self .NIN_0 .W .data .T , self .NIN_1 .W .data .T , self .NIN_2 .W .data .T ], dim = 0
387- )[:, :, None ]
388- self .qkv .bias .data = torch .concat ([self .NIN_0 .b .data , self .NIN_1 .b .data , self .NIN_2 .b .data ], dim = 0 )
389-
390- self .proj .weight .data = self .NIN_3 .W .data .T [:, :, None ]
391- self .proj .bias .data = self .NIN_3 .b .data
392-
393- self .norm .weight .data = self .GroupNorm_0 .weight .data
394- self .norm .bias .data = self .GroupNorm_0 .bias .data
395- else :
396- self .proj .weight .data = self .proj_out .weight .data
397- self .proj .bias .data = self .proj_out .bias .data
398-
399- def forward (self , x , encoder_out = None ):
400- if not self .is_overwritten and (self .overwrite_qkv or self .overwrite_linear ):
401- self .set_weights (self )
402- self .is_overwritten = True
403-
404- b , c , * spatial = x .shape
405- hid_states = self .norm (x ).view (b , c , - 1 )
406-
407- qkv = self .qkv (hid_states )
408- bs , width , length = qkv .shape
409- assert width % (3 * self .n_heads ) == 0
410- ch = width // (3 * self .n_heads )
411- q , k , v = qkv .reshape (bs * self .n_heads , ch * 3 , length ).split (ch , dim = 1 )
412-
413- if encoder_out is not None :
414- encoder_kv = self .encoder_kv (encoder_out )
415- assert encoder_kv .shape [1 ] == self .n_heads * ch * 2
416- ek , ev = encoder_kv .reshape (bs * self .n_heads , ch * 2 , - 1 ).split (ch , dim = 1 )
417- k = torch .cat ([ek , k ], dim = - 1 )
418- v = torch .cat ([ev , v ], dim = - 1 )
419-
420- scale = 1 / math .sqrt (math .sqrt (ch ))
421- weight = torch .einsum ("bct,bcs->bts" , q * scale , k * scale ) # More stable with f16 than dividing afterwards
422- weight = torch .softmax (weight .float (), dim = - 1 ).type (weight .dtype )
423-
424- a = torch .einsum ("bts,bcs->bct" , weight , v )
425- h = a .reshape (bs , - 1 , length )
426-
427- h = self .proj (h )
428- h = h .reshape (b , c , * spatial )
429-
430- result = x + h
431-
432- result = result / self .rescale_output_factor
433-
434- return result
0 commit comments