2727from ..modeling_utils import ModelMixin
2828from .attention import AttentionBlock
2929from .embeddings import GaussianFourierProjection , get_timestep_embedding
30- from .resnet import ResnetBlockBigGANpp , downsample_2d , upfirdn2d , upsample_2d
30+ from .resnet import Downsample , ResnetBlockBigGANpp , Upsample , downsample_2d , upfirdn2d , upsample_2d
3131
3232
3333def _setup_kernel (k ):
@@ -184,37 +184,39 @@ def forward(self, x, y):
184184
185185
186186class FirUpsample (nn .Module ):
187- def __init__ (self , in_ch = None , out_ch = None , with_conv = False , fir_kernel = (1 , 3 , 3 , 1 )):
187+ def __init__ (self , channels = None , out_channels = None , use_conv = False , fir_kernel = (1 , 3 , 3 , 1 )):
188188 super ().__init__ ()
189- out_ch = out_ch if out_ch else in_ch
190- if with_conv :
191- self .Conv2d_0 = Conv2d (in_ch , out_ch , kernel_size = 3 , stride = 1 , padding = 1 )
192- self .with_conv = with_conv
189+ out_channels = out_channels if out_channels else channels
190+ if use_conv :
191+ self .Conv2d_0 = Conv2d (channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
192+ self .use_conv = use_conv
193193 self .fir_kernel = fir_kernel
194- self .out_ch = out_ch
194+ self .out_channels = out_channels
195195
196196 def forward (self , x ):
197- if self .with_conv :
197+ if self .use_conv :
198198 h = _upsample_conv_2d (x , self .Conv2d_0 .weight , k = self .fir_kernel )
199+ h = h + self .Conv2d_0 .bias .reshape (1 , - 1 , 1 , 1 )
199200 else :
200201 h = upsample_2d (x , self .fir_kernel , factor = 2 )
201202
202203 return h
203204
204205
205206class FirDownsample (nn .Module ):
206- def __init__ (self , in_ch = None , out_ch = None , with_conv = False , fir_kernel = (1 , 3 , 3 , 1 )):
207+ def __init__ (self , channels = None , out_channels = None , use_conv = False , fir_kernel = (1 , 3 , 3 , 1 )):
207208 super ().__init__ ()
208- out_ch = out_ch if out_ch else in_ch
209- if with_conv :
210- self .Conv2d_0 = self .Conv2d_0 = Conv2d (in_ch , out_ch , kernel_size = 3 , stride = 1 , padding = 1 )
209+ out_channels = out_channels if out_channels else channels
210+ if use_conv :
211+ self .Conv2d_0 = self .Conv2d_0 = Conv2d (channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
211212 self .fir_kernel = fir_kernel
212- self .with_conv = with_conv
213- self .out_ch = out_ch
213+ self .use_conv = use_conv
214+ self .out_channels = out_channels
214215
215216 def forward (self , x ):
216- if self .with_conv :
217+ if self .use_conv :
217218 x = _conv_downsample_2d (x , self .Conv2d_0 .weight , k = self .fir_kernel )
219+ x = x + self .Conv2d_0 .bias .reshape (1 , - 1 , 1 , 1 )
218220 else :
219221 x = downsample_2d (x , self .fir_kernel , factor = 2 )
220222
@@ -228,13 +230,14 @@ def __init__(
228230 self ,
229231 image_size = 1024 ,
230232 num_channels = 3 ,
233+ centered = False ,
231234 attn_resolutions = (16 ,),
232235 ch_mult = (1 , 2 , 4 , 8 , 16 , 32 , 32 , 32 ),
233236 conditional = True ,
234237 conv_size = 3 ,
235238 dropout = 0.0 ,
236239 embedding_type = "fourier" ,
237- fir = True , # TODO (patil-suraj) remove this option from here and pre-trained model configs
240+ fir = True ,
238241 fir_kernel = (1 , 3 , 3 , 1 ),
239242 fourier_scale = 16 ,
240243 init_scale = 0.0 ,
@@ -252,12 +255,14 @@ def __init__(
252255 self .register_to_config (
253256 image_size = image_size ,
254257 num_channels = num_channels ,
258+ centered = centered ,
255259 attn_resolutions = attn_resolutions ,
256260 ch_mult = ch_mult ,
257261 conditional = conditional ,
258262 conv_size = conv_size ,
259263 dropout = dropout ,
260264 embedding_type = embedding_type ,
265+ fir = fir ,
261266 fir_kernel = fir_kernel ,
262267 fourier_scale = fourier_scale ,
263268 init_scale = init_scale ,
@@ -307,24 +312,32 @@ def __init__(
307312 modules .append (Linear (nf * 4 , nf * 4 ))
308313
309314 AttnBlock = functools .partial (AttentionBlock , overwrite_linear = True , rescale_output_factor = math .sqrt (2.0 ))
310- Up_sample = functools .partial (FirUpsample , with_conv = resamp_with_conv , fir_kernel = fir_kernel )
315+
316+ if self .fir :
317+ Up_sample = functools .partial (FirUpsample , fir_kernel = fir_kernel , use_conv = resamp_with_conv )
318+ else :
319+ Up_sample = functools .partial (Upsample , name = "Conv2d_0" )
311320
312321 if progressive == "output_skip" :
313- self .pyramid_upsample = Up_sample (fir_kernel = fir_kernel , with_conv = False )
322+ self .pyramid_upsample = Up_sample (channels = None , use_conv = False )
314323 elif progressive == "residual" :
315- pyramid_upsample = functools .partial (Up_sample , fir_kernel = fir_kernel , with_conv = True )
324+ pyramid_upsample = functools .partial (Up_sample , use_conv = True )
316325
317- Down_sample = functools .partial (FirDownsample , with_conv = resamp_with_conv , fir_kernel = fir_kernel )
326+ if self .fir :
327+ Down_sample = functools .partial (FirDownsample , fir_kernel = fir_kernel , use_conv = resamp_with_conv )
328+ else :
329+ Down_sample = functools .partial (Downsample , padding = 0 , name = "Conv2d_0" )
318330
319331 if progressive_input == "input_skip" :
320- self .pyramid_downsample = Down_sample (fir_kernel = fir_kernel , with_conv = False )
332+ self .pyramid_downsample = Down_sample (channels = None , use_conv = False )
321333 elif progressive_input == "residual" :
322- pyramid_downsample = functools .partial (Down_sample , fir_kernel = fir_kernel , with_conv = True )
334+ pyramid_downsample = functools .partial (Down_sample , use_conv = True )
323335
324336 ResnetBlock = functools .partial (
325337 ResnetBlockBigGANpp ,
326338 act = act ,
327339 dropout = dropout ,
340+ fir = fir ,
328341 fir_kernel = fir_kernel ,
329342 init_scale = init_scale ,
330343 skip_rescale = skip_rescale ,
@@ -361,7 +374,7 @@ def __init__(
361374 in_ch *= 2
362375
363376 elif progressive_input == "residual" :
364- modules .append (pyramid_downsample (in_ch = input_pyramid_ch , out_ch = in_ch ))
377+ modules .append (pyramid_downsample (channels = input_pyramid_ch , out_channels = in_ch ))
365378 input_pyramid_ch = in_ch
366379
367380 hs_c .append (in_ch )
@@ -402,7 +415,7 @@ def __init__(
402415 )
403416 pyramid_ch = channels
404417 elif progressive == "residual" :
405- modules .append (pyramid_upsample (in_ch = pyramid_ch , out_ch = in_ch ))
418+ modules .append (pyramid_upsample (channels = pyramid_ch , out_channels = in_ch ))
406419 pyramid_ch = in_ch
407420 else :
408421 raise ValueError (f"{ progressive } is not a valid name" )
@@ -446,7 +459,8 @@ def forward(self, x, timesteps, sigmas=None):
446459 temb = None
447460
448461 # If input data is in [0, 1]
449- x = 2 * x - 1.0
462+ if not self .config .centered :
463+ x = 2 * x - 1.0
450464
451465 # Downsampling block
452466 input_pyramid = None
0 commit comments