Skip to content

Commit 1da2509

Browse files
authored
Revert "Dead code removal (Stability-AI#48)" (Stability-AI#62)
This reverts commit b5b5680.
1 parent a4ceca6 commit 1da2509

File tree

1 file changed

+34
-4
lines changed

1 file changed

+34
-4
lines changed

sgm/modules/attention.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import logging
22
import math
3+
from inspect import isfunction
34
from typing import Any, Optional
45

56
import torch
67
import torch.nn.functional as F
78
from einops import rearrange, repeat
89
from packaging import version
910
from torch import nn
10-
from ..util import exists, default
11-
1211

1312

1413
logger = logging.getLogger(__name__)
@@ -59,11 +58,25 @@
5958
from .diffusionmodules.util import checkpoint
6059

6160

62-
def uniq(arr): # TODO: this seems unused
61+
def exists(val):
62+
return val is not None
63+
64+
65+
def uniq(arr):
6366
return {el: True for el in arr}.keys()
6467

6568

66-
def init_(tensor): # TODO: this seems unused
69+
def default(val, d):
70+
if exists(val):
71+
return val
72+
return d() if isfunction(d) else d
73+
74+
75+
def max_neg_value(t):
76+
return -torch.finfo(t.dtype).max
77+
78+
79+
def init_(tensor):
6780
dim = tensor.shape[-1]
6881
std = 1 / math.sqrt(dim)
6982
tensor.uniform_(-std, std)
@@ -243,6 +256,23 @@ def forward(
243256

244257
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
245258

259+
## old
260+
"""
261+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
262+
del q, k
263+
264+
if exists(mask):
265+
mask = rearrange(mask, 'b ... -> b (...)')
266+
max_neg_value = -torch.finfo(sim.dtype).max
267+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
268+
sim.masked_fill_(~mask, max_neg_value)
269+
270+
# attention, what we cannot get enough of
271+
sim = sim.softmax(dim=-1)
272+
273+
out = einsum('b i j, b j d -> b i d', sim, v)
274+
"""
275+
## new
246276
with sdp_kernel(**BACKEND_MAP[self.backend]):
247277
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
248278
out = F.scaled_dot_product_attention(

0 commit comments

Comments
 (0)