1717import unittest
1818
1919import numpy as np
20- import requests
2120import torch
22- from PIL import Image
2321from transformers import CLIPTextConfig , CLIPTextModel , CLIPTokenizer
2422
2523from diffusers import (
3331 UNet2DConditionModel ,
3432)
3533from diffusers .utils import load_numpy , slow , torch_device
36- from diffusers .utils .testing_utils import require_torch_gpu , skip_mps
34+ from diffusers .utils .testing_utils import load_image , load_pt , require_torch_gpu , skip_mps
3735
3836from ...test_pipelines_common import PipelineTesterMixin
3937
4038
4139torch .backends .cuda .matmul .allow_tf32 = False
4240
4341
44- def download_from_url (embedding_url , local_filepath ):
45- r = requests .get (embedding_url )
46- with open (local_filepath , "wb" ) as f :
47- f .write (r .content )
48-
49-
5042@skip_mps
5143class StableDiffusionPix2PixZeroPipelineFastTests (PipelineTesterMixin , unittest .TestCase ):
5244 pipeline_class = StableDiffusionPix2PixZeroPipeline
5345
46+ @classmethod
47+ def setUpClass (cls ):
48+ cls .source_embeds = load_pt (
49+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/src_emb_0.pt"
50+ )
51+
52+ cls .target_embeds = load_pt (
53+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/tgt_emb_0.pt"
54+ )
55+
5456 def get_dummy_components (self ):
5557 torch .manual_seed (0 )
5658 unet = UNet2DConditionModel (
@@ -103,15 +105,6 @@ def get_dummy_components(self):
103105 return components
104106
105107 def get_dummy_inputs (self , device , seed = 0 ):
106- src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/src_emb_0.pt"
107- tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/tgt_emb_0.pt"
108-
109- for url in [src_emb_url , tgt_emb_url ]:
110- download_from_url (url , url .split ("/" )[- 1 ])
111-
112- src_embeds = torch .load (src_emb_url .split ("/" )[- 1 ])
113- target_embeds = torch .load (tgt_emb_url .split ("/" )[- 1 ])
114-
115108 generator = torch .manual_seed (seed )
116109
117110 inputs = {
@@ -120,8 +113,8 @@ def get_dummy_inputs(self, device, seed=0):
120113 "num_inference_steps" : 2 ,
121114 "guidance_scale" : 6.0 ,
122115 "cross_attention_guidance_amount" : 0.15 ,
123- "source_embeds" : src_embeds ,
124- "target_embeds" : target_embeds ,
116+ "source_embeds" : self . source_embeds ,
117+ "target_embeds" : self . target_embeds ,
125118 "output_type" : "numpy" ,
126119 }
127120 return inputs
@@ -237,26 +230,27 @@ def tearDown(self):
237230 gc .collect ()
238231 torch .cuda .empty_cache ()
239232
240- def get_inputs ( self , seed = 0 ):
241- generator = torch . manual_seed ( seed )
242-
243- src_emb_url = "https://hf .co/datasets/sayakpaul/sample-datasets /resolve/main/cat.pt"
244- tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt"
233+ @ classmethod
234+ def setUpClass ( cls ):
235+ cls . source_embeds = load_pt (
236+ "https://huggingface .co/datasets/hf-internal-testing/diffusers-images /resolve/main/pix2pix /cat.pt"
237+ )
245238
246- for url in [src_emb_url , tgt_emb_url ]:
247- download_from_url (url , url .split ("/" )[- 1 ])
239+ cls .target_embeds = load_pt (
240+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.pt"
241+ )
248242
249- src_embeds = torch . load ( src_emb_url . split ( "/" )[ - 1 ])
250- target_embeds = torch .load ( tgt_emb_url . split ( "/" )[ - 1 ] )
243+ def get_inputs ( self , seed = 0 ):
244+ generator = torch .manual_seed ( seed )
251245
252246 inputs = {
253247 "prompt" : "turn him into a cyborg" ,
254248 "generator" : generator ,
255249 "num_inference_steps" : 3 ,
256250 "guidance_scale" : 7.5 ,
257251 "cross_attention_guidance_amount" : 0.15 ,
258- "source_embeds" : src_embeds ,
259- "target_embeds" : target_embeds ,
252+ "source_embeds" : self . source_embeds ,
253+ "target_embeds" : self . target_embeds ,
260254 "output_type" : "numpy" ,
261255 }
262256 return inputs
@@ -364,10 +358,17 @@ def tearDown(self):
364358 gc .collect ()
365359 torch .cuda .empty_cache ()
366360
367- def test_stable_diffusion_pix2pix_inversion (self ):
368- img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
369- raw_image = Image .open (requests .get (img_url , stream = True ).raw ).convert ("RGB" ).resize ((512 , 512 ))
361+ @classmethod
362+ def setUpClass (cls ):
363+ raw_image = load_image (
364+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png"
365+ )
366+
367+ raw_image = raw_image .convert ("RGB" ).resize ((512 , 512 ))
368+
369+ cls .raw_image = raw_image
370370
371+ def test_stable_diffusion_pix2pix_inversion (self ):
371372 pipe = StableDiffusionPix2PixZeroPipeline .from_pretrained (
372373 "CompVis/stable-diffusion-v1-4" , safety_checker = None , torch_dtype = torch .float16
373374 )
@@ -380,7 +381,7 @@ def test_stable_diffusion_pix2pix_inversion(self):
380381 pipe .set_progress_bar_config (disable = None )
381382
382383 generator = torch .manual_seed (0 )
383- output = pipe .invert (caption , image = raw_image , generator = generator , num_inference_steps = 10 )
384+ output = pipe .invert (caption , image = self . raw_image , generator = generator , num_inference_steps = 10 )
384385 inv_latents = output [0 ]
385386
386387 image_slice = inv_latents [0 , - 3 :, - 3 :, - 1 ].flatten ()
@@ -391,9 +392,6 @@ def test_stable_diffusion_pix2pix_inversion(self):
391392 assert np .abs (expected_slice - image_slice .cpu ().numpy ()).max () < 1e-3
392393
393394 def test_stable_diffusion_pix2pix_full (self ):
394- img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
395- raw_image = Image .open (requests .get (img_url , stream = True ).raw ).convert ("RGB" ).resize ((512 , 512 ))
396-
397395 # numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png
398396 expected_image = load_numpy (
399397 "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.npy"
@@ -411,7 +409,7 @@ def test_stable_diffusion_pix2pix_full(self):
411409 pipe .set_progress_bar_config (disable = None )
412410
413411 generator = torch .manual_seed (0 )
414- output = pipe .invert (caption , image = raw_image , generator = generator )
412+ output = pipe .invert (caption , image = self . raw_image , generator = generator )
415413 inv_latents = output [0 ]
416414
417415 source_prompts = 4 * ["a cat sitting on the street" , "a cat playing in the field" , "a face of a cat" ]
0 commit comments