Skip to content

Commit 531e719

Browse files
sayakpaulyiyixuxu
andauthored
[LoRA] use the PyTorch classes wherever needed and start depcrecation cycles (huggingface#7204)
* fix PyTorch classes and start deprecsation cycles. * remove args crafting for accommodating scale. * remove scale check in feedforward. * assert against nn.Linear and not CompatibleLinear. * remove conv_cls and lineaR_cls. * remove scale * 👋 scale. * fix: unet2dcondition * fix attention.py * fix: attention.py again * fix: unet_2d_blocks. * fix-copies. * more fixes. * fix: resnet.py * more fixes * fix i2vgenxl unet. * depcrecate scale gently. * fix-copies * Apply suggestions from code review Co-authored-by: YiYi Xu <[email protected]> * quality * throw warning when scale is passed to the the BasicTransformerBlock class. * remove scale from signature. * cross_attention_kwargs, very nice catch by Yiyi * fix: logger.warn * make deprecation message clearer. * address final comments. * maintain same depcrecation message and also add it to activations. * address yiyi * fix copies * Apply suggestions from code review Co-authored-by: YiYi Xu <[email protected]> * more depcrecation * fix-copies --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent 4fbd310 commit 531e719

File tree

17 files changed

+403
-351
lines changed

17 files changed

+403
-351
lines changed

src/diffusers/models/activations.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
import torch.nn.functional as F
1818
from torch import nn
1919

20-
from ..utils import USE_PEFT_BACKEND
21-
from .lora import LoRACompatibleLinear
20+
from ..utils import deprecate
2221

2322

2423
ACTIVATION_FUNCTIONS = {
@@ -87,19 +86,20 @@ class GEGLU(nn.Module):
8786

8887
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
8988
super().__init__()
90-
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
91-
92-
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
89+
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
9390

9491
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
9592
if gate.device.type != "mps":
9693
return F.gelu(gate)
9794
# mps: gelu is not implemented for float16
9895
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
9996

100-
def forward(self, hidden_states, scale: float = 1.0):
101-
args = () if USE_PEFT_BACKEND else (scale,)
102-
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
97+
def forward(self, hidden_states, *args, **kwargs):
98+
if len(args) > 0 or kwargs.get("scale", None) is not None:
99+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
100+
deprecate("scale", "1.0.0", deprecation_message)
101+
102+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
103103
return hidden_states * self.gelu(gate)
104104

105105

src/diffusers/models/attention.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,29 @@
1717
import torch.nn.functional as F
1818
from torch import nn
1919

20-
from ..utils import USE_PEFT_BACKEND
20+
from ..utils import deprecate, logging
2121
from ..utils.torch_utils import maybe_allow_in_graph
2222
from .activations import GEGLU, GELU, ApproximateGELU
2323
from .attention_processor import Attention
2424
from .embeddings import SinusoidalPositionalEmbedding
25-
from .lora import LoRACompatibleLinear
2625
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
2726

2827

29-
def _chunked_feed_forward(
30-
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
31-
):
28+
logger = logging.get_logger(__name__)
29+
30+
31+
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
3232
# "feed_forward_chunk_size" can be used to save memory
3333
if hidden_states.shape[chunk_dim] % chunk_size != 0:
3434
raise ValueError(
3535
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
3636
)
3737

3838
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39-
if lora_scale is None:
40-
ff_output = torch.cat(
41-
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
42-
dim=chunk_dim,
43-
)
44-
else:
45-
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
46-
ff_output = torch.cat(
47-
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
48-
dim=chunk_dim,
49-
)
50-
39+
ff_output = torch.cat(
40+
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
41+
dim=chunk_dim,
42+
)
5143
return ff_output
5244

5345

@@ -299,6 +291,10 @@ def forward(
299291
class_labels: Optional[torch.LongTensor] = None,
300292
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
301293
) -> torch.FloatTensor:
294+
if cross_attention_kwargs is not None:
295+
if cross_attention_kwargs.get("scale", None) is not None:
296+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
297+
302298
# Notice that normalization is always applied before the real computation in the following blocks.
303299
# 0. Self-Attention
304300
batch_size = hidden_states.shape[0]
@@ -326,10 +322,7 @@ def forward(
326322
if self.pos_embed is not None:
327323
norm_hidden_states = self.pos_embed(norm_hidden_states)
328324

329-
# 1. Retrieve lora scale.
330-
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
331-
332-
# 2. Prepare GLIGEN inputs
325+
# 1. Prepare GLIGEN inputs
333326
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
334327
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
335328

@@ -348,7 +341,7 @@ def forward(
348341
if hidden_states.ndim == 4:
349342
hidden_states = hidden_states.squeeze(1)
350343

351-
# 2.5 GLIGEN Control
344+
# 1.2 GLIGEN Control
352345
if gligen_kwargs is not None:
353346
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
354347

@@ -394,11 +387,9 @@ def forward(
394387

395388
if self._chunk_size is not None:
396389
# "feed_forward_chunk_size" can be used to save memory
397-
ff_output = _chunked_feed_forward(
398-
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
399-
)
390+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
400391
else:
401-
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
392+
ff_output = self.ff(norm_hidden_states)
402393

403394
if self.norm_type == "ada_norm_zero":
404395
ff_output = gate_mlp.unsqueeze(1) * ff_output
@@ -643,7 +634,7 @@ def __init__(
643634
if inner_dim is None:
644635
inner_dim = int(dim * mult)
645636
dim_out = dim_out if dim_out is not None else dim
646-
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
637+
linear_cls = nn.Linear
647638

648639
if activation_fn == "gelu":
649640
act_fn = GELU(dim, inner_dim, bias=bias)
@@ -665,11 +656,10 @@ def __init__(
665656
if final_dropout:
666657
self.net.append(nn.Dropout(dropout))
667658

668-
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
669-
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
659+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
660+
if len(args) > 0 or kwargs.get("scale", None) is not None:
661+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
662+
deprecate("scale", "1.0.0", deprecation_message)
670663
for module in self.net:
671-
if isinstance(module, compatible_cls):
672-
hidden_states = module(hidden_states, scale)
673-
else:
674-
hidden_states = module(hidden_states)
664+
hidden_states = module(hidden_states)
675665
return hidden_states

0 commit comments

Comments
 (0)