Skip to content

Commit 9141c1f

Browse files
authored
[Core] enable lora for sdxl controlnets too and add slow tests. (huggingface#4666)
* enable lora for sdxl controlnets too. * add: tests * fix: assertion values.
1 parent f75b8aa commit 9141c1f

File tree

2 files changed

+155
-2
lines changed

2 files changed

+155
-2
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import inspect
17+
import os
1718
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1819

1920
import numpy as np
@@ -1169,3 +1170,76 @@ def __call__(
11691170
return (image,)
11701171

11711172
return StableDiffusionXLPipelineOutput(images=image)
1173+
1174+
# Overrride to properly handle the loading and unloading of the additional text encoder.
1175+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
1176+
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
1177+
# We could have accessed the unet config from `lora_state_dict()` too. We pass
1178+
# it here explicitly to be able to tell that it's coming from an SDXL
1179+
# pipeline.
1180+
state_dict, network_alphas = self.lora_state_dict(
1181+
pretrained_model_name_or_path_or_dict,
1182+
unet_config=self.unet.config,
1183+
**kwargs,
1184+
)
1185+
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
1186+
1187+
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1188+
if len(text_encoder_state_dict) > 0:
1189+
self.load_lora_into_text_encoder(
1190+
text_encoder_state_dict,
1191+
network_alphas=network_alphas,
1192+
text_encoder=self.text_encoder,
1193+
prefix="text_encoder",
1194+
lora_scale=self.lora_scale,
1195+
)
1196+
1197+
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1198+
if len(text_encoder_2_state_dict) > 0:
1199+
self.load_lora_into_text_encoder(
1200+
text_encoder_2_state_dict,
1201+
network_alphas=network_alphas,
1202+
text_encoder=self.text_encoder_2,
1203+
prefix="text_encoder_2",
1204+
lora_scale=self.lora_scale,
1205+
)
1206+
1207+
@classmethod
1208+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
1209+
def save_lora_weights(
1210+
self,
1211+
save_directory: Union[str, os.PathLike],
1212+
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1213+
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1214+
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1215+
is_main_process: bool = True,
1216+
weight_name: str = None,
1217+
save_function: Callable = None,
1218+
safe_serialization: bool = True,
1219+
):
1220+
state_dict = {}
1221+
1222+
def pack_weights(layers, prefix):
1223+
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
1224+
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
1225+
return layers_state_dict
1226+
1227+
state_dict.update(pack_weights(unet_lora_layers, "unet"))
1228+
1229+
if text_encoder_lora_layers and text_encoder_2_lora_layers:
1230+
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
1231+
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1232+
1233+
self.write_lora_layers(
1234+
state_dict=state_dict,
1235+
save_directory=save_directory,
1236+
is_main_process=is_main_process,
1237+
weight_name=weight_name,
1238+
save_function=save_function,
1239+
safe_serialization=safe_serialization,
1240+
)
1241+
1242+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
1243+
def _remove_text_encoder_monkey_patch(self):
1244+
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
1245+
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)

tests/pipelines/controlnet/test_controlnet_sdxl.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import gc
1617
import unittest
1718

1819
import numpy as np
@@ -27,9 +28,9 @@
2728
UNet2DConditionModel,
2829
)
2930
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
30-
from diffusers.utils import randn_tensor, torch_device
31+
from diffusers.utils import load_image, randn_tensor, torch_device
3132
from diffusers.utils.import_utils import is_xformers_available
32-
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
33+
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow
3334

3435
from ..pipeline_params import (
3536
IMAGE_TO_IMAGE_IMAGE_PARAMS,
@@ -678,3 +679,81 @@ def test_xformers_attention_forwardGenerator_pass(self):
678679

679680
def test_inference_batch_single_identical(self):
680681
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
682+
683+
684+
@slow
685+
@require_torch_gpu
686+
class ControlNetSDXLPipelineSlowTests(unittest.TestCase):
687+
def tearDown(self):
688+
super().tearDown()
689+
gc.collect()
690+
torch.cuda.empty_cache()
691+
692+
def test_canny(self):
693+
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0")
694+
695+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
696+
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet
697+
)
698+
pipe.enable_sequential_cpu_offload()
699+
pipe.set_progress_bar_config(disable=None)
700+
701+
generator = torch.Generator(device="cpu").manual_seed(0)
702+
prompt = "bird"
703+
image = load_image(
704+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
705+
)
706+
707+
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
708+
709+
assert images[0].shape == (768, 512, 3)
710+
711+
original_image = images[0, -3:, -3:, -1].flatten()
712+
expected_image = np.array([0.4185, 0.4127, 0.4089, 0.4046, 0.4115, 0.4096, 0.4081, 0.4112, 0.3913])
713+
assert np.allclose(original_image, expected_image, atol=1e-04)
714+
715+
def test_depth(self):
716+
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0")
717+
718+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
719+
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet
720+
)
721+
pipe.enable_sequential_cpu_offload()
722+
pipe.set_progress_bar_config(disable=None)
723+
724+
generator = torch.Generator(device="cpu").manual_seed(0)
725+
prompt = "Stormtrooper's lecture"
726+
image = load_image(
727+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
728+
)
729+
730+
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
731+
732+
assert images[0].shape == (512, 512, 3)
733+
734+
original_image = images[0, -3:, -3:, -1].flatten()
735+
expected_image = np.array([0.4399, 0.5112, 0.5478, 0.4314, 0.472, 0.4823, 0.4647, 0.4957, 0.4853])
736+
assert np.allclose(original_image, expected_image, atol=1e-04)
737+
738+
def test_canny_lora(self):
739+
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0")
740+
741+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
742+
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet
743+
)
744+
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors")
745+
pipe.enable_sequential_cpu_offload()
746+
747+
generator = torch.Generator(device="cpu").manual_seed(0)
748+
prompt = "corgi"
749+
image = load_image(
750+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
751+
)
752+
753+
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
754+
755+
assert images[0].shape == (768, 512, 3)
756+
757+
original_image = images[0, -3:, -3:, -1].flatten()
758+
expected_image = np.array([0.4574, 0.4461, 0.4435, 0.4462, 0.4396, 0.439, 0.4474, 0.4486, 0.4333])
759+
assert np.allclose(original_image, expected_image, atol=1e-04)

0 commit comments

Comments
 (0)