|
1 | 1 | import logging
|
2 | 2 | import math
|
| 3 | +from inspect import isfunction |
3 | 4 | from typing import Any, Optional
|
4 | 5 |
|
5 | 6 | import torch
|
6 | 7 | import torch.nn.functional as F
|
7 | 8 | from einops import rearrange, repeat
|
8 | 9 | from packaging import version
|
9 | 10 | from torch import nn
|
10 |
| -from ..util import exists, default |
11 |
| - |
12 | 11 |
|
13 | 12 |
|
14 | 13 | logger = logging.getLogger(__name__)
|
|
59 | 58 | from .diffusionmodules.util import checkpoint
|
60 | 59 |
|
61 | 60 |
|
62 |
| -def uniq(arr): # TODO: this seems unused |
| 61 | +def exists(val): |
| 62 | + return val is not None |
| 63 | + |
| 64 | + |
| 65 | +def uniq(arr): |
63 | 66 | return {el: True for el in arr}.keys()
|
64 | 67 |
|
65 | 68 |
|
66 |
| -def init_(tensor): # TODO: this seems unused |
| 69 | +def default(val, d): |
| 70 | + if exists(val): |
| 71 | + return val |
| 72 | + return d() if isfunction(d) else d |
| 73 | + |
| 74 | + |
| 75 | +def max_neg_value(t): |
| 76 | + return -torch.finfo(t.dtype).max |
| 77 | + |
| 78 | + |
| 79 | +def init_(tensor): |
67 | 80 | dim = tensor.shape[-1]
|
68 | 81 | std = 1 / math.sqrt(dim)
|
69 | 82 | tensor.uniform_(-std, std)
|
@@ -243,6 +256,23 @@ def forward(
|
243 | 256 |
|
244 | 257 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
245 | 258 |
|
| 259 | + ## old |
| 260 | + """ |
| 261 | + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale |
| 262 | + del q, k |
| 263 | +
|
| 264 | + if exists(mask): |
| 265 | + mask = rearrange(mask, 'b ... -> b (...)') |
| 266 | + max_neg_value = -torch.finfo(sim.dtype).max |
| 267 | + mask = repeat(mask, 'b j -> (b h) () j', h=h) |
| 268 | + sim.masked_fill_(~mask, max_neg_value) |
| 269 | +
|
| 270 | + # attention, what we cannot get enough of |
| 271 | + sim = sim.softmax(dim=-1) |
| 272 | +
|
| 273 | + out = einsum('b i j, b j d -> b i d', sim, v) |
| 274 | + """ |
| 275 | + ## new |
246 | 276 | with sdp_kernel(**BACKEND_MAP[self.backend]):
|
247 | 277 | # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
248 | 278 | out = F.scaled_dot_product_attention(
|
|
0 commit comments