|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
| 16 | +import gc |
16 | 17 | import unittest |
17 | 18 |
|
18 | 19 | import numpy as np |
|
27 | 28 | UNet2DConditionModel, |
28 | 29 | ) |
29 | 30 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel |
30 | | -from diffusers.utils import randn_tensor, torch_device |
| 31 | +from diffusers.utils import load_image, randn_tensor, torch_device |
31 | 32 | from diffusers.utils.import_utils import is_xformers_available |
32 | | -from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu |
| 33 | +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow |
33 | 34 |
|
34 | 35 | from ..pipeline_params import ( |
35 | 36 | IMAGE_TO_IMAGE_IMAGE_PARAMS, |
@@ -678,3 +679,81 @@ def test_xformers_attention_forwardGenerator_pass(self): |
678 | 679 |
|
679 | 680 | def test_inference_batch_single_identical(self): |
680 | 681 | self._test_inference_batch_single_identical(expected_max_diff=2e-3) |
| 682 | + |
| 683 | + |
| 684 | +@slow |
| 685 | +@require_torch_gpu |
| 686 | +class ControlNetSDXLPipelineSlowTests(unittest.TestCase): |
| 687 | + def tearDown(self): |
| 688 | + super().tearDown() |
| 689 | + gc.collect() |
| 690 | + torch.cuda.empty_cache() |
| 691 | + |
| 692 | + def test_canny(self): |
| 693 | + controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0") |
| 694 | + |
| 695 | + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( |
| 696 | + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet |
| 697 | + ) |
| 698 | + pipe.enable_sequential_cpu_offload() |
| 699 | + pipe.set_progress_bar_config(disable=None) |
| 700 | + |
| 701 | + generator = torch.Generator(device="cpu").manual_seed(0) |
| 702 | + prompt = "bird" |
| 703 | + image = load_image( |
| 704 | + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" |
| 705 | + ) |
| 706 | + |
| 707 | + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images |
| 708 | + |
| 709 | + assert images[0].shape == (768, 512, 3) |
| 710 | + |
| 711 | + original_image = images[0, -3:, -3:, -1].flatten() |
| 712 | + expected_image = np.array([0.4185, 0.4127, 0.4089, 0.4046, 0.4115, 0.4096, 0.4081, 0.4112, 0.3913]) |
| 713 | + assert np.allclose(original_image, expected_image, atol=1e-04) |
| 714 | + |
| 715 | + def test_depth(self): |
| 716 | + controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0") |
| 717 | + |
| 718 | + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( |
| 719 | + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet |
| 720 | + ) |
| 721 | + pipe.enable_sequential_cpu_offload() |
| 722 | + pipe.set_progress_bar_config(disable=None) |
| 723 | + |
| 724 | + generator = torch.Generator(device="cpu").manual_seed(0) |
| 725 | + prompt = "Stormtrooper's lecture" |
| 726 | + image = load_image( |
| 727 | + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png" |
| 728 | + ) |
| 729 | + |
| 730 | + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images |
| 731 | + |
| 732 | + assert images[0].shape == (512, 512, 3) |
| 733 | + |
| 734 | + original_image = images[0, -3:, -3:, -1].flatten() |
| 735 | + expected_image = np.array([0.4399, 0.5112, 0.5478, 0.4314, 0.472, 0.4823, 0.4647, 0.4957, 0.4853]) |
| 736 | + assert np.allclose(original_image, expected_image, atol=1e-04) |
| 737 | + |
| 738 | + def test_canny_lora(self): |
| 739 | + controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0") |
| 740 | + |
| 741 | + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( |
| 742 | + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet |
| 743 | + ) |
| 744 | + pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors") |
| 745 | + pipe.enable_sequential_cpu_offload() |
| 746 | + |
| 747 | + generator = torch.Generator(device="cpu").manual_seed(0) |
| 748 | + prompt = "corgi" |
| 749 | + image = load_image( |
| 750 | + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" |
| 751 | + ) |
| 752 | + |
| 753 | + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images |
| 754 | + |
| 755 | + assert images[0].shape == (768, 512, 3) |
| 756 | + |
| 757 | + original_image = images[0, -3:, -3:, -1].flatten() |
| 758 | + expected_image = np.array([0.4574, 0.4461, 0.4435, 0.4462, 0.4396, 0.439, 0.4474, 0.4486, 0.4333]) |
| 759 | + assert np.allclose(original_image, expected_image, atol=1e-04) |
0 commit comments