1414# limitations under the License.
1515
1616import gc
17+ import random
18+ import tempfile
1719import unittest
1820
1921import numpy as np
3032 StableDiffusionPix2PixZeroPipeline ,
3133 UNet2DConditionModel ,
3234)
33- from diffusers .utils import load_numpy , slow , torch_device
35+ from diffusers .utils import floats_tensor , load_numpy , slow , torch_device
3436from diffusers .utils .testing_utils import load_image , load_pt , require_torch_gpu , skip_mps
3537
3638from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS , TEXT_GUIDED_IMAGE_VARIATION_PARAMS
@@ -69,6 +71,7 @@ def get_dummy_components(self):
6971 cross_attention_dim = 32 ,
7072 )
7173 scheduler = DDIMScheduler ()
74+ inverse_scheduler = DDIMInverseScheduler ()
7275 torch .manual_seed (0 )
7376 vae = AutoencoderKL (
7477 block_out_channels = [32 , 64 ],
@@ -101,7 +104,7 @@ def get_dummy_components(self):
101104 "tokenizer" : tokenizer ,
102105 "safety_checker" : None ,
103106 "feature_extractor" : None ,
104- "inverse_scheduler" : None ,
107+ "inverse_scheduler" : inverse_scheduler ,
105108 "caption_generator" : None ,
106109 "caption_processor" : None ,
107110 }
@@ -122,6 +125,90 @@ def get_dummy_inputs(self, device, seed=0):
122125 }
123126 return inputs
124127
128+ def get_dummy_inversion_inputs (self , device , seed = 0 ):
129+ dummy_image = floats_tensor ((2 , 3 , 32 , 32 ), rng = random .Random (seed )).to (torch_device )
130+ generator = torch .manual_seed (seed )
131+
132+ inputs = {
133+ "prompt" : [
134+ "A painting of a squirrel eating a burger" ,
135+ "A painting of a burger eating a squirrel" ,
136+ ],
137+ "image" : dummy_image .cpu (),
138+ "num_inference_steps" : 2 ,
139+ "guidance_scale" : 6.0 ,
140+ "generator" : generator ,
141+ "output_type" : "numpy" ,
142+ }
143+ return inputs
144+
145+ def test_save_load_optional_components (self ):
146+ if not hasattr (self .pipeline_class , "_optional_components" ):
147+ return
148+
149+ components = self .get_dummy_components ()
150+ pipe = self .pipeline_class (** components )
151+ pipe .to (torch_device )
152+ pipe .set_progress_bar_config (disable = None )
153+
154+ # set all optional components to None and update pipeline config accordingly
155+ for optional_component in pipe ._optional_components :
156+ setattr (pipe , optional_component , None )
157+ pipe .register_modules (** {optional_component : None for optional_component in pipe ._optional_components })
158+
159+ inputs = self .get_dummy_inputs (torch_device )
160+ output = pipe (** inputs )[0 ]
161+
162+ with tempfile .TemporaryDirectory () as tmpdir :
163+ pipe .save_pretrained (tmpdir )
164+ pipe_loaded = self .pipeline_class .from_pretrained (tmpdir )
165+ pipe_loaded .to (torch_device )
166+ pipe_loaded .set_progress_bar_config (disable = None )
167+
168+ for optional_component in pipe ._optional_components :
169+ self .assertTrue (
170+ getattr (pipe_loaded , optional_component ) is None ,
171+ f"`{ optional_component } ` did not stay set to None after loading." ,
172+ )
173+
174+ inputs = self .get_dummy_inputs (torch_device )
175+ output_loaded = pipe_loaded (** inputs )[0 ]
176+
177+ max_diff = np .abs (output - output_loaded ).max ()
178+ self .assertLess (max_diff , 1e-4 )
179+
180+ def test_stable_diffusion_pix2pix_zero_inversion (self ):
181+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
182+ components = self .get_dummy_components ()
183+ sd_pipe = StableDiffusionPix2PixZeroPipeline (** components )
184+ sd_pipe = sd_pipe .to (device )
185+ sd_pipe .set_progress_bar_config (disable = None )
186+
187+ inputs = self .get_dummy_inversion_inputs (device )
188+ inputs ["image" ] = inputs ["image" ][:1 ]
189+ inputs ["prompt" ] = inputs ["prompt" ][:1 ]
190+ image = sd_pipe .invert (** inputs ).images
191+ image_slice = image [0 , - 3 :, - 3 :, - 1 ]
192+ assert image .shape == (1 , 32 , 32 , 3 )
193+ expected_slice = np .array ([0.4833 , 0.4696 , 0.5574 , 0.5194 , 0.5248 , 0.5638 , 0.5040 , 0.5423 , 0.5072 ])
194+
195+ assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-3
196+
197+ def test_stable_diffusion_pix2pix_zero_inversion_batch (self ):
198+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
199+ components = self .get_dummy_components ()
200+ sd_pipe = StableDiffusionPix2PixZeroPipeline (** components )
201+ sd_pipe = sd_pipe .to (device )
202+ sd_pipe .set_progress_bar_config (disable = None )
203+
204+ inputs = self .get_dummy_inversion_inputs (device )
205+ image = sd_pipe .invert (** inputs ).images
206+ image_slice = image [1 , - 3 :, - 3 :, - 1 ]
207+ assert image .shape == (2 , 32 , 32 , 3 )
208+ expected_slice = np .array ([0.6672 , 0.5203 , 0.4908 , 0.4376 , 0.4517 , 0.5544 , 0.4605 , 0.4826 , 0.5007 ])
209+
210+ assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-3
211+
125212 def test_stable_diffusion_pix2pix_zero_default_case (self ):
126213 device = "cpu" # ensure determinism for the device-dependent torch.Generator
127214 components = self .get_dummy_components ()
0 commit comments