Skip to content

Commit b5b5680

Browse files
authored
Dead code removal (Stability-AI#48)
* Remove old commented-out attention code * Mark two functions as likely unused * Use exists() and default() from sgm.util
1 parent 6f6d3f8 commit b5b5680

File tree

1 file changed

+4
-34
lines changed

1 file changed

+4
-34
lines changed

sgm/modules/attention.py

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import logging
22
import math
3-
from inspect import isfunction
43
from typing import Any, Optional
54

65
import torch
76
import torch.nn.functional as F
87
from einops import rearrange, repeat
98
from packaging import version
109
from torch import nn
10+
from ..util import exists, default
11+
1112

1213

1314
logger = logging.getLogger(__name__)
@@ -58,25 +59,11 @@
5859
from .diffusionmodules.util import checkpoint
5960

6061

61-
def exists(val):
62-
return val is not None
63-
64-
65-
def uniq(arr):
62+
def uniq(arr): # TODO: this seems unused
6663
return {el: True for el in arr}.keys()
6764

6865

69-
def default(val, d):
70-
if exists(val):
71-
return val
72-
return d() if isfunction(d) else d
73-
74-
75-
def max_neg_value(t):
76-
return -torch.finfo(t.dtype).max
77-
78-
79-
def init_(tensor):
66+
def init_(tensor): # TODO: this seems unused
8067
dim = tensor.shape[-1]
8168
std = 1 / math.sqrt(dim)
8269
tensor.uniform_(-std, std)
@@ -256,23 +243,6 @@ def forward(
256243

257244
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
258245

259-
## old
260-
"""
261-
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
262-
del q, k
263-
264-
if exists(mask):
265-
mask = rearrange(mask, 'b ... -> b (...)')
266-
max_neg_value = -torch.finfo(sim.dtype).max
267-
mask = repeat(mask, 'b j -> (b h) () j', h=h)
268-
sim.masked_fill_(~mask, max_neg_value)
269-
270-
# attention, what we cannot get enough of
271-
sim = sim.softmax(dim=-1)
272-
273-
out = einsum('b i j, b j d -> b i d', sim, v)
274-
"""
275-
## new
276246
with sdp_kernel(**BACKEND_MAP[self.backend]):
277247
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
278248
out = F.scaled_dot_product_attention(

0 commit comments

Comments
 (0)