Skip to content

Commit bfe37f3

Browse files
pcuencapatil-suraj
andauthored
Reproducible images by supplying latents to pipeline (huggingface#247)
* Accept latents as input for StableDiffusionPipeline. * Notebook to demonstrate reusable seeds (latents). * More accurate type annotation Co-authored-by: Suraj Patil <[email protected]> * Review comments: move to device, raise instead of assert. * Actually commit the test notebook. I had mistakenly pushed an empty file instead. * Adapt notebook to Colab. * Update examples readme. * Move notebook to personal repo. Co-authored-by: Suraj Patil <[email protected]>
1 parent 89793a9 commit bfe37f3

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

examples/inference/readme.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,8 @@ with autocast("cuda"):
4747

4848
images[0].save("fantasy_landscape.png")
4949
```
50-
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb)
50+
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb)
51+
52+
## Tweak prompts reusing seeds and latents
53+
54+
You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb).

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __call__(
4646
guidance_scale: Optional[float] = 7.5,
4747
eta: Optional[float] = 0.0,
4848
generator: Optional[torch.Generator] = None,
49+
latents: Optional[torch.FloatTensor] = None,
4950
output_type: Optional[str] = "pil",
5051
**kwargs,
5152
):
@@ -98,12 +99,18 @@ def __call__(
9899
# to avoid doing two forward passes
99100
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
100101

101-
# get the intial random noise
102-
latents = torch.randn(
103-
(batch_size, self.unet.in_channels, height // 8, width // 8),
104-
generator=generator,
105-
device=self.device,
106-
)
102+
# get the initial random noise unless the user supplied it
103+
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
104+
if latents is None:
105+
latents = torch.randn(
106+
latents_shape,
107+
generator=generator,
108+
device=self.device,
109+
)
110+
else:
111+
if latents.shape != latents_shape:
112+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
113+
latents = latents.to(self.device)
107114

108115
# set timesteps
109116
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())

0 commit comments

Comments
 (0)