2323
2424from ..configuration_utils import ConfigMixin , register_to_config
2525from ..utils import BaseOutput , logging
26- from .attention_processor import (
27- AttentionProcessor ,
28- )
26+ from .attention_processor import USE_PEFT_BACKEND , AttentionProcessor
2927from .autoencoders import AutoencoderKL
3028from .lora import LoRACompatibleConv
3129from .modeling_utils import ModelMixin
@@ -817,11 +815,23 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no,
817815 norm_kwargs = {a : getattr (old_norm1 , a ) for a in norm_args }
818816 norm_kwargs ["num_channels" ] += by # surgery done here
819817 # conv1
820- conv1_args = (
821- "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer" .split (" " )
822- )
818+ conv1_args = [
819+ "in_channels" ,
820+ "out_channels" ,
821+ "kernel_size" ,
822+ "stride" ,
823+ "padding" ,
824+ "dilation" ,
825+ "groups" ,
826+ "bias" ,
827+ "padding_mode" ,
828+ ]
829+ if not USE_PEFT_BACKEND :
830+ conv1_args .append ("lora_layer" )
831+
823832 for a in conv1_args :
824833 assert hasattr (old_conv1 , a )
834+
825835 conv1_kwargs = {a : getattr (old_conv1 , a ) for a in conv1_args }
826836 conv1_kwargs ["bias" ] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
827837 conv1_kwargs ["in_channels" ] += by # surgery done here
@@ -839,25 +849,42 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no,
839849 }
840850 # swap old with new modules
841851 unet .down_blocks [block_no ].resnets [resnet_idx ].norm1 = GroupNorm (** norm_kwargs )
842- unet .down_blocks [block_no ].resnets [resnet_idx ].conv1 = LoRACompatibleConv (** conv1_kwargs )
843- unet .down_blocks [block_no ].resnets [resnet_idx ].conv_shortcut = LoRACompatibleConv (** conv_shortcut_args_kwargs )
852+ unet .down_blocks [block_no ].resnets [resnet_idx ].conv1 = (
853+ nn .Conv2d (** conv1_kwargs ) if USE_PEFT_BACKEND else LoRACompatibleConv (** conv1_kwargs )
854+ )
855+ unet .down_blocks [block_no ].resnets [resnet_idx ].conv_shortcut = (
856+ nn .Conv2d (** conv_shortcut_args_kwargs ) if USE_PEFT_BACKEND else LoRACompatibleConv (** conv_shortcut_args_kwargs )
857+ )
844858 unet .down_blocks [block_no ].resnets [resnet_idx ].in_channels += by # surgery done here
845859
846860
847861def increase_block_input_in_encoder_downsampler (unet : UNet2DConditionModel , block_no , by ):
848862 """Increase channels sizes to allow for additional concatted information from base model"""
849863 old_down = unet .down_blocks [block_no ].downsamplers [0 ].conv
850- # conv1
851- args = "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer" .split (
852- " "
853- )
864+
865+ args = [
866+ "in_channels" ,
867+ "out_channels" ,
868+ "kernel_size" ,
869+ "stride" ,
870+ "padding" ,
871+ "dilation" ,
872+ "groups" ,
873+ "bias" ,
874+ "padding_mode" ,
875+ ]
876+ if not USE_PEFT_BACKEND :
877+ args .append ("lora_layer" )
878+
854879 for a in args :
855880 assert hasattr (old_down , a )
856881 kwargs = {a : getattr (old_down , a ) for a in args }
857882 kwargs ["bias" ] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor.
858883 kwargs ["in_channels" ] += by # surgery done here
859884 # swap old with new modules
860- unet .down_blocks [block_no ].downsamplers [0 ].conv = LoRACompatibleConv (** kwargs )
885+ unet .down_blocks [block_no ].downsamplers [0 ].conv = (
886+ nn .Conv2d (** kwargs ) if USE_PEFT_BACKEND else LoRACompatibleConv (** kwargs )
887+ )
861888 unet .down_blocks [block_no ].downsamplers [0 ].channels += by # surgery done here
862889
863890
@@ -871,12 +898,20 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
871898 assert hasattr (old_norm1 , a )
872899 norm_kwargs = {a : getattr (old_norm1 , a ) for a in norm_args }
873900 norm_kwargs ["num_channels" ] += by # surgery done here
874- # conv1
875- conv1_args = (
876- "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer" .split (" " )
877- )
878- for a in conv1_args :
879- assert hasattr (old_conv1 , a )
901+ conv1_args = [
902+ "in_channels" ,
903+ "out_channels" ,
904+ "kernel_size" ,
905+ "stride" ,
906+ "padding" ,
907+ "dilation" ,
908+ "groups" ,
909+ "bias" ,
910+ "padding_mode" ,
911+ ]
912+ if not USE_PEFT_BACKEND :
913+ conv1_args .append ("lora_layer" )
914+
880915 conv1_kwargs = {a : getattr (old_conv1 , a ) for a in conv1_args }
881916 conv1_kwargs ["bias" ] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
882917 conv1_kwargs ["in_channels" ] += by # surgery done here
@@ -894,8 +929,12 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
894929 }
895930 # swap old with new modules
896931 unet .mid_block .resnets [0 ].norm1 = GroupNorm (** norm_kwargs )
897- unet .mid_block .resnets [0 ].conv1 = LoRACompatibleConv (** conv1_kwargs )
898- unet .mid_block .resnets [0 ].conv_shortcut = LoRACompatibleConv (** conv_shortcut_args_kwargs )
932+ unet .mid_block .resnets [0 ].conv1 = (
933+ nn .Conv2d (** conv1_kwargs ) if USE_PEFT_BACKEND else LoRACompatibleConv (** conv1_kwargs )
934+ )
935+ unet .mid_block .resnets [0 ].conv_shortcut = (
936+ nn .Conv2d (** conv_shortcut_args_kwargs ) if USE_PEFT_BACKEND else LoRACompatibleConv (** conv_shortcut_args_kwargs )
937+ )
899938 unet .mid_block .resnets [0 ].in_channels += by # surgery done here
900939
901940
0 commit comments