|
26 | 26 | from huggingface_hub import hf_hub_download |
27 | 27 | from huggingface_hub.repocard import RepoCard |
28 | 28 | from packaging import version |
| 29 | +from safetensors.torch import load_file |
29 | 30 | from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer |
30 | 31 |
|
31 | 32 | from diffusers import ( |
32 | 33 | AutoencoderKL, |
33 | 34 | AutoPipelineForImage2Image, |
| 35 | + AutoPipelineForText2Image, |
34 | 36 | ControlNetModel, |
35 | 37 | DDIMScheduler, |
36 | 38 | DiffusionPipeline, |
@@ -1745,6 +1747,40 @@ def test_load_unload_load_kohya_lora(self): |
1745 | 1747 | self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3)) |
1746 | 1748 | release_memory(pipe) |
1747 | 1749 |
|
| 1750 | + def test_not_empty_state_dict(self): |
| 1751 | + # Makes sure https://github.com/huggingface/diffusers/issues/7054 does not happen again |
| 1752 | + pipe = AutoPipelineForText2Image.from_pretrained( |
| 1753 | + "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 |
| 1754 | + ).to("cuda") |
| 1755 | + pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) |
| 1756 | + |
| 1757 | + cached_file = hf_hub_download("hf-internal-testing/lcm-lora-test-sd-v1-5", "test_lora.safetensors") |
| 1758 | + lcm_lora = load_file(cached_file) |
| 1759 | + |
| 1760 | + pipe.load_lora_weights(lcm_lora, adapter_name="lcm") |
| 1761 | + self.assertTrue(lcm_lora != {}) |
| 1762 | + release_memory(pipe) |
| 1763 | + |
| 1764 | + def test_load_unload_load_state_dict(self): |
| 1765 | + # Makes sure https://github.com/huggingface/diffusers/issues/7054 does not happen again |
| 1766 | + pipe = AutoPipelineForText2Image.from_pretrained( |
| 1767 | + "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 |
| 1768 | + ).to("cuda") |
| 1769 | + pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) |
| 1770 | + |
| 1771 | + cached_file = hf_hub_download("hf-internal-testing/lcm-lora-test-sd-v1-5", "test_lora.safetensors") |
| 1772 | + lcm_lora = load_file(cached_file) |
| 1773 | + previous_state_dict = lcm_lora.copy() |
| 1774 | + |
| 1775 | + pipe.load_lora_weights(lcm_lora, adapter_name="lcm") |
| 1776 | + self.assertDictEqual(lcm_lora, previous_state_dict) |
| 1777 | + |
| 1778 | + pipe.unload_lora_weights() |
| 1779 | + pipe.load_lora_weights(lcm_lora, adapter_name="lcm") |
| 1780 | + self.assertDictEqual(lcm_lora, previous_state_dict) |
| 1781 | + |
| 1782 | + release_memory(pipe) |
| 1783 | + |
1748 | 1784 |
|
1749 | 1785 | @slow |
1750 | 1786 | @require_torch_gpu |
|
0 commit comments