Skip to content

Commit 9b37ed3

Browse files
[SD Img2Img] resize source images to multiple of 8 instead of 32 (huggingface#1571)
* [Stable Diffusion Img2Img] resize source images to integer multiple of 8 instead of 32 * [Alt Diffusion Img2Img] resize source images to multiple of 8 instead of 32 * [Img2Img] fix AltDiffusion Img2Img resolution test * [Img2Img] add Stable Diffusion Img2Img resolution test * [Cycle Diffusion] round resolution to multiplies of 8 instead of 32 * [ONNX SD Img2Img] round resolution to multiplies of 64 instead of 32 * [SD Depth2Img] round resolution to multiplies of 8 instead of 32 * [Repaint] round resolution to multiplies of 8 instead of 32 * fix make style Co-authored-by: Patrick von Platen <[email protected]>
1 parent 135567f commit 9b37ed3

File tree

8 files changed

+80
-7
lines changed

8 files changed

+80
-7
lines changed

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def preprocess(image):
8080

8181
if isinstance(image[0], PIL.Image.Image):
8282
w, h = image[0].size
83-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
83+
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
8484

8585
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
8686
image = np.concatenate(image, axis=0)

src/diffusers/pipelines/repaint/pipeline_repaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]):
3838

3939
if isinstance(image[0], PIL.Image.Image):
4040
w, h = image[0].size
41-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
41+
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
4242

4343
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
4444
image = np.concatenate(image, axis=0)

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def preprocess(image):
4444

4545
if isinstance(image[0], PIL.Image.Image):
4646
w, h = image[0].size
47-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
47+
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
4848

4949
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
5050
image = np.concatenate(image, axis=0)

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3333

3434

35-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
35+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64
3636
def preprocess(image):
3737
if isinstance(image, torch.Tensor):
3838
return image
@@ -41,7 +41,7 @@ def preprocess(image):
4141

4242
if isinstance(image[0], PIL.Image.Image):
4343
w, h = image[0].size
44-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
44+
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
4545

4646
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
4747
image = np.concatenate(image, axis=0)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def preprocess(image):
4949

5050
if isinstance(image[0], PIL.Image.Image):
5151
w, h = image[0].size
52-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
52+
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
5353

5454
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
5555
image = np.concatenate(image, axis=0)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def preprocess(image):
8484

8585
if isinstance(image[0], PIL.Image.Image):
8686
w, h = image[0].size
87-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
87+
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
8888

8989
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
9090
image = np.concatenate(image, axis=0)

tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,43 @@ def test_stable_diffusion_img2img_fp16(self):
207207

208208
assert image.shape == (1, 32, 32, 3)
209209

210+
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
211+
def test_stable_diffusion_img2img_pipeline_multiple_of_8(self):
212+
init_image = load_image(
213+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
214+
"/img2img/sketch-mountains-input.jpg"
215+
)
216+
# resize to resolution that is divisible by 8 but not 16 or 32
217+
init_image = init_image.resize((760, 504))
218+
219+
model_id = "BAAI/AltDiffusion"
220+
pipe = AltDiffusionImg2ImgPipeline.from_pretrained(
221+
model_id,
222+
safety_checker=None,
223+
)
224+
pipe.to(torch_device)
225+
pipe.set_progress_bar_config(disable=None)
226+
pipe.enable_attention_slicing()
227+
228+
prompt = "A fantasy landscape, trending on artstation"
229+
230+
generator = torch.Generator(device=torch_device).manual_seed(0)
231+
output = pipe(
232+
prompt=prompt,
233+
image=init_image,
234+
strength=0.75,
235+
guidance_scale=7.5,
236+
generator=generator,
237+
output_type="np",
238+
)
239+
image = output.images[0]
240+
241+
image_slice = image[255:258, 383:386, -1]
242+
243+
assert image.shape == (504, 760, 3)
244+
expected_slice = np.array([0.3252, 0.3340, 0.3418, 0.3263, 0.3346, 0.3300, 0.3163, 0.3470, 0.3427])
245+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
246+
210247

211248
@slow
212249
@require_torch_gpu

tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,42 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
333333
# make sure that less than 2.2 GB is allocated
334334
assert mem_bytes < 2.2 * 10**9
335335

336+
def test_stable_diffusion_img2img_pipeline_multiple_of_8(self):
337+
init_image = load_image(
338+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
339+
"/img2img/sketch-mountains-input.jpg"
340+
)
341+
# resize to resolution that is divisible by 8 but not 16 or 32
342+
init_image = init_image.resize((760, 504))
343+
344+
model_id = "CompVis/stable-diffusion-v1-4"
345+
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
346+
model_id,
347+
safety_checker=None,
348+
)
349+
pipe.to(torch_device)
350+
pipe.set_progress_bar_config(disable=None)
351+
pipe.enable_attention_slicing()
352+
353+
prompt = "A fantasy landscape, trending on artstation"
354+
355+
generator = torch.Generator(device=torch_device).manual_seed(0)
356+
output = pipe(
357+
prompt=prompt,
358+
image=init_image,
359+
strength=0.75,
360+
guidance_scale=7.5,
361+
generator=generator,
362+
output_type="np",
363+
)
364+
image = output.images[0]
365+
366+
image_slice = image[255:258, 383:386, -1]
367+
368+
assert image.shape == (504, 760, 3)
369+
expected_slice = np.array([0.7124, 0.7105, 0.6993, 0.7140, 0.7106, 0.6945, 0.7198, 0.7172, 0.7031])
370+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
371+
336372

337373
@nightly
338374
@require_torch_gpu

0 commit comments

Comments
 (0)