4242)
4343
4444
45+ class ResnetBlockCondNorm2D (nn .Module ):
46+ r"""
47+ A Resnet block that use normalization layer that incorporate conditioning information.
48+
49+ Parameters:
50+ in_channels (`int`): The number of channels in the input.
51+ out_channels (`int`, *optional*, default to be `None`):
52+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
53+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
54+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
55+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
56+ groups_out (`int`, *optional*, default to None):
57+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
58+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
59+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
60+ time_embedding_norm (`str`, *optional*, default to `"ada_group"` ):
61+ The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial".
62+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
63+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
64+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
65+ use_in_shortcut (`bool`, *optional*, default to `True`):
66+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
67+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
68+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
69+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
70+ `conv_shortcut` output.
71+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
72+ If None, same as `out_channels`.
73+ """
74+
75+ def __init__ (
76+ self ,
77+ * ,
78+ in_channels : int ,
79+ out_channels : Optional [int ] = None ,
80+ conv_shortcut : bool = False ,
81+ dropout : float = 0.0 ,
82+ temb_channels : int = 512 ,
83+ groups : int = 32 ,
84+ groups_out : Optional [int ] = None ,
85+ eps : float = 1e-6 ,
86+ non_linearity : str = "swish" ,
87+ time_embedding_norm : str = "ada_group" , # ada_group, spatial
88+ output_scale_factor : float = 1.0 ,
89+ use_in_shortcut : Optional [bool ] = None ,
90+ up : bool = False ,
91+ down : bool = False ,
92+ conv_shortcut_bias : bool = True ,
93+ conv_2d_out_channels : Optional [int ] = None ,
94+ ):
95+ super ().__init__ ()
96+ self .in_channels = in_channels
97+ out_channels = in_channels if out_channels is None else out_channels
98+ self .out_channels = out_channels
99+ self .use_conv_shortcut = conv_shortcut
100+ self .up = up
101+ self .down = down
102+ self .output_scale_factor = output_scale_factor
103+ self .time_embedding_norm = time_embedding_norm
104+
105+ conv_cls = nn .Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
106+
107+ if groups_out is None :
108+ groups_out = groups
109+
110+ if self .time_embedding_norm == "ada_group" : # ada_group
111+ self .norm1 = AdaGroupNorm (temb_channels , in_channels , groups , eps = eps )
112+ elif self .time_embedding_norm == "spatial" :
113+ self .norm1 = SpatialNorm (in_channels , temb_channels )
114+ else :
115+ raise ValueError (f" unsupported time_embedding_norm: { self .time_embedding_norm } " )
116+
117+ self .conv1 = conv_cls (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
118+
119+ if self .time_embedding_norm == "ada_group" : # ada_group
120+ self .norm2 = AdaGroupNorm (temb_channels , out_channels , groups_out , eps = eps )
121+ elif self .time_embedding_norm == "spatial" : # spatial
122+ self .norm2 = SpatialNorm (out_channels , temb_channels )
123+ else :
124+ raise ValueError (f" unsupported time_embedding_norm: { self .time_embedding_norm } " )
125+
126+ self .dropout = torch .nn .Dropout (dropout )
127+
128+ conv_2d_out_channels = conv_2d_out_channels or out_channels
129+ self .conv2 = conv_cls (out_channels , conv_2d_out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
130+
131+ self .nonlinearity = get_activation (non_linearity )
132+
133+ self .upsample = self .downsample = None
134+ if self .up :
135+ self .upsample = Upsample2D (in_channels , use_conv = False )
136+ elif self .down :
137+ self .downsample = Downsample2D (in_channels , use_conv = False , padding = 1 , name = "op" )
138+
139+ self .use_in_shortcut = self .in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
140+
141+ self .conv_shortcut = None
142+ if self .use_in_shortcut :
143+ self .conv_shortcut = conv_cls (
144+ in_channels ,
145+ conv_2d_out_channels ,
146+ kernel_size = 1 ,
147+ stride = 1 ,
148+ padding = 0 ,
149+ bias = conv_shortcut_bias ,
150+ )
151+
152+ def forward (
153+ self ,
154+ input_tensor : torch .FloatTensor ,
155+ temb : torch .FloatTensor ,
156+ scale : float = 1.0 ,
157+ ) -> torch .FloatTensor :
158+ hidden_states = input_tensor
159+
160+ hidden_states = self .norm1 (hidden_states , temb )
161+
162+ hidden_states = self .nonlinearity (hidden_states )
163+
164+ if self .upsample is not None :
165+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
166+ if hidden_states .shape [0 ] >= 64 :
167+ input_tensor = input_tensor .contiguous ()
168+ hidden_states = hidden_states .contiguous ()
169+ input_tensor = self .upsample (input_tensor , scale = scale )
170+ hidden_states = self .upsample (hidden_states , scale = scale )
171+
172+ elif self .downsample is not None :
173+ input_tensor = self .downsample (input_tensor , scale = scale )
174+ hidden_states = self .downsample (hidden_states , scale = scale )
175+
176+ hidden_states = self .conv1 (hidden_states , scale ) if not USE_PEFT_BACKEND else self .conv1 (hidden_states )
177+
178+ hidden_states = self .norm2 (hidden_states , temb )
179+
180+ hidden_states = self .nonlinearity (hidden_states )
181+
182+ hidden_states = self .dropout (hidden_states )
183+ hidden_states = self .conv2 (hidden_states , scale ) if not USE_PEFT_BACKEND else self .conv2 (hidden_states )
184+
185+ if self .conv_shortcut is not None :
186+ input_tensor = (
187+ self .conv_shortcut (input_tensor , scale ) if not USE_PEFT_BACKEND else self .conv_shortcut (input_tensor )
188+ )
189+
190+ output_tensor = (input_tensor + hidden_states ) / self .output_scale_factor
191+
192+ return output_tensor
193+
194+
45195class ResnetBlock2D (nn .Module ):
46196 r"""
47197 A Resnet block.
@@ -58,8 +208,8 @@ class ResnetBlock2D(nn.Module):
58208 eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
59209 non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
60210 time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
61- By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
62- "ada_group" for a stronger conditioning with scale and shift.
211+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift"
212+ for a stronger conditioning with scale and shift.
63213 kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
64214 [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
65215 output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
@@ -87,7 +237,7 @@ def __init__(
87237 eps : float = 1e-6 ,
88238 non_linearity : str = "swish" ,
89239 skip_time_act : bool = False ,
90- time_embedding_norm : str = "default" , # default, scale_shift, ada_group, spatial
240+ time_embedding_norm : str = "default" , # default, scale_shift,
91241 kernel : Optional [torch .FloatTensor ] = None ,
92242 output_scale_factor : float = 1.0 ,
93243 use_in_shortcut : Optional [bool ] = None ,
@@ -97,7 +247,15 @@ def __init__(
97247 conv_2d_out_channels : Optional [int ] = None ,
98248 ):
99249 super ().__init__ ()
100- self .pre_norm = pre_norm
250+ if time_embedding_norm == "ada_group" :
251+ raise ValueError (
252+ "This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead" ,
253+ )
254+ if time_embedding_norm == "spatial" :
255+ raise ValueError (
256+ "This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead" ,
257+ )
258+
101259 self .pre_norm = True
102260 self .in_channels = in_channels
103261 out_channels = in_channels if out_channels is None else out_channels
@@ -115,12 +273,7 @@ def __init__(
115273 if groups_out is None :
116274 groups_out = groups
117275
118- if self .time_embedding_norm == "ada_group" :
119- self .norm1 = AdaGroupNorm (temb_channels , in_channels , groups , eps = eps )
120- elif self .time_embedding_norm == "spatial" :
121- self .norm1 = SpatialNorm (in_channels , temb_channels )
122- else :
123- self .norm1 = torch .nn .GroupNorm (num_groups = groups , num_channels = in_channels , eps = eps , affine = True )
276+ self .norm1 = torch .nn .GroupNorm (num_groups = groups , num_channels = in_channels , eps = eps , affine = True )
124277
125278 self .conv1 = conv_cls (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
126279
@@ -129,19 +282,12 @@ def __init__(
129282 self .time_emb_proj = linear_cls (temb_channels , out_channels )
130283 elif self .time_embedding_norm == "scale_shift" :
131284 self .time_emb_proj = linear_cls (temb_channels , 2 * out_channels )
132- elif self .time_embedding_norm == "ada_group" or self .time_embedding_norm == "spatial" :
133- self .time_emb_proj = None
134285 else :
135286 raise ValueError (f"unknown time_embedding_norm : { self .time_embedding_norm } " )
136287 else :
137288 self .time_emb_proj = None
138289
139- if self .time_embedding_norm == "ada_group" :
140- self .norm2 = AdaGroupNorm (temb_channels , out_channels , groups_out , eps = eps )
141- elif self .time_embedding_norm == "spatial" :
142- self .norm2 = SpatialNorm (out_channels , temb_channels )
143- else :
144- self .norm2 = torch .nn .GroupNorm (num_groups = groups_out , num_channels = out_channels , eps = eps , affine = True )
290+ self .norm2 = torch .nn .GroupNorm (num_groups = groups_out , num_channels = out_channels , eps = eps , affine = True )
145291
146292 self .dropout = torch .nn .Dropout (dropout )
147293 conv_2d_out_channels = conv_2d_out_channels or out_channels
@@ -188,11 +334,7 @@ def forward(
188334 ) -> torch .FloatTensor :
189335 hidden_states = input_tensor
190336
191- if self .time_embedding_norm == "ada_group" or self .time_embedding_norm == "spatial" :
192- hidden_states = self .norm1 (hidden_states , temb )
193- else :
194- hidden_states = self .norm1 (hidden_states )
195-
337+ hidden_states = self .norm1 (hidden_states )
196338 hidden_states = self .nonlinearity (hidden_states )
197339
198340 if self .upsample is not None :
@@ -233,17 +375,20 @@ def forward(
233375 else self .time_emb_proj (temb )[:, :, None , None ]
234376 )
235377
236- if temb is not None and self .time_embedding_norm == "default" :
237- hidden_states = hidden_states + temb
238-
239- if self .time_embedding_norm == "ada_group" or self .time_embedding_norm == "spatial" :
240- hidden_states = self .norm2 (hidden_states , temb )
241- else :
378+ if self .time_embedding_norm == "default" :
379+ if temb is not None :
380+ hidden_states = hidden_states + temb
242381 hidden_states = self .norm2 (hidden_states )
243-
244- if temb is not None and self .time_embedding_norm == "scale_shift" :
382+ elif self .time_embedding_norm == "scale_shift" :
383+ if temb is None :
384+ raise ValueError (
385+ f" `temb` should not be None when `time_embedding_norm` is { self .time_embedding_norm } "
386+ )
245387 scale , shift = torch .chunk (temb , 2 , dim = 1 )
388+ hidden_states = self .norm2 (hidden_states )
246389 hidden_states = hidden_states * (1 + scale ) + shift
390+ else :
391+ hidden_states = self .norm2 (hidden_states )
247392
248393 hidden_states = self .nonlinearity (hidden_states )
249394
0 commit comments