1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import gc
1617import unittest
1718
1819import numpy as np
1920import torch
2021
2122from diffusers import AutoencoderKL , DDIMScheduler , LDMTextToImagePipeline , UNet2DConditionModel
22- from diffusers .utils .testing_utils import require_torch , slow , torch_device
23+ from diffusers .utils .testing_utils import load_numpy , nightly , require_torch_gpu , slow , torch_device
2324from transformers import CLIPTextConfig , CLIPTextModel , CLIPTokenizer
2425
26+ from ...test_pipelines_common import PipelineTesterMixin
27+
2528
2629torch .backends .cuda .matmul .allow_tf32 = False
2730
2831
29- class LDMTextToImagePipelineFastTests (unittest .TestCase ):
30- @property
31- def dummy_cond_unet (self ):
32+ class LDMTextToImagePipelineFastTests (PipelineTesterMixin , unittest .TestCase ):
33+ pipeline_class = LDMTextToImagePipeline
34+ test_cpu_offload = False
35+
36+ def get_dummy_components (self ):
3237 torch .manual_seed (0 )
33- model = UNet2DConditionModel (
38+ unet = UNet2DConditionModel (
3439 block_out_channels = (32 , 64 ),
3540 layers_per_block = 2 ,
3641 sample_size = 32 ,
@@ -40,25 +45,24 @@ def dummy_cond_unet(self):
4045 up_block_types = ("CrossAttnUpBlock2D" , "UpBlock2D" ),
4146 cross_attention_dim = 32 ,
4247 )
43- return model
44-
45- @property
46- def dummy_vae (self ):
48+ scheduler = DDIMScheduler (
49+ beta_start = 0.00085 ,
50+ beta_end = 0.012 ,
51+ beta_schedule = "scaled_linear" ,
52+ clip_sample = False ,
53+ set_alpha_to_one = False ,
54+ )
4755 torch .manual_seed (0 )
48- model = AutoencoderKL (
49- block_out_channels = [ 32 , 64 ] ,
56+ vae = AutoencoderKL (
57+ block_out_channels = ( 32 , 64 ) ,
5058 in_channels = 3 ,
5159 out_channels = 3 ,
52- down_block_types = [ "DownEncoderBlock2D" , "DownEncoderBlock2D" ] ,
53- up_block_types = [ "UpDecoderBlock2D" , "UpDecoderBlock2D" ] ,
60+ down_block_types = ( "DownEncoderBlock2D" , "DownEncoderBlock2D" ) ,
61+ up_block_types = ( "UpDecoderBlock2D" , "UpDecoderBlock2D" ) ,
5462 latent_channels = 4 ,
5563 )
56- return model
57-
58- @property
59- def dummy_text_encoder (self ):
6064 torch .manual_seed (0 )
61- config = CLIPTextConfig (
65+ text_encoder_config = CLIPTextConfig (
6266 bos_token_id = 0 ,
6367 eos_token_id = 2 ,
6468 hidden_size = 32 ,
@@ -69,96 +73,117 @@ def dummy_text_encoder(self):
6973 pad_token_id = 1 ,
7074 vocab_size = 1000 ,
7175 )
72- return CLIPTextModel (config )
73-
74- def test_inference_text2img (self ):
75- if torch_device != "cpu" :
76- return
77-
78- unet = self .dummy_cond_unet
79- scheduler = DDIMScheduler ()
80- vae = self .dummy_vae
81- bert = self .dummy_text_encoder
76+ text_encoder = CLIPTextModel (text_encoder_config )
8277 tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
8378
84- ldm = LDMTextToImagePipeline (vqvae = vae , bert = bert , tokenizer = tokenizer , unet = unet , scheduler = scheduler )
85- ldm .to (torch_device )
86- ldm .set_progress_bar_config (disable = None )
87-
88- prompt = "A painting of a squirrel eating a burger"
89-
90- # Warmup pass when using mps (see #372)
91- if torch_device == "mps" :
92- generator = torch .manual_seed (0 )
93- _ = ldm (
94- [prompt ], generator = generator , guidance_scale = 6.0 , num_inference_steps = 1 , output_type = "numpy"
95- ).images
96-
97- device = torch_device if torch_device != "mps" else "cpu"
98- generator = torch .Generator (device = device ).manual_seed (0 )
99-
100- image = ldm (
101- [prompt ], generator = generator , guidance_scale = 6.0 , num_inference_steps = 2 , output_type = "numpy"
102- ).images
103-
104- device = torch_device if torch_device != "mps" else "cpu"
105- generator = torch .Generator (device = device ).manual_seed (0 )
106-
107- image_from_tuple = ldm (
108- [prompt ],
109- generator = generator ,
110- guidance_scale = 6.0 ,
111- num_inference_steps = 2 ,
112- output_type = "numpy" ,
113- return_dict = False ,
114- )[0 ]
79+ components = {
80+ "unet" : unet ,
81+ "scheduler" : scheduler ,
82+ "vqvae" : vae ,
83+ "bert" : text_encoder ,
84+ "tokenizer" : tokenizer ,
85+ }
86+ return components
87+
88+ def get_dummy_inputs (self , device , seed = 0 ):
89+ if str (device ).startswith ("mps" ):
90+ generator = torch .manual_seed (seed )
91+ else :
92+ generator = torch .Generator (device = device ).manual_seed (seed )
93+ inputs = {
94+ "prompt" : "A painting of a squirrel eating a burger" ,
95+ "generator" : generator ,
96+ "num_inference_steps" : 2 ,
97+ "guidance_scale" : 6.0 ,
98+ "output_type" : "numpy" ,
99+ }
100+ return inputs
115101
116- image_slice = image [0 , - 3 :, - 3 :, - 1 ]
117- image_from_tuple_slice = image_from_tuple [0 , - 3 :, - 3 :, - 1 ]
118-
119- assert image .shape == (1 , 16 , 16 , 3 )
120- expected_slice = np .array ([0.6806 , 0.5454 , 0.5638 , 0.4893 , 0.4656 , 0.4257 , 0.6248 , 0.5217 , 0.5498 ])
121- assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
122- assert np .abs (image_from_tuple_slice .flatten () - expected_slice ).max () < 1e-2
123-
124-
125- @slow
126- @require_torch
127- class LDMTextToImagePipelineIntegrationTests (unittest .TestCase ):
128102 def test_inference_text2img (self ):
129- ldm = LDMTextToImagePipeline .from_pretrained ("CompVis/ldm-text2im-large-256" )
130- ldm .to (torch_device )
131- ldm .set_progress_bar_config (disable = None )
132-
133- prompt = "A painting of a squirrel eating a burger"
134-
135- device = torch_device if torch_device != "mps" else "cpu"
136- generator = torch .Generator (device = device ).manual_seed (0 )
103+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
137104
138- image = ldm (
139- [prompt ], generator = generator , guidance_scale = 6.0 , num_inference_steps = 20 , output_type = "numpy"
140- ).images
105+ components = self .get_dummy_components ()
106+ pipe = LDMTextToImagePipeline (** components )
107+ pipe .to (device )
108+ pipe .set_progress_bar_config (disable = None )
141109
110+ inputs = self .get_dummy_inputs (device )
111+ image = pipe (** inputs ).images
142112 image_slice = image [0 , - 3 :, - 3 :, - 1 ]
143113
144- assert image .shape == (1 , 256 , 256 , 3 )
145- expected_slice = np .array ([0.9256 , 0.9340 , 0.8933 , 0.9361 , 0.9113 , 0.8727 , 0.9122 , 0.8745 , 0.8099 ])
146- assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
147-
148- def test_inference_text2img_fast (self ):
149- ldm = LDMTextToImagePipeline .from_pretrained ("CompVis/ldm-text2im-large-256" )
150- ldm .to (torch_device )
151- ldm .set_progress_bar_config (disable = None )
152-
153- prompt = "A painting of a squirrel eating a burger"
114+ assert image .shape == (1 , 16 , 16 , 3 )
115+ expected_slice = np .array ([0.59450 , 0.64078 , 0.55509 , 0.51229 , 0.69640 , 0.36960 , 0.59296 , 0.60801 , 0.49332 ])
154116
155- device = torch_device if torch_device != "mps" else "cpu"
156- generator = torch .Generator (device = device ).manual_seed (0 )
117+ assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-3
157118
158- image = ldm (prompt , generator = generator , num_inference_steps = 1 , output_type = "numpy" ).images
159119
160- image_slice = image [0 , - 3 :, - 3 :, - 1 ]
120+ @slow
121+ @require_torch_gpu
122+ class LDMTextToImagePipelineSlowTests (unittest .TestCase ):
123+ def tearDown (self ):
124+ super ().tearDown ()
125+ gc .collect ()
126+ torch .cuda .empty_cache ()
127+
128+ def get_inputs (self , device , dtype = torch .float32 , seed = 0 ):
129+ generator = torch .Generator (device = device ).manual_seed (seed )
130+ latents = np .random .RandomState (seed ).standard_normal ((1 , 4 , 32 , 32 ))
131+ latents = torch .from_numpy (latents ).to (device = device , dtype = dtype )
132+ inputs = {
133+ "prompt" : "A painting of a squirrel eating a burger" ,
134+ "latents" : latents ,
135+ "generator" : generator ,
136+ "num_inference_steps" : 3 ,
137+ "guidance_scale" : 6.0 ,
138+ "output_type" : "numpy" ,
139+ }
140+ return inputs
141+
142+ def test_ldm_default_ddim (self ):
143+ pipe = LDMTextToImagePipeline .from_pretrained ("CompVis/ldm-text2im-large-256" ).to (torch_device )
144+ pipe .set_progress_bar_config (disable = None )
145+
146+ inputs = self .get_inputs (torch_device )
147+ image = pipe (** inputs ).images
148+ image_slice = image [0 , - 3 :, - 3 :, - 1 ].flatten ()
161149
162150 assert image .shape == (1 , 256 , 256 , 3 )
163- expected_slice = np .array ([0.3163 , 0.8670 , 0.6465 , 0.1865 , 0.6291 , 0.5139 , 0.2824 , 0.3723 , 0.4344 ])
164- assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
151+ expected_slice = np .array ([0.51825 , 0.52850 , 0.52543 , 0.54258 , 0.52304 , 0.52569 , 0.54363 , 0.55276 , 0.56878 ])
152+ max_diff = np .abs (expected_slice - image_slice ).max ()
153+ assert max_diff < 1e-3
154+
155+
156+ @nightly
157+ @require_torch_gpu
158+ class LDMTextToImagePipelineNightlyTests (unittest .TestCase ):
159+ def tearDown (self ):
160+ super ().tearDown ()
161+ gc .collect ()
162+ torch .cuda .empty_cache ()
163+
164+ def get_inputs (self , device , dtype = torch .float32 , seed = 0 ):
165+ generator = torch .Generator (device = device ).manual_seed (seed )
166+ latents = np .random .RandomState (seed ).standard_normal ((1 , 4 , 32 , 32 ))
167+ latents = torch .from_numpy (latents ).to (device = device , dtype = dtype )
168+ inputs = {
169+ "prompt" : "A painting of a squirrel eating a burger" ,
170+ "latents" : latents ,
171+ "generator" : generator ,
172+ "num_inference_steps" : 50 ,
173+ "guidance_scale" : 6.0 ,
174+ "output_type" : "numpy" ,
175+ }
176+ return inputs
177+
178+ def test_ldm_default_ddim (self ):
179+ pipe = LDMTextToImagePipeline .from_pretrained ("CompVis/ldm-text2im-large-256" ).to (torch_device )
180+ pipe .set_progress_bar_config (disable = None )
181+
182+ inputs = self .get_inputs (torch_device )
183+ image = pipe (** inputs ).images [0 ]
184+
185+ expected_image = load_numpy (
186+ "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/ldm_text2img/ldm_large_256_ddim.npy"
187+ )
188+ max_diff = np .abs (expected_image - image ).max ()
189+ assert max_diff < 1e-3
0 commit comments