Skip to content

Commit 1586186

Browse files
pix2pix tests no write to fs (huggingface#2497)
* attend and excite batch test causing timeouts * pix2pix tests, no write to fs
1 parent 42beaf1 commit 1586186

File tree

2 files changed

+44
-39
lines changed

2 files changed

+44
-39
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,13 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
209209
return arry
210210

211211

212+
def load_pt(url: str):
213+
response = requests.get(url)
214+
response.raise_for_status()
215+
arry = torch.load(BytesIO(response.content))
216+
return arry
217+
218+
212219
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
213220
"""
214221
Args:

tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
import unittest
1818

1919
import numpy as np
20-
import requests
2120
import torch
22-
from PIL import Image
2321
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
2422

2523
from diffusers import (
@@ -33,24 +31,28 @@
3331
UNet2DConditionModel,
3432
)
3533
from diffusers.utils import load_numpy, slow, torch_device
36-
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
34+
from diffusers.utils.testing_utils import load_image, load_pt, require_torch_gpu, skip_mps
3735

3836
from ...test_pipelines_common import PipelineTesterMixin
3937

4038

4139
torch.backends.cuda.matmul.allow_tf32 = False
4240

4341

44-
def download_from_url(embedding_url, local_filepath):
45-
r = requests.get(embedding_url)
46-
with open(local_filepath, "wb") as f:
47-
f.write(r.content)
48-
49-
5042
@skip_mps
5143
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
5244
pipeline_class = StableDiffusionPix2PixZeroPipeline
5345

46+
@classmethod
47+
def setUpClass(cls):
48+
cls.source_embeds = load_pt(
49+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/src_emb_0.pt"
50+
)
51+
52+
cls.target_embeds = load_pt(
53+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/tgt_emb_0.pt"
54+
)
55+
5456
def get_dummy_components(self):
5557
torch.manual_seed(0)
5658
unet = UNet2DConditionModel(
@@ -103,15 +105,6 @@ def get_dummy_components(self):
103105
return components
104106

105107
def get_dummy_inputs(self, device, seed=0):
106-
src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/src_emb_0.pt"
107-
tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/tgt_emb_0.pt"
108-
109-
for url in [src_emb_url, tgt_emb_url]:
110-
download_from_url(url, url.split("/")[-1])
111-
112-
src_embeds = torch.load(src_emb_url.split("/")[-1])
113-
target_embeds = torch.load(tgt_emb_url.split("/")[-1])
114-
115108
generator = torch.manual_seed(seed)
116109

117110
inputs = {
@@ -120,8 +113,8 @@ def get_dummy_inputs(self, device, seed=0):
120113
"num_inference_steps": 2,
121114
"guidance_scale": 6.0,
122115
"cross_attention_guidance_amount": 0.15,
123-
"source_embeds": src_embeds,
124-
"target_embeds": target_embeds,
116+
"source_embeds": self.source_embeds,
117+
"target_embeds": self.target_embeds,
125118
"output_type": "numpy",
126119
}
127120
return inputs
@@ -237,26 +230,27 @@ def tearDown(self):
237230
gc.collect()
238231
torch.cuda.empty_cache()
239232

240-
def get_inputs(self, seed=0):
241-
generator = torch.manual_seed(seed)
242-
243-
src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt"
244-
tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt"
233+
@classmethod
234+
def setUpClass(cls):
235+
cls.source_embeds = load_pt(
236+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat.pt"
237+
)
245238

246-
for url in [src_emb_url, tgt_emb_url]:
247-
download_from_url(url, url.split("/")[-1])
239+
cls.target_embeds = load_pt(
240+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.pt"
241+
)
248242

249-
src_embeds = torch.load(src_emb_url.split("/")[-1])
250-
target_embeds = torch.load(tgt_emb_url.split("/")[-1])
243+
def get_inputs(self, seed=0):
244+
generator = torch.manual_seed(seed)
251245

252246
inputs = {
253247
"prompt": "turn him into a cyborg",
254248
"generator": generator,
255249
"num_inference_steps": 3,
256250
"guidance_scale": 7.5,
257251
"cross_attention_guidance_amount": 0.15,
258-
"source_embeds": src_embeds,
259-
"target_embeds": target_embeds,
252+
"source_embeds": self.source_embeds,
253+
"target_embeds": self.target_embeds,
260254
"output_type": "numpy",
261255
}
262256
return inputs
@@ -364,10 +358,17 @@ def tearDown(self):
364358
gc.collect()
365359
torch.cuda.empty_cache()
366360

367-
def test_stable_diffusion_pix2pix_inversion(self):
368-
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
369-
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
361+
@classmethod
362+
def setUpClass(cls):
363+
raw_image = load_image(
364+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png"
365+
)
366+
367+
raw_image = raw_image.convert("RGB").resize((512, 512))
368+
369+
cls.raw_image = raw_image
370370

371+
def test_stable_diffusion_pix2pix_inversion(self):
371372
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
372373
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
373374
)
@@ -380,7 +381,7 @@ def test_stable_diffusion_pix2pix_inversion(self):
380381
pipe.set_progress_bar_config(disable=None)
381382

382383
generator = torch.manual_seed(0)
383-
output = pipe.invert(caption, image=raw_image, generator=generator, num_inference_steps=10)
384+
output = pipe.invert(caption, image=self.raw_image, generator=generator, num_inference_steps=10)
384385
inv_latents = output[0]
385386

386387
image_slice = inv_latents[0, -3:, -3:, -1].flatten()
@@ -391,9 +392,6 @@ def test_stable_diffusion_pix2pix_inversion(self):
391392
assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 1e-3
392393

393394
def test_stable_diffusion_pix2pix_full(self):
394-
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
395-
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
396-
397395
# numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png
398396
expected_image = load_numpy(
399397
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.npy"
@@ -411,7 +409,7 @@ def test_stable_diffusion_pix2pix_full(self):
411409
pipe.set_progress_bar_config(disable=None)
412410

413411
generator = torch.manual_seed(0)
414-
output = pipe.invert(caption, image=raw_image, generator=generator)
412+
output = pipe.invert(caption, image=self.raw_image, generator=generator)
415413
inv_latents = output[0]
416414

417415
source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]

0 commit comments

Comments
 (0)