Skip to content

Commit 32ff477

Browse files
authored
ControlNetXS fixes. (huggingface#6228)
update
1 parent 288ceeb commit 32ff477

File tree

2 files changed

+67
-23
lines changed

2 files changed

+67
-23
lines changed

src/diffusers/models/controlnetxs.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
2525
from ..utils import BaseOutput, logging
26-
from .attention_processor import (
27-
AttentionProcessor,
28-
)
26+
from .attention_processor import USE_PEFT_BACKEND, AttentionProcessor
2927
from .autoencoders import AutoencoderKL
3028
from .lora import LoRACompatibleConv
3129
from .modeling_utils import ModelMixin
@@ -817,11 +815,23 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no,
817815
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
818816
norm_kwargs["num_channels"] += by # surgery done here
819817
# conv1
820-
conv1_args = (
821-
"in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ")
822-
)
818+
conv1_args = [
819+
"in_channels",
820+
"out_channels",
821+
"kernel_size",
822+
"stride",
823+
"padding",
824+
"dilation",
825+
"groups",
826+
"bias",
827+
"padding_mode",
828+
]
829+
if not USE_PEFT_BACKEND:
830+
conv1_args.append("lora_layer")
831+
823832
for a in conv1_args:
824833
assert hasattr(old_conv1, a)
834+
825835
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
826836
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
827837
conv1_kwargs["in_channels"] += by # surgery done here
@@ -839,25 +849,42 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no,
839849
}
840850
# swap old with new modules
841851
unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs)
842-
unet.down_blocks[block_no].resnets[resnet_idx].conv1 = LoRACompatibleConv(**conv1_kwargs)
843-
unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs)
852+
unet.down_blocks[block_no].resnets[resnet_idx].conv1 = (
853+
nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
854+
)
855+
unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = (
856+
nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
857+
)
844858
unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here
845859

846860

847861
def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by):
848862
"""Increase channels sizes to allow for additional concatted information from base model"""
849863
old_down = unet.down_blocks[block_no].downsamplers[0].conv
850-
# conv1
851-
args = "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(
852-
" "
853-
)
864+
865+
args = [
866+
"in_channels",
867+
"out_channels",
868+
"kernel_size",
869+
"stride",
870+
"padding",
871+
"dilation",
872+
"groups",
873+
"bias",
874+
"padding_mode",
875+
]
876+
if not USE_PEFT_BACKEND:
877+
args.append("lora_layer")
878+
854879
for a in args:
855880
assert hasattr(old_down, a)
856881
kwargs = {a: getattr(old_down, a) for a in args}
857882
kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor.
858883
kwargs["in_channels"] += by # surgery done here
859884
# swap old with new modules
860-
unet.down_blocks[block_no].downsamplers[0].conv = LoRACompatibleConv(**kwargs)
885+
unet.down_blocks[block_no].downsamplers[0].conv = (
886+
nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs)
887+
)
861888
unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here
862889

863890

@@ -871,12 +898,20 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
871898
assert hasattr(old_norm1, a)
872899
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
873900
norm_kwargs["num_channels"] += by # surgery done here
874-
# conv1
875-
conv1_args = (
876-
"in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ")
877-
)
878-
for a in conv1_args:
879-
assert hasattr(old_conv1, a)
901+
conv1_args = [
902+
"in_channels",
903+
"out_channels",
904+
"kernel_size",
905+
"stride",
906+
"padding",
907+
"dilation",
908+
"groups",
909+
"bias",
910+
"padding_mode",
911+
]
912+
if not USE_PEFT_BACKEND:
913+
conv1_args.append("lora_layer")
914+
880915
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
881916
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
882917
conv1_kwargs["in_channels"] += by # surgery done here
@@ -894,8 +929,12 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
894929
}
895930
# swap old with new modules
896931
unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs)
897-
unet.mid_block.resnets[0].conv1 = LoRACompatibleConv(**conv1_kwargs)
898-
unet.mid_block.resnets[0].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs)
932+
unet.mid_block.resnets[0].conv1 = (
933+
nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
934+
)
935+
unet.mid_block.resnets[0].conv_shortcut = (
936+
nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
937+
)
899938
unet.mid_block.resnets[0].in_channels += by # surgery done here
900939

901940

tests/pipelines/controlnetxs/test_controlnetxs.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
enable_full_determinism,
3535
load_image,
3636
load_numpy,
37+
numpy_cosine_similarity_distance,
3738
require_python39_or_higher,
3839
require_torch_2,
3940
require_torch_gpu,
@@ -273,7 +274,9 @@ def test_canny(self):
273274

274275
original_image = image[-3:, -3:, -1].flatten()
275276
expected_image = np.array([0.1274, 0.1401, 0.147, 0.1185, 0.1555, 0.1492, 0.1565, 0.1474, 0.1701])
276-
assert np.allclose(original_image, expected_image, atol=1e-04)
277+
278+
max_diff = numpy_cosine_similarity_distance(original_image, expected_image)
279+
assert max_diff < 1e-4
277280

278281
def test_depth(self):
279282
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-depth")
@@ -298,7 +301,9 @@ def test_depth(self):
298301

299302
original_image = image[-3:, -3:, -1].flatten()
300303
expected_image = np.array([0.1098, 0.1025, 0.1211, 0.1129, 0.1165, 0.1262, 0.1185, 0.1261, 0.1703])
301-
assert np.allclose(original_image, expected_image, atol=1e-04)
304+
305+
max_diff = numpy_cosine_similarity_distance(original_image, expected_image)
306+
assert max_diff < 1e-4
302307

303308
@require_python39_or_higher
304309
@require_torch_2

0 commit comments

Comments
 (0)