Skip to content

Commit c691bb2

Browse files
authored
Merge pull request huggingface#60 from huggingface/add-fir-back
fix unde sde for vp model.
2 parents abedfb0 + 4c293e0 commit c691bb2

File tree

2 files changed

+75
-33
lines changed

2 files changed

+75
-33
lines changed

src/diffusers/models/resnet.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import abstractmethod
2+
from functools import partial
23

34
import numpy as np
45
import torch
@@ -78,18 +79,25 @@ class Upsample(nn.Module):
7879
upsampling occurs in the inner-two dimensions.
7980
"""
8081

81-
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
82+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"):
8283
super().__init__()
8384
self.channels = channels
8485
self.out_channels = out_channels or channels
8586
self.use_conv = use_conv
8687
self.dims = dims
8788
self.use_conv_transpose = use_conv_transpose
89+
self.name = name
8890

91+
conv = None
8992
if use_conv_transpose:
90-
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
93+
conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
9194
elif use_conv:
92-
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
95+
conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
96+
97+
if name == "conv":
98+
self.conv = conv
99+
else:
100+
self.Conv2d_0 = conv
93101

94102
def forward(self, x):
95103
assert x.shape[1] == self.channels
@@ -102,7 +110,10 @@ def forward(self, x):
102110
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
103111

104112
if self.use_conv:
105-
x = self.conv(x)
113+
if self.name == "conv":
114+
x = self.conv(x)
115+
else:
116+
x = self.Conv2d_0(x)
106117

107118
return x
108119

@@ -134,6 +145,8 @@ def __init__(self, channels, use_conv=False, dims=2, out_channels=None, padding=
134145

135146
if name == "conv":
136147
self.conv = conv
148+
elif name == "Conv2d_0":
149+
self.Conv2d_0 = conv
137150
else:
138151
self.op = conv
139152

@@ -145,6 +158,8 @@ def forward(self, x):
145158

146159
if self.name == "conv":
147160
return self.conv(x)
161+
elif self.name == "Conv2d_0":
162+
return self.Conv2d_0(x)
148163
else:
149164
return self.op(x)
150165

@@ -390,6 +405,7 @@ def __init__(
390405
up=False,
391406
down=False,
392407
dropout=0.1,
408+
fir=False,
393409
fir_kernel=(1, 3, 3, 1),
394410
skip_rescale=True,
395411
init_scale=0.0,
@@ -400,8 +416,20 @@ def __init__(
400416
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
401417
self.up = up
402418
self.down = down
419+
self.fir = fir
403420
self.fir_kernel = fir_kernel
404421

422+
if self.up:
423+
if self.fir:
424+
self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2)
425+
else:
426+
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
427+
elif self.down:
428+
if self.fir:
429+
self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2)
430+
else:
431+
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
432+
405433
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
406434
if temb_dim is not None:
407435
self.Dense_0 = nn.Linear(temb_dim, out_ch)
@@ -424,11 +452,11 @@ def forward(self, x, temb=None):
424452
h = self.act(self.GroupNorm_0(x))
425453

426454
if self.up:
427-
h = upsample_2d(h, self.fir_kernel, factor=2)
428-
x = upsample_2d(x, self.fir_kernel, factor=2)
455+
h = self.upsample(h)
456+
x = self.upsample(x)
429457
elif self.down:
430-
h = downsample_2d(h, self.fir_kernel, factor=2)
431-
x = downsample_2d(x, self.fir_kernel, factor=2)
458+
h = self.downsample(h)
459+
x = self.downsample(x)
432460

433461
h = self.Conv_0(h)
434462
# Add bias to each feature map conditioned on the time embedding

src/diffusers/models/unet_sde_score_estimation.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ..modeling_utils import ModelMixin
2828
from .attention import AttentionBlock
2929
from .embeddings import GaussianFourierProjection, get_timestep_embedding
30-
from .resnet import ResnetBlockBigGANpp, downsample_2d, upfirdn2d, upsample_2d
30+
from .resnet import Downsample, ResnetBlockBigGANpp, Upsample, downsample_2d, upfirdn2d, upsample_2d
3131

3232

3333
def _setup_kernel(k):
@@ -184,37 +184,39 @@ def forward(self, x, y):
184184

185185

186186
class FirUpsample(nn.Module):
187-
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
187+
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
188188
super().__init__()
189-
out_ch = out_ch if out_ch else in_ch
190-
if with_conv:
191-
self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
192-
self.with_conv = with_conv
189+
out_channels = out_channels if out_channels else channels
190+
if use_conv:
191+
self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
192+
self.use_conv = use_conv
193193
self.fir_kernel = fir_kernel
194-
self.out_ch = out_ch
194+
self.out_channels = out_channels
195195

196196
def forward(self, x):
197-
if self.with_conv:
197+
if self.use_conv:
198198
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
199+
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
199200
else:
200201
h = upsample_2d(x, self.fir_kernel, factor=2)
201202

202203
return h
203204

204205

205206
class FirDownsample(nn.Module):
206-
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
207+
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
207208
super().__init__()
208-
out_ch = out_ch if out_ch else in_ch
209-
if with_conv:
210-
self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
209+
out_channels = out_channels if out_channels else channels
210+
if use_conv:
211+
self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
211212
self.fir_kernel = fir_kernel
212-
self.with_conv = with_conv
213-
self.out_ch = out_ch
213+
self.use_conv = use_conv
214+
self.out_channels = out_channels
214215

215216
def forward(self, x):
216-
if self.with_conv:
217+
if self.use_conv:
217218
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
219+
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
218220
else:
219221
x = downsample_2d(x, self.fir_kernel, factor=2)
220222

@@ -228,13 +230,14 @@ def __init__(
228230
self,
229231
image_size=1024,
230232
num_channels=3,
233+
centered=False,
231234
attn_resolutions=(16,),
232235
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
233236
conditional=True,
234237
conv_size=3,
235238
dropout=0.0,
236239
embedding_type="fourier",
237-
fir=True, # TODO (patil-suraj) remove this option from here and pre-trained model configs
240+
fir=True,
238241
fir_kernel=(1, 3, 3, 1),
239242
fourier_scale=16,
240243
init_scale=0.0,
@@ -252,12 +255,14 @@ def __init__(
252255
self.register_to_config(
253256
image_size=image_size,
254257
num_channels=num_channels,
258+
centered=centered,
255259
attn_resolutions=attn_resolutions,
256260
ch_mult=ch_mult,
257261
conditional=conditional,
258262
conv_size=conv_size,
259263
dropout=dropout,
260264
embedding_type=embedding_type,
265+
fir=fir,
261266
fir_kernel=fir_kernel,
262267
fourier_scale=fourier_scale,
263268
init_scale=init_scale,
@@ -307,24 +312,32 @@ def __init__(
307312
modules.append(Linear(nf * 4, nf * 4))
308313

309314
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
310-
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
315+
316+
if self.fir:
317+
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
318+
else:
319+
Up_sample = functools.partial(Upsample, name="Conv2d_0")
311320

312321
if progressive == "output_skip":
313-
self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False)
322+
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
314323
elif progressive == "residual":
315-
pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True)
324+
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
316325

317-
Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
326+
if self.fir:
327+
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
328+
else:
329+
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
318330

319331
if progressive_input == "input_skip":
320-
self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False)
332+
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
321333
elif progressive_input == "residual":
322-
pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True)
334+
pyramid_downsample = functools.partial(Down_sample, use_conv=True)
323335

324336
ResnetBlock = functools.partial(
325337
ResnetBlockBigGANpp,
326338
act=act,
327339
dropout=dropout,
340+
fir=fir,
328341
fir_kernel=fir_kernel,
329342
init_scale=init_scale,
330343
skip_rescale=skip_rescale,
@@ -361,7 +374,7 @@ def __init__(
361374
in_ch *= 2
362375

363376
elif progressive_input == "residual":
364-
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
377+
modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
365378
input_pyramid_ch = in_ch
366379

367380
hs_c.append(in_ch)
@@ -402,7 +415,7 @@ def __init__(
402415
)
403416
pyramid_ch = channels
404417
elif progressive == "residual":
405-
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
418+
modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
406419
pyramid_ch = in_ch
407420
else:
408421
raise ValueError(f"{progressive} is not a valid name")
@@ -446,7 +459,8 @@ def forward(self, x, timesteps, sigmas=None):
446459
temb = None
447460

448461
# If input data is in [0, 1]
449-
x = 2 * x - 1.0
462+
if not self.config.centered:
463+
x = 2 * x - 1.0
450464

451465
# Downsampling block
452466
input_pyramid = None

0 commit comments

Comments
 (0)