1717import torch .nn .functional as F
1818from torch import nn
1919
20- from ..utils import USE_PEFT_BACKEND
20+ from ..utils import deprecate , logging
2121from ..utils .torch_utils import maybe_allow_in_graph
2222from .activations import GEGLU , GELU , ApproximateGELU
2323from .attention_processor import Attention
2424from .embeddings import SinusoidalPositionalEmbedding
25- from .lora import LoRACompatibleLinear
2625from .normalization import AdaLayerNorm , AdaLayerNormContinuous , AdaLayerNormZero , RMSNorm
2726
2827
29- def _chunked_feed_forward (
30- ff : nn .Module , hidden_states : torch .Tensor , chunk_dim : int , chunk_size : int , lora_scale : Optional [float ] = None
31- ):
28+ logger = logging .get_logger (__name__ )
29+
30+
31+ def _chunked_feed_forward (ff : nn .Module , hidden_states : torch .Tensor , chunk_dim : int , chunk_size : int ):
3232 # "feed_forward_chunk_size" can be used to save memory
3333 if hidden_states .shape [chunk_dim ] % chunk_size != 0 :
3434 raise ValueError (
3535 f"`hidden_states` dimension to be chunked: { hidden_states .shape [chunk_dim ]} has to be divisible by chunk size: { chunk_size } . Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
3636 )
3737
3838 num_chunks = hidden_states .shape [chunk_dim ] // chunk_size
39- if lora_scale is None :
40- ff_output = torch .cat (
41- [ff (hid_slice ) for hid_slice in hidden_states .chunk (num_chunks , dim = chunk_dim )],
42- dim = chunk_dim ,
43- )
44- else :
45- # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
46- ff_output = torch .cat (
47- [ff (hid_slice , scale = lora_scale ) for hid_slice in hidden_states .chunk (num_chunks , dim = chunk_dim )],
48- dim = chunk_dim ,
49- )
50-
39+ ff_output = torch .cat (
40+ [ff (hid_slice ) for hid_slice in hidden_states .chunk (num_chunks , dim = chunk_dim )],
41+ dim = chunk_dim ,
42+ )
5143 return ff_output
5244
5345
@@ -299,6 +291,10 @@ def forward(
299291 class_labels : Optional [torch .LongTensor ] = None ,
300292 added_cond_kwargs : Optional [Dict [str , torch .Tensor ]] = None ,
301293 ) -> torch .FloatTensor :
294+ if cross_attention_kwargs is not None :
295+ if cross_attention_kwargs .get ("scale" , None ) is not None :
296+ logger .warning ("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored." )
297+
302298 # Notice that normalization is always applied before the real computation in the following blocks.
303299 # 0. Self-Attention
304300 batch_size = hidden_states .shape [0 ]
@@ -326,10 +322,7 @@ def forward(
326322 if self .pos_embed is not None :
327323 norm_hidden_states = self .pos_embed (norm_hidden_states )
328324
329- # 1. Retrieve lora scale.
330- lora_scale = cross_attention_kwargs .get ("scale" , 1.0 ) if cross_attention_kwargs is not None else 1.0
331-
332- # 2. Prepare GLIGEN inputs
325+ # 1. Prepare GLIGEN inputs
333326 cross_attention_kwargs = cross_attention_kwargs .copy () if cross_attention_kwargs is not None else {}
334327 gligen_kwargs = cross_attention_kwargs .pop ("gligen" , None )
335328
@@ -348,7 +341,7 @@ def forward(
348341 if hidden_states .ndim == 4 :
349342 hidden_states = hidden_states .squeeze (1 )
350343
351- # 2.5 GLIGEN Control
344+ # 1.2 GLIGEN Control
352345 if gligen_kwargs is not None :
353346 hidden_states = self .fuser (hidden_states , gligen_kwargs ["objs" ])
354347
@@ -394,11 +387,9 @@ def forward(
394387
395388 if self ._chunk_size is not None :
396389 # "feed_forward_chunk_size" can be used to save memory
397- ff_output = _chunked_feed_forward (
398- self .ff , norm_hidden_states , self ._chunk_dim , self ._chunk_size , lora_scale = lora_scale
399- )
390+ ff_output = _chunked_feed_forward (self .ff , norm_hidden_states , self ._chunk_dim , self ._chunk_size )
400391 else :
401- ff_output = self .ff (norm_hidden_states , scale = lora_scale )
392+ ff_output = self .ff (norm_hidden_states )
402393
403394 if self .norm_type == "ada_norm_zero" :
404395 ff_output = gate_mlp .unsqueeze (1 ) * ff_output
@@ -643,7 +634,7 @@ def __init__(
643634 if inner_dim is None :
644635 inner_dim = int (dim * mult )
645636 dim_out = dim_out if dim_out is not None else dim
646- linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn .Linear
637+ linear_cls = nn .Linear
647638
648639 if activation_fn == "gelu" :
649640 act_fn = GELU (dim , inner_dim , bias = bias )
@@ -665,11 +656,10 @@ def __init__(
665656 if final_dropout :
666657 self .net .append (nn .Dropout (dropout ))
667658
668- def forward (self , hidden_states : torch .Tensor , scale : float = 1.0 ) -> torch .Tensor :
669- compatible_cls = (GEGLU ,) if USE_PEFT_BACKEND else (GEGLU , LoRACompatibleLinear )
659+ def forward (self , hidden_states : torch .Tensor , * args , ** kwargs ) -> torch .Tensor :
660+ if len (args ) > 0 or kwargs .get ("scale" , None ) is not None :
661+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
662+ deprecate ("scale" , "1.0.0" , deprecation_message )
670663 for module in self .net :
671- if isinstance (module , compatible_cls ):
672- hidden_states = module (hidden_states , scale )
673- else :
674- hidden_states = module (hidden_states )
664+ hidden_states = module (hidden_states )
675665 return hidden_states
0 commit comments