|
1 | 1 | import logging
|
2 | 2 | import math
|
3 |
| -from inspect import isfunction |
4 | 3 | from typing import Any, Optional
|
5 | 4 |
|
6 | 5 | import torch
|
7 | 6 | import torch.nn.functional as F
|
8 | 7 | from einops import rearrange, repeat
|
9 | 8 | from packaging import version
|
10 | 9 | from torch import nn
|
| 10 | +from ..util import exists, default |
| 11 | + |
11 | 12 |
|
12 | 13 |
|
13 | 14 | logger = logging.getLogger(__name__)
|
|
58 | 59 | from .diffusionmodules.util import checkpoint
|
59 | 60 |
|
60 | 61 |
|
61 |
| -def exists(val): |
62 |
| - return val is not None |
63 |
| - |
64 |
| - |
65 |
| -def uniq(arr): |
| 62 | +def uniq(arr): # TODO: this seems unused |
66 | 63 | return {el: True for el in arr}.keys()
|
67 | 64 |
|
68 | 65 |
|
69 |
| -def default(val, d): |
70 |
| - if exists(val): |
71 |
| - return val |
72 |
| - return d() if isfunction(d) else d |
73 |
| - |
74 |
| - |
75 |
| -def max_neg_value(t): |
76 |
| - return -torch.finfo(t.dtype).max |
77 |
| - |
78 |
| - |
79 |
| -def init_(tensor): |
| 66 | +def init_(tensor): # TODO: this seems unused |
80 | 67 | dim = tensor.shape[-1]
|
81 | 68 | std = 1 / math.sqrt(dim)
|
82 | 69 | tensor.uniform_(-std, std)
|
@@ -256,23 +243,6 @@ def forward(
|
256 | 243 |
|
257 | 244 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
258 | 245 |
|
259 |
| - ## old |
260 |
| - """ |
261 |
| - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale |
262 |
| - del q, k |
263 |
| -
|
264 |
| - if exists(mask): |
265 |
| - mask = rearrange(mask, 'b ... -> b (...)') |
266 |
| - max_neg_value = -torch.finfo(sim.dtype).max |
267 |
| - mask = repeat(mask, 'b j -> (b h) () j', h=h) |
268 |
| - sim.masked_fill_(~mask, max_neg_value) |
269 |
| -
|
270 |
| - # attention, what we cannot get enough of |
271 |
| - sim = sim.softmax(dim=-1) |
272 |
| -
|
273 |
| - out = einsum('b i j, b j d -> b i d', sim, v) |
274 |
| - """ |
275 |
| - ## new |
276 | 246 | with sdp_kernel(**BACKEND_MAP[self.backend]):
|
277 | 247 | # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
278 | 248 | out = F.scaled_dot_product_attention(
|
|
0 commit comments