Skip to content

Commit c53a850

Browse files
[Batched Generators] This PR adds generators that are useful to make batched generation fully reproducible (huggingface#1718)
* [Batched Generators] all batched generators * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * hey * up again * fix tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * correct tests Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 086c7f9 commit c53a850

33 files changed

+571
-183
lines changed

README.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,8 @@ image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
302302

303303
### Tweak prompts reusing seeds and latents
304304

305-
You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](https://github.com/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb).
306-
307-
308-
For more details, check out [the Stable Diffusion notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb)
309-
and have a look into the [release notes](https://github.com/huggingface/diffusers/releases/tag/v0.2.0).
305+
You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked.
306+
Please have a look at [Reusing seeds for deterministic generation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/reusing_seeds).
310307

311308
## Fine-Tuning Stable Diffusion
312309

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
title: "Text-Guided Image-Inpainting"
2929
- local: using-diffusers/depth2img
3030
title: "Text-Guided Depth-to-Image"
31+
- local: using-diffusers/reusing_seeds
32+
title: "Reusing seeds for deterministic generation"
3133
- local: using-diffusers/custom_pipeline_examples
3234
title: "Community Pipelines"
3335
- local: using-diffusers/contribute_pipeline
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Re-using seeds for fast prompt engineering
14+
15+
A common use case when generating images is to generate a batch of images, select one image and improve it with a better, more detailed prompt in a second run.
16+
To do this, one needs to make each generated image of the batch deterministic.
17+
Images are generated by denoising gaussian random noise which can be instantiated by passing a [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html#generator).
18+
19+
Now, for batched generation, we need to make sure that every single generated image in the batch is tied exactly to one seed. In 🧨 Diffusers, this can be achieved by not passing one `generator`, but a list
20+
of `generators` to the pipeline.
21+
22+
Let's go through an example using [`runwayml/stable-diffusion-v1-5`](runwayml/stable-diffusion-v1-5).
23+
We want to generate several versions of the prompt:
24+
25+
```py
26+
prompt = "Labrador in the style of Vermeer"
27+
```
28+
29+
Let's load the pipeline
30+
31+
```python
32+
>>> from diffusers import DiffusionPipeline
33+
34+
>>> pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
35+
>>> pipe = pipe.to("cuda")
36+
```
37+
38+
Now, let's define 4 different generators, since we would like to reproduce a certain image. We'll use seeds `0` to `3` to create our generators.
39+
40+
```python
41+
>>> import torch
42+
43+
>>> generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
44+
```
45+
46+
Let's generate 4 images:
47+
48+
```python
49+
>>> images = pipe(prompt, generator=generator, num_images_per_prompt=4).images
50+
>>> images
51+
```
52+
53+
![img](https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/reusabe_seeds.jpg)
54+
55+
Ok, the last images has some double eyes, but the first image looks good!
56+
Let's try to make the prompt a bit better **while keeping the first seed**
57+
so that the images are similar to the first image.
58+
59+
```python
60+
prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]]
61+
generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]
62+
```
63+
64+
We create 4 generators with seed `0`, which is the first seed we used before.
65+
66+
Let's run the pipeline again.
67+
68+
```python
69+
>>> images = pipe(prompt, generator=generator).images
70+
>>> images
71+
```
72+
73+
![img](https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/reusabe_seeds_2.jpg)

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -379,12 +379,24 @@ def check_inputs(self, prompt, height, width, callback_steps):
379379

380380
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
381381
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
382+
if isinstance(generator, list) and len(generator) != batch_size:
383+
raise ValueError(
384+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
385+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
386+
)
387+
382388
if latents is None:
383-
if device.type == "mps":
384-
# randn does not work reproducibly on mps
385-
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
389+
rand_device = "cpu" if device.type == "mps" else device
390+
391+
if isinstance(generator, list):
392+
shape = (1,) + shape[1:]
393+
latents = [
394+
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
395+
for i in range(batch_size)
396+
]
397+
latents = torch.cat(latents, dim=0).to(device)
386398
else:
387-
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
399+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
388400
else:
389401
if latents.shape != shape:
390402
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
@@ -405,7 +417,7 @@ def __call__(
405417
negative_prompt: Optional[Union[str, List[str]]] = None,
406418
num_images_per_prompt: Optional[int] = 1,
407419
eta: float = 0.0,
408-
generator: Optional[torch.Generator] = None,
420+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
409421
latents: Optional[torch.FloatTensor] = None,
410422
output_type: Optional[str] = "pil",
411423
return_dict: bool = True,
@@ -440,8 +452,8 @@ def __call__(
440452
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
441453
[`schedulers.DDIMScheduler`], will be ignored for others.
442454
generator (`torch.Generator`, *optional*):
443-
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
444-
deterministic.
455+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
456+
to make generation deterministic.
445457
latents (`torch.FloatTensor`, *optional*):
446458
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
447459
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,22 @@ def get_timesteps(self, num_inference_steps, strength, device):
396396

397397
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
398398
image = image.to(device=device, dtype=dtype)
399-
init_latent_dist = self.vae.encode(image).latent_dist
400-
init_latents = init_latent_dist.sample(generator=generator)
399+
400+
batch_size = batch_size * num_images_per_prompt
401+
if isinstance(generator, list) and len(generator) != batch_size:
402+
raise ValueError(
403+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
404+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
405+
)
406+
407+
if isinstance(generator, list):
408+
init_latents = [
409+
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
410+
]
411+
init_latents = torch.cat(init_latents, dim=0)
412+
else:
413+
init_latents = self.vae.encode(image).latent_dist.sample(generator)
414+
401415
init_latents = 0.18215 * init_latents
402416

403417
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
@@ -410,16 +424,24 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
410424
)
411425
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
412426
additional_image_per_prompt = batch_size // init_latents.shape[0]
413-
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
427+
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
414428
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
415429
raise ValueError(
416430
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
417431
)
418432
else:
419-
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
420-
421-
# add noise to latents using the timesteps
422-
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
433+
init_latents = torch.cat([init_latents], dim=0)
434+
435+
rand_device = "cpu" if device.type == "mps" else device
436+
shape = init_latents.shape
437+
if isinstance(generator, list):
438+
shape = (1,) + shape[1:]
439+
noise = [
440+
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
441+
]
442+
noise = torch.cat(noise, dim=0).to(device)
443+
else:
444+
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
423445

424446
# get latents
425447
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
@@ -438,7 +460,7 @@ def __call__(
438460
negative_prompt: Optional[Union[str, List[str]]] = None,
439461
num_images_per_prompt: Optional[int] = 1,
440462
eta: Optional[float] = 0.0,
441-
generator: Optional[torch.Generator] = None,
463+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
442464
output_type: Optional[str] = "pil",
443465
return_dict: bool = True,
444466
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -478,8 +500,8 @@ def __call__(
478500
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
479501
[`schedulers.DDIMScheduler`], will be ignored for others.
480502
generator (`torch.Generator`, *optional*):
481-
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
482-
deterministic.
503+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
504+
to make generation deterministic.
483505
output_type (`str`, *optional*, defaults to `"pil"`):
484506
The output format of the generate image. Choose between
485507
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.

src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Optional, Tuple, Union
16+
from typing import List, Optional, Tuple, Union
1717

1818
import torch
1919

@@ -45,7 +45,7 @@ def __call__(
4545
self,
4646
batch_size: int = 1,
4747
num_inference_steps: int = 100,
48-
generator: Optional[torch.Generator] = None,
48+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
4949
audio_length_in_s: Optional[float] = None,
5050
return_dict: bool = True,
5151
) -> Union[AudioPipelineOutput, Tuple]:
@@ -57,8 +57,8 @@ def __call__(
5757
The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at
5858
the expense of slower inference.
5959
generator (`torch.Generator`, *optional*):
60-
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
61-
deterministic.
60+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
61+
to make generation deterministic.
6262
audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`):
6363
The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
6464
`sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`.
@@ -94,9 +94,23 @@ def __call__(
9494
sample_size = int(sample_size)
9595

9696
dtype = next(iter(self.unet.parameters())).dtype
97-
audio = torch.randn(
98-
(batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device, dtype=dtype
99-
)
97+
shape = (batch_size, self.unet.in_channels, sample_size)
98+
if isinstance(generator, list) and len(generator) != batch_size:
99+
raise ValueError(
100+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
101+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
102+
)
103+
104+
rand_device = "cpu" if self.device.type == "mps" else self.device
105+
if isinstance(generator, list):
106+
shape = (1,) + shape[1:]
107+
audio = [
108+
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
109+
for i in range(batch_size)
110+
]
111+
audio = torch.cat(audio, dim=0).to(self.device)
112+
else:
113+
audio = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(self.device)
100114

101115
# set step values
102116
self.scheduler.set_timesteps(num_inference_steps, device=audio.device)

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional, Tuple, Union
15+
from typing import List, Optional, Tuple, Union
1616

1717
import torch
1818

@@ -40,7 +40,7 @@ def __init__(self, unet, scheduler):
4040
def __call__(
4141
self,
4242
batch_size: int = 1,
43-
generator: Optional[torch.Generator] = None,
43+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
4444
eta: float = 0.0,
4545
num_inference_steps: int = 50,
4646
use_clipped_model_output: Optional[bool] = None,
@@ -52,8 +52,8 @@ def __call__(
5252
batch_size (`int`, *optional*, defaults to 1):
5353
The number of images to generate.
5454
generator (`torch.Generator`, *optional*):
55-
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
56-
deterministic.
55+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
56+
to make generation deterministic.
5757
eta (`float`, *optional*, defaults to 0.0):
5858
The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
5959
num_inference_steps (`int`, *optional*, defaults to 50):
@@ -74,7 +74,12 @@ def __call__(
7474
generated images.
7575
"""
7676

77-
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
77+
if (
78+
generator is not None
79+
and isinstance(generator, torch.Generator)
80+
and generator.device.type != self.device.type
81+
and self.device.type != "mps"
82+
):
7883
message = (
7984
f"The `generator` device is `{generator.device}` and does not match the pipeline "
8085
f"device `{self.device}`, so the `generator` will be ignored. "
@@ -93,12 +98,23 @@ def __call__(
9398
else:
9499
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
95100

96-
if self.device.type == "mps":
97-
# randn does not work reproducibly on mps
98-
image = torch.randn(image_shape, generator=generator, dtype=self.unet.dtype)
99-
image = image.to(self.device)
101+
if isinstance(generator, list) and len(generator) != batch_size:
102+
raise ValueError(
103+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
104+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
105+
)
106+
107+
rand_device = "cpu" if self.device.type == "mps" else self.device
108+
if isinstance(generator, list):
109+
shape = (1,) + image_shape[1:]
110+
image = [
111+
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
112+
for i in range(batch_size)
113+
]
114+
image = torch.cat(image, dim=0).to(self.device)
100115
else:
101-
image = torch.randn(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
116+
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
117+
image = image.to(self.device)
102118

103119
# set step values
104120
self.scheduler.set_timesteps(num_inference_steps)

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Optional, Tuple, Union
16+
from typing import List, Optional, Tuple, Union
1717

1818
import torch
1919

@@ -42,7 +42,7 @@ def __init__(self, unet, scheduler):
4242
def __call__(
4343
self,
4444
batch_size: int = 1,
45-
generator: Optional[torch.Generator] = None,
45+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
4646
num_inference_steps: int = 1000,
4747
output_type: Optional[str] = "pil",
4848
return_dict: bool = True,
@@ -53,8 +53,8 @@ def __call__(
5353
batch_size (`int`, *optional*, defaults to 1):
5454
The number of images to generate.
5555
generator (`torch.Generator`, *optional*):
56-
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
57-
deterministic.
56+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
57+
to make generation deterministic.
5858
num_inference_steps (`int`, *optional*, defaults to 1000):
5959
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
6060
expense of slower inference.

0 commit comments

Comments
 (0)