Skip to content

Commit c1efda7

Browse files
[Clean up] Clean unused code (huggingface#245)
* CleanResNet * refactor more * correct
1 parent 4789316 commit c1efda7

File tree

4 files changed

+39
-247
lines changed

4 files changed

+39
-247
lines changed

src/diffusers/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
390390
)
391391
except EntryNotFoundError:
392392
raise EnvironmentError(
393-
f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}."
393+
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
394394
)
395395
except HTTPError as err:
396396
raise EnvironmentError(

src/diffusers/models/attention.py

Lines changed: 6 additions & 214 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import math
2-
from inspect import isfunction
32

43
import torch
54
import torch.nn.functional as F
65
from torch import nn
76

87

9-
class AttentionBlockNew(nn.Module):
8+
class AttentionBlock(nn.Module):
109
"""
1110
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
1211
to the N-d case.
@@ -82,55 +81,6 @@ def forward(self, hidden_states):
8281
hidden_states = (hidden_states + residual) / self.rescale_output_factor
8382
return hidden_states
8483

85-
def set_weight(self, attn_layer):
86-
self.group_norm.weight.data = attn_layer.norm.weight.data
87-
self.group_norm.bias.data = attn_layer.norm.bias.data
88-
89-
if hasattr(attn_layer, "q"):
90-
self.query.weight.data = attn_layer.q.weight.data[:, :, 0, 0]
91-
self.key.weight.data = attn_layer.k.weight.data[:, :, 0, 0]
92-
self.value.weight.data = attn_layer.v.weight.data[:, :, 0, 0]
93-
94-
self.query.bias.data = attn_layer.q.bias.data
95-
self.key.bias.data = attn_layer.k.bias.data
96-
self.value.bias.data = attn_layer.v.bias.data
97-
98-
self.proj_attn.weight.data = attn_layer.proj_out.weight.data[:, :, 0, 0]
99-
self.proj_attn.bias.data = attn_layer.proj_out.bias.data
100-
elif hasattr(attn_layer, "NIN_0"):
101-
self.query.weight.data = attn_layer.NIN_0.W.data.T
102-
self.key.weight.data = attn_layer.NIN_1.W.data.T
103-
self.value.weight.data = attn_layer.NIN_2.W.data.T
104-
105-
self.query.bias.data = attn_layer.NIN_0.b.data
106-
self.key.bias.data = attn_layer.NIN_1.b.data
107-
self.value.bias.data = attn_layer.NIN_2.b.data
108-
109-
self.proj_attn.weight.data = attn_layer.NIN_3.W.data.T
110-
self.proj_attn.bias.data = attn_layer.NIN_3.b.data
111-
112-
self.group_norm.weight.data = attn_layer.GroupNorm_0.weight.data
113-
self.group_norm.bias.data = attn_layer.GroupNorm_0.bias.data
114-
else:
115-
qkv_weight = attn_layer.qkv.weight.data.reshape(
116-
self.num_heads, 3 * self.channels // self.num_heads, self.channels
117-
)
118-
qkv_bias = attn_layer.qkv.bias.data.reshape(self.num_heads, 3 * self.channels // self.num_heads)
119-
120-
q_w, k_w, v_w = qkv_weight.split(self.channels // self.num_heads, dim=1)
121-
q_b, k_b, v_b = qkv_bias.split(self.channels // self.num_heads, dim=1)
122-
123-
self.query.weight.data = q_w.reshape(-1, self.channels)
124-
self.key.weight.data = k_w.reshape(-1, self.channels)
125-
self.value.weight.data = v_w.reshape(-1, self.channels)
126-
127-
self.query.bias.data = q_b.reshape(-1)
128-
self.key.bias.data = k_b.reshape(-1)
129-
self.value.bias.data = v_b.reshape(-1)
130-
131-
self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0]
132-
self.proj_attn.bias.data = attn_layer.proj.bias.data
133-
13484

13585
class SpatialTransformer(nn.Module):
13686
"""
@@ -170,12 +120,6 @@ def forward(self, x, context=None):
170120
x = self.proj_out(x)
171121
return x + x_in
172122

173-
def set_weight(self, layer):
174-
self.norm = layer.norm
175-
self.proj_in = layer.proj_in
176-
self.transformer_blocks = layer.transformer_blocks
177-
self.proj_out = layer.proj_out
178-
179123

180124
class BasicTransformerBlock(nn.Module):
181125
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
@@ -203,7 +147,7 @@ class CrossAttention(nn.Module):
203147
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
204148
super().__init__()
205149
inner_dim = dim_head * heads
206-
context_dim = default(context_dim, query_dim)
150+
context_dim = context_dim if context_dim is not None else query_dim
207151

208152
self.scale = dim_head**-0.5
209153
self.heads = heads
@@ -234,7 +178,7 @@ def forward(self, x, context=None, mask=None):
234178
h = self.heads
235179

236180
q = self.to_q(x)
237-
context = default(context, x)
181+
context = context if context is not None else x
238182
k = self.to_k(context)
239183
v = self.to_v(context)
240184

@@ -244,7 +188,7 @@ def forward(self, x, context=None, mask=None):
244188

245189
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
246190

247-
if exists(mask):
191+
if mask is not None:
248192
mask = mask.reshape(batch_size, -1)
249193
max_neg_value = -torch.finfo(sim.dtype).max
250194
mask = mask[:, None, :].repeat(h, 1, 1)
@@ -262,8 +206,8 @@ class FeedForward(nn.Module):
262206
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
263207
super().__init__()
264208
inner_dim = int(dim * mult)
265-
dim_out = default(dim_out, dim)
266-
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
209+
dim_out = dim_out if dim_out is not None else dim
210+
project_in = GEGLU(dim, inner_dim)
267211

268212
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
269213

@@ -280,155 +224,3 @@ def __init__(self, dim_in, dim_out):
280224
def forward(self, x):
281225
x, gate = self.proj(x).chunk(2, dim=-1)
282226
return x * F.gelu(gate)
283-
284-
285-
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
286-
class NIN(nn.Module):
287-
def __init__(self, in_dim, num_units, init_scale=0.1):
288-
super().__init__()
289-
self.W = nn.Parameter(torch.zeros(in_dim, num_units), requires_grad=True)
290-
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
291-
292-
293-
def exists(val):
294-
return val is not None
295-
296-
297-
def default(val, d):
298-
if exists(val):
299-
return val
300-
return d() if isfunction(d) else d
301-
302-
303-
# the main attention block that is used for all models
304-
class AttentionBlock(nn.Module):
305-
"""
306-
An attention block that allows spatial positions to attend to each other.
307-
308-
Originally ported from here, but adapted to the N-d case.
309-
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
310-
"""
311-
312-
def __init__(
313-
self,
314-
channels,
315-
num_heads=1,
316-
num_head_channels=None,
317-
num_groups=32,
318-
encoder_channels=None,
319-
overwrite_qkv=False,
320-
overwrite_linear=False,
321-
rescale_output_factor=1.0,
322-
eps=1e-5,
323-
):
324-
super().__init__()
325-
self.channels = channels
326-
if num_head_channels is None:
327-
self.num_heads = num_heads
328-
else:
329-
assert (
330-
channels % num_head_channels == 0
331-
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
332-
self.num_heads = channels // num_head_channels
333-
334-
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
335-
self.qkv = nn.Conv1d(channels, channels * 3, 1)
336-
self.n_heads = self.num_heads
337-
self.rescale_output_factor = rescale_output_factor
338-
339-
if encoder_channels is not None:
340-
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
341-
342-
self.proj = nn.Conv1d(channels, channels, 1)
343-
344-
self.overwrite_qkv = overwrite_qkv
345-
self.overwrite_linear = overwrite_linear
346-
347-
if overwrite_qkv:
348-
in_channels = channels
349-
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
350-
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
351-
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
352-
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
353-
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
354-
elif self.overwrite_linear:
355-
num_groups = min(channels // 4, 32)
356-
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
357-
self.NIN_0 = NIN(channels, channels)
358-
self.NIN_1 = NIN(channels, channels)
359-
self.NIN_2 = NIN(channels, channels)
360-
self.NIN_3 = NIN(channels, channels)
361-
362-
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
363-
else:
364-
self.proj_out = nn.Conv1d(channels, channels, 1)
365-
self.set_weights(self)
366-
367-
self.is_overwritten = False
368-
369-
def set_weights(self, module):
370-
if self.overwrite_qkv:
371-
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
372-
:, :, :, 0
373-
]
374-
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
375-
376-
self.qkv.weight.data = qkv_weight
377-
self.qkv.bias.data = qkv_bias
378-
379-
proj_out = nn.Conv1d(self.channels, self.channels, 1)
380-
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
381-
proj_out.bias.data = module.proj_out.bias.data
382-
383-
self.proj = proj_out
384-
elif self.overwrite_linear:
385-
self.qkv.weight.data = torch.concat(
386-
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
387-
)[:, :, None]
388-
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
389-
390-
self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
391-
self.proj.bias.data = self.NIN_3.b.data
392-
393-
self.norm.weight.data = self.GroupNorm_0.weight.data
394-
self.norm.bias.data = self.GroupNorm_0.bias.data
395-
else:
396-
self.proj.weight.data = self.proj_out.weight.data
397-
self.proj.bias.data = self.proj_out.bias.data
398-
399-
def forward(self, x, encoder_out=None):
400-
if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
401-
self.set_weights(self)
402-
self.is_overwritten = True
403-
404-
b, c, *spatial = x.shape
405-
hid_states = self.norm(x).view(b, c, -1)
406-
407-
qkv = self.qkv(hid_states)
408-
bs, width, length = qkv.shape
409-
assert width % (3 * self.n_heads) == 0
410-
ch = width // (3 * self.n_heads)
411-
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
412-
413-
if encoder_out is not None:
414-
encoder_kv = self.encoder_kv(encoder_out)
415-
assert encoder_kv.shape[1] == self.n_heads * ch * 2
416-
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
417-
k = torch.cat([ek, k], dim=-1)
418-
v = torch.cat([ev, v], dim=-1)
419-
420-
scale = 1 / math.sqrt(math.sqrt(ch))
421-
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
422-
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
423-
424-
a = torch.einsum("bts,bcs->bct", weight, v)
425-
h = a.reshape(bs, -1, length)
426-
427-
h = self.proj(h)
428-
h = h.reshape(b, c, *spatial)
429-
430-
result = x + h
431-
432-
result = result / self.rescale_output_factor
433-
434-
return result

src/diffusers/models/resnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def forward(self, x):
248248
return x
249249

250250

251-
class ResnetBlock(nn.Module):
251+
class ResnetBlock2D(nn.Module):
252252
def __init__(
253253
self,
254254
*,

0 commit comments

Comments
 (0)