2020import torch
2121from transformers import CLIPTextConfig , CLIPTextModel , CLIPTextModelWithProjection , CLIPTokenizer
2222
23+ import diffusers
2324from diffusers import (
2425 AutoencoderKL ,
2526 EulerDiscreteScheduler ,
27+ MultiAdapter ,
2628 StableDiffusionXLAdapterPipeline ,
2729 T2IAdapter ,
2830 UNet2DConditionModel ,
2931)
30- from diffusers .utils .testing_utils import enable_full_determinism , floats_tensor
32+ from diffusers .utils import logging
33+ from diffusers .utils .testing_utils import enable_full_determinism , floats_tensor , torch_device
3134
3235from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS , TEXT_GUIDED_IMAGE_VARIATION_PARAMS
33- from ..test_pipelines_common import PipelineTesterMixin
36+ from ..test_pipelines_common import PipelineTesterMixin , assert_mean_pixel_difference
3437
3538
3639enable_full_determinism ()
@@ -41,7 +44,7 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
4144 params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
4245 batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
4346
44- def get_dummy_components (self ):
47+ def get_dummy_components (self , adapter_type = "full_adapter_xl" ):
4548 torch .manual_seed (0 )
4649 unet = UNet2DConditionModel (
4750 block_out_channels = (32 , 64 ),
@@ -97,13 +100,38 @@ def get_dummy_components(self):
97100
98101 text_encoder_2 = CLIPTextModelWithProjection (text_encoder_config )
99102 tokenizer_2 = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
100- adapter = T2IAdapter (
101- in_channels = 3 ,
102- channels = [32 , 64 ],
103- num_res_blocks = 2 ,
104- downscale_factor = 4 ,
105- adapter_type = "full_adapter_xl" ,
106- )
103+ if adapter_type == "full_adapter_xl" :
104+ adapter = T2IAdapter (
105+ in_channels = 3 ,
106+ channels = [32 , 64 ],
107+ num_res_blocks = 2 ,
108+ downscale_factor = 4 ,
109+ adapter_type = adapter_type ,
110+ )
111+ elif adapter_type == "multi_adapter" :
112+ adapter = MultiAdapter (
113+ [
114+ T2IAdapter (
115+ in_channels = 3 ,
116+ channels = [32 , 64 ],
117+ num_res_blocks = 2 ,
118+ downscale_factor = 4 ,
119+ adapter_type = "full_adapter_xl" ,
120+ ),
121+ T2IAdapter (
122+ in_channels = 3 ,
123+ channels = [32 , 64 ],
124+ num_res_blocks = 2 ,
125+ downscale_factor = 4 ,
126+ adapter_type = "full_adapter_xl" ,
127+ ),
128+ ]
129+ )
130+ else :
131+ raise ValueError (
132+ f"Unknown adapter type: { adapter_type } , must be one of 'full_adapter_xl', or 'multi_adapter''"
133+ )
134+
107135 components = {
108136 "adapter" : adapter ,
109137 "unet" : unet ,
@@ -118,8 +146,12 @@ def get_dummy_components(self):
118146 }
119147 return components
120148
121- def get_dummy_inputs (self , device , seed = 0 ):
122- image = floats_tensor ((1 , 3 , 64 , 64 ), rng = random .Random (seed )).to (device )
149+ def get_dummy_inputs (self , device , seed = 0 , num_images = 1 ):
150+ if num_images == 1 :
151+ image = floats_tensor ((1 , 3 , 64 , 64 ), rng = random .Random (seed )).to (device )
152+ else :
153+ image = [floats_tensor ((1 , 3 , 64 , 64 ), rng = random .Random (seed )).to (device ) for _ in range (num_images )]
154+
123155 if str (device ).startswith ("mps" ):
124156 generator = torch .manual_seed (seed )
125157 else :
@@ -150,3 +182,202 @@ def test_stable_diffusion_adapter_default_case(self):
150182 [0.5752919 , 0.6022097 , 0.4728038 , 0.49861962 , 0.57084894 , 0.4644975 , 0.5193715 , 0.5133664 , 0.4729858 ]
151183 )
152184 assert np .abs (image_slice .flatten () - expected_slice ).max () < 5e-3
185+
186+
187+ class StableDiffusionXLMultiAdapterPipelineFastTests (
188+ StableDiffusionXLAdapterPipelineFastTests , PipelineTesterMixin , unittest .TestCase
189+ ):
190+ def get_dummy_components (self ):
191+ return super ().get_dummy_components ("multi_adapter" )
192+
193+ def get_dummy_inputs (self , device , seed = 0 ):
194+ inputs = super ().get_dummy_inputs (device , seed , num_images = 2 )
195+ inputs ["adapter_conditioning_scale" ] = [0.5 , 0.5 ]
196+ return inputs
197+
198+ def test_stable_diffusion_adapter_default_case (self ):
199+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
200+ components = self .get_dummy_components ()
201+ sd_pipe = StableDiffusionXLAdapterPipeline (** components )
202+ sd_pipe = sd_pipe .to (device )
203+ sd_pipe .set_progress_bar_config (disable = None )
204+
205+ inputs = self .get_dummy_inputs (device )
206+ image = sd_pipe (** inputs ).images
207+ image_slice = image [0 , - 3 :, - 3 :, - 1 ]
208+
209+ assert image .shape == (1 , 64 , 64 , 3 )
210+ expected_slice = np .array (
211+ [0.5813032 , 0.60995954 , 0.47563356 , 0.5056669 , 0.57199144 , 0.4631841 , 0.5176794 , 0.51252556 , 0.47183886 ]
212+ )
213+ assert np .abs (image_slice .flatten () - expected_slice ).max () < 5e-3
214+
215+ def test_inference_batch_consistent (
216+ self , batch_sizes = [2 , 4 , 13 ], additional_params_copy_to_batched_inputs = ["num_inference_steps" ]
217+ ):
218+ components = self .get_dummy_components ()
219+ pipe = self .pipeline_class (** components )
220+ pipe .to (torch_device )
221+ pipe .set_progress_bar_config (disable = None )
222+
223+ inputs = self .get_dummy_inputs (torch_device )
224+
225+ logger = logging .get_logger (pipe .__module__ )
226+ logger .setLevel (level = diffusers .logging .FATAL )
227+
228+ # batchify inputs
229+ for batch_size in batch_sizes :
230+ batched_inputs = {}
231+ for name , value in inputs .items ():
232+ if name in self .batch_params :
233+ # prompt is string
234+ if name == "prompt" :
235+ len_prompt = len (value )
236+ # make unequal batch sizes
237+ batched_inputs [name ] = [value [: len_prompt // i ] for i in range (1 , batch_size + 1 )]
238+
239+ # make last batch super long
240+ batched_inputs [name ][- 1 ] = 100 * "very long"
241+ elif name == "image" :
242+ batched_images = []
243+
244+ for image in value :
245+ batched_images .append (batch_size * [image ])
246+
247+ batched_inputs [name ] = batched_images
248+ else :
249+ batched_inputs [name ] = batch_size * [value ]
250+
251+ elif name == "batch_size" :
252+ batched_inputs [name ] = batch_size
253+ else :
254+ batched_inputs [name ] = value
255+
256+ for arg in additional_params_copy_to_batched_inputs :
257+ batched_inputs [arg ] = inputs [arg ]
258+
259+ batched_inputs ["output_type" ] = "np"
260+
261+ output = pipe (** batched_inputs )
262+
263+ assert len (output [0 ]) == batch_size
264+
265+ batched_inputs ["output_type" ] = "np"
266+
267+ output = pipe (** batched_inputs )[0 ]
268+
269+ assert output .shape [0 ] == batch_size
270+
271+ logger .setLevel (level = diffusers .logging .WARNING )
272+
273+ def test_num_images_per_prompt (self ):
274+ components = self .get_dummy_components ()
275+ pipe = self .pipeline_class (** components )
276+ pipe = pipe .to (torch_device )
277+ pipe .set_progress_bar_config (disable = None )
278+
279+ batch_sizes = [1 , 2 ]
280+ num_images_per_prompts = [1 , 2 ]
281+
282+ for batch_size in batch_sizes :
283+ for num_images_per_prompt in num_images_per_prompts :
284+ inputs = self .get_dummy_inputs (torch_device )
285+
286+ for key in inputs .keys ():
287+ if key in self .batch_params :
288+ if key == "image" :
289+ batched_images = []
290+
291+ for image in inputs [key ]:
292+ batched_images .append (batch_size * [image ])
293+
294+ inputs [key ] = batched_images
295+ else :
296+ inputs [key ] = batch_size * [inputs [key ]]
297+
298+ images = pipe (** inputs , num_images_per_prompt = num_images_per_prompt )[0 ]
299+
300+ assert images .shape [0 ] == batch_size * num_images_per_prompt
301+
302+ def test_inference_batch_single_identical (
303+ self ,
304+ batch_size = 3 ,
305+ test_max_difference = None ,
306+ test_mean_pixel_difference = None ,
307+ relax_max_difference = False ,
308+ expected_max_diff = 2e-3 ,
309+ additional_params_copy_to_batched_inputs = ["num_inference_steps" ],
310+ ):
311+ if test_max_difference is None :
312+ # TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems
313+ # make sure that batched and non-batched is identical
314+ test_max_difference = torch_device != "mps"
315+
316+ if test_mean_pixel_difference is None :
317+ # TODO same as above
318+ test_mean_pixel_difference = torch_device != "mps"
319+
320+ components = self .get_dummy_components ()
321+ pipe = self .pipeline_class (** components )
322+ pipe .to (torch_device )
323+ pipe .set_progress_bar_config (disable = None )
324+
325+ inputs = self .get_dummy_inputs (torch_device )
326+
327+ logger = logging .get_logger (pipe .__module__ )
328+ logger .setLevel (level = diffusers .logging .FATAL )
329+
330+ # batchify inputs
331+ batched_inputs = {}
332+ batch_size = batch_size
333+ for name , value in inputs .items ():
334+ if name in self .batch_params :
335+ # prompt is string
336+ if name == "prompt" :
337+ len_prompt = len (value )
338+ # make unequal batch sizes
339+ batched_inputs [name ] = [value [: len_prompt // i ] for i in range (1 , batch_size + 1 )]
340+
341+ # make last batch super long
342+ batched_inputs [name ][- 1 ] = 100 * "very long"
343+ elif name == "image" :
344+ batched_images = []
345+
346+ for image in value :
347+ batched_images .append (batch_size * [image ])
348+
349+ batched_inputs [name ] = batched_images
350+ else :
351+ batched_inputs [name ] = batch_size * [value ]
352+ elif name == "batch_size" :
353+ batched_inputs [name ] = batch_size
354+ elif name == "generator" :
355+ batched_inputs [name ] = [self .get_generator (i ) for i in range (batch_size )]
356+ else :
357+ batched_inputs [name ] = value
358+
359+ for arg in additional_params_copy_to_batched_inputs :
360+ batched_inputs [arg ] = inputs [arg ]
361+
362+ output_batch = pipe (** batched_inputs )
363+ assert output_batch [0 ].shape [0 ] == batch_size
364+
365+ inputs ["generator" ] = self .get_generator (0 )
366+
367+ output = pipe (** inputs )
368+
369+ logger .setLevel (level = diffusers .logging .WARNING )
370+ if test_max_difference :
371+ if relax_max_difference :
372+ # Taking the median of the largest <n> differences
373+ # is resilient to outliers
374+ diff = np .abs (output_batch [0 ][0 ] - output [0 ][0 ])
375+ diff = diff .flatten ()
376+ diff .sort ()
377+ max_diff = np .median (diff [- 5 :])
378+ else :
379+ max_diff = np .abs (output_batch [0 ][0 ] - output [0 ][0 ]).max ()
380+ assert max_diff < expected_max_diff
381+
382+ if test_mean_pixel_difference :
383+ assert_mean_pixel_difference (output_batch [0 ][0 ], output [0 ][0 ])
0 commit comments