Skip to content

Commit 6b33c11

Browse files
sunhsAisukopcuencapatrickvonplatenmauricio-repetto
authored
add noise_sampler_seed to StableDiffusionKDiffusionPipeline.__call__ (huggingface#3911)
* add noise_sampler to StableDiffusionKDiffusionPipeline * fix/docs: Fix the broken doc links (huggingface#3897) * fix/docs: Fix the broken doc links Signed-off-by: GitHub <[email protected]> * Update docs/source/en/using-diffusers/write_own_pipeline.mdx Co-authored-by: Pedro Cuenca <[email protected]> --------- Signed-off-by: GitHub <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> * Add video img2img (huggingface#3900) * Add image to image video * Improve * better naming * make fix copies * add docs * finish tests * trigger tests * make style * correct * finish * Fix more * make style * finish * fix/doc-code: Updating to the latest version parameters (huggingface#3924) fix/doc-code: update to use the new parameter Signed-off-by: GitHub <[email protected]> * fix/doc: no import torch issue (huggingface#3923) Ffix/doc: no import torch issue Signed-off-by: GitHub <[email protected]> * Correct controlnet out of list error (huggingface#3928) * Correct controlnet out of list error * Apply suggestions from code review * correct tests * correct tests * fix * test all * Apply suggestions from code review * test all * test all * Apply suggestions from code review * Apply suggestions from code review * fix more tests * Fix more * Apply suggestions from code review * finish * Apply suggestions from code review * Update src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py * finish * Adding better way to define multiple concepts and also validation capabilities. (huggingface#3807) * - Added validation parameters - Changed some parameter descriptions to better explain their use. - Fixed a few typos. - Added concept_list parameter for better management of multiple subjects - changed logic for image validation * - Fixed bad logic for class data root directories * Defaulting validation_steps to None for an easier logic * Fixed multiple validation prompts * Fixed bug on validation negative prompt * Changed validation logic for tracker. * Added uuid for validation image labeling * Fix error when comparing validation prompts and validation negative prompts * Improved error message when negative prompts for validation are more than the number of prompts * - Changed image tracking number from epoch to global_step - Added Typing for functions * Added some validations more when using concept_list parameter and the regular ones. * Fixed error message * Added more validations for validation parameters * Improved messaging for errors * Fixed validation error for parameters with default values * - Added train step to image name for validation - reformatted code * - Added train step to image's name for validation - reformatted code * Updated README.md file. * reverted back original script of train_dreambooth.py * reverted back original script of train_dreambooth.py * left one blank line at the eof * reverted back setup.py * reverted back setup.py * added same logic for when parameters for prior preservation are used without enabling the flag while using concept_list parameter. * Ran black formatter. * fixed a few strings * fixed import sort with isort and removed fstrings without placeholder * fixed import order with ruff (since with isort wasn't ok) --------- Co-authored-by: Patrick von Platen <[email protected]> * [ldm3d] Update code to be functional with the new checkpoints (huggingface#3875) * fixed typo * updated doc to be consistent in naming * make style/quality * preprocessing for 4 channels and not 6 * make style * test for 4c * make style/quality * fixed test on cpu --------- Co-authored-by: Aflalo <[email protected]> Co-authored-by: Aflalo <[email protected]> Co-authored-by: Aflalo <[email protected]> * Improve memory text to video (huggingface#3930) * Improve memory text to video * Apply suggestions from code review * add test * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * finish test setup --------- Co-authored-by: Pedro Cuenca <[email protected]> * revert automatic chunking (huggingface#3934) * revert automatic chunking * Apply suggestions from code review * revert automatic chunking * avoid upcasting by assigning dtype to noise tensor (huggingface#3713) * avoid upcasting by assigning dtype to noise tensor * make style * Update train_unconditional.py * Update train_unconditional.py * make style * add unit test for pickle * revert change --------- Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> * Fix failing np tests (huggingface#3942) * Fix failing np tests * Apply suggestions from code review * Update tests/pipelines/test_pipelines_common.py * Add `timestep_spacing` and `steps_offset` to schedulers (huggingface#3947) * Add timestep_spacing to DDPM, LMSDiscrete, PNDM. * Remove spurious line. * More easy schedulers. * Add `linspace` to DDIM * Noise sigma for `trailing`. * Add timestep_spacing to DEISMultistepScheduler. Not sure the range is the way it was intended. * Fix: remove line used to debug. * Support timestep_spacing in DPMSolverMultistep, DPMSolverSDE, UniPC * Fix: convert to numpy. * Use sched. defaults when instantiating from_config For params not present in the original configuration. This makes it possible to switch pipeline schedulers even if they use different timestep_spacing (or any other param). * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * Missing args in DPMSolverMultistep * Test: default args not in config * Style * Fix scheduler name in test * Remove duplicated entries * Add test for solver_type This test currently fails in main. When switching from DEIS to UniPC, solver_type is "logrho" (the default value from DEIS), which gets translated to "bh1" by UniPC. This is different to the default value for UniPC: "bh2". This is where the translation happens: https://github.com/huggingface/diffusers/blob/36d22d0709dc19776e3016fb3392d0f5578b0ab2/src/diffusers/schedulers/scheduling_unipc_multistep.py#L171 * UniPC: use same default for solver_type Fixes a bug when switching from UniPC from another scheduler (i.e., DEIS) that uses a different solver type. The solver is now the same as if we had instantiated the scheduler directly. * do not save use default values * fix more * fix all * fix schedulers * fix more * finish for real * finish for real * flaky tests * Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py * Default steps_offset to 0. * Add missing docstrings * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen <[email protected]> * Add Consistency Models Pipeline (huggingface#3492) * initial commit * Improve consistency models sampling implementation. * Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling. * Add Unet blocks for consistency models * Add conversion script for Unet * Fix bug in new unet blocks * Fix attention weight loading * Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests. * make style * Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints. * Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script. * make style * Change num_class_embeds to 1000 to better match the original consistency models implementation. * Add support for distillation in pipeline_consistency_models.py. * Improve consistency model tests: - Get small testing checkpoints from hub - Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline - Add onestep, multistep tests for distillation and distillation + class conditional - Add expected image slices for onestep tests * make style * Improve ConsistencyModelPipeline: - Add initial support for class-conditional generation - Fix initial sigma for onestep generation - Fix some sigma shape issues * make style * Improve ConsistencyModelPipeline: - add latents __call__ argument and prepare_latents method - add check_inputs method - add initial docstrings for ConsistencyModelPipeline.__call__ * make style * Fix bug when randomly generating class labels for class-conditional generation. * Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests. * Remove some unused code and make style. * Fix small bug in CMStochasticIterativeScheduler. * Add expected slices for multistep sampling tests and make them pass. * Work on consistency model fast tests: - in pipeline, call self.scheduler.scale_model_input before denoising - get expected slices for Euler and Heun scheduler tests - make Euler test pass - mark Heun test as expected fail because it doesn't support prediction_type "sample" yet - remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas * make style * Refactor conversion script to make it easier to add more model architectures to convert in the future. * Work on ConsistencyModelPipeline tests: - Fix device bug when handling class labels in ConsistencyModelPipeline.__call__ - Add slow tests for onestep and multistep sampling and make them pass - Refactor fast tests - Refactor ConsistencyModelPipeline.__init__ * make style * Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now. * Run python utils/check_copies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler. * Make fast tests from PipelineTesterMixin pass. * make style * Refactor consistency models pipeline and scheduler: - Remove support for Karras schedulers (only support CMStochasticIterativeScheduler) - Move sigma manipulation, input scaling, denoising from pipeline to scheduler - Make corresponding changes to tests and ensure they pass * make style * Add docstrings and further refactor pipeline and scheduler. * make style * Add initial version of the consistency models documentation. * Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging. * make style * Convert current slow tests to use fp16 and flash attention. * make style * Add slow tests for normal attention on cuda device. * make style * Fix attention weights loading * Update consistency model fast tests for new test checkpoints with attention fix. * make style * apply suggestions * Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler). * Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers. * When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence). * make style * Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example. * apply suggestions from review * make style * fix attention naming * Add tests for CMStochasticIterativeScheduler. * make style * Make CMStochasticIterativeScheduler tests pass. * make style * Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest. * make style * rename some models * Improve API * rename some models * Remove duplicated block * Add docstring and make torch compile work * More fixes * Fixes * Apply suggestions from code review * Apply suggestions from code review * add more docstring * update consistency conversion script --------- Co-authored-by: ayushmangal <[email protected]> Co-authored-by: Ayush Mangal <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * add test case for StableDiffusionKDiffusionPipeline noise_sampler --------- Signed-off-by: GitHub <[email protected]> Co-authored-by: Aisuko <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Andrés Mauricio Repetto Ferrero <[email protected]> Co-authored-by: estelleafl <[email protected]> Co-authored-by: Aflalo <[email protected]> Co-authored-by: Aflalo <[email protected]> Co-authored-by: Aflalo <[email protected]> Co-authored-by: Prathik Rao <[email protected]> Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> Co-authored-by: dg845 <[email protected]> Co-authored-by: ayushmangal <[email protected]> Co-authored-by: Ayush Mangal <[email protected]>
1 parent 5729829 commit 6b33c11

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414

1515
import importlib
16+
import inspect
1617
import warnings
1718
from typing import Callable, List, Optional, Union
1819

1920
import torch
2021
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
21-
from k_diffusion.sampling import get_sigmas_karras
22+
from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
2223

2324
from ...image_processor import VaeImageProcessor
2425
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
@@ -464,6 +465,7 @@ def __call__(
464465
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
465466
callback_steps: int = 1,
466467
use_karras_sigmas: Optional[bool] = False,
468+
noise_sampler_seed: Optional[int] = None,
467469
):
468470
r"""
469471
Function invoked when calling the pipeline for generation.
@@ -524,6 +526,8 @@ def __call__(
524526
Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to
525527
`DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M
526528
Karras`.
529+
noise_sampler_seed (`int`, *optional*, defaults to `None`):
530+
The random seed to use for the noise sampler. If `None`, a random seed will be generated.
527531
Returns:
528532
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
529533
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
@@ -608,7 +612,14 @@ def model_fn(x, t):
608612
return noise_pred
609613

610614
# 8. Run k-diffusion solver
611-
latents = self.sampler(model_fn, latents, sigmas)
615+
sampler_kwargs = {}
616+
617+
if "noise_sampler" in inspect.signature(self.sampler).parameters:
618+
min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max()
619+
noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
620+
sampler_kwargs["noise_sampler"] = noise_sampler
621+
622+
latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
612623

613624
if not output_type == "latent":
614625
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]

tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,33 @@ def test_stable_diffusion_karras_sigmas(self):
104104
)
105105

106106
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
107+
108+
def test_stable_diffusion_noise_sampler_seed(self):
109+
sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
110+
sd_pipe = sd_pipe.to(torch_device)
111+
sd_pipe.set_progress_bar_config(disable=None)
112+
113+
sd_pipe.set_scheduler("sample_dpmpp_sde")
114+
115+
prompt = "A painting of a squirrel eating a burger"
116+
seed = 0
117+
images1 = sd_pipe(
118+
[prompt],
119+
generator=torch.manual_seed(seed),
120+
noise_sampler_seed=seed,
121+
guidance_scale=9.0,
122+
num_inference_steps=20,
123+
output_type="np",
124+
).images
125+
images2 = sd_pipe(
126+
[prompt],
127+
generator=torch.manual_seed(seed),
128+
noise_sampler_seed=seed,
129+
guidance_scale=9.0,
130+
num_inference_steps=20,
131+
output_type="np",
132+
).images
133+
134+
assert images1.shape == (1, 512, 512, 3)
135+
assert images2.shape == (1, 512, 512, 3)
136+
assert np.abs(images1.flatten() - images2.flatten()).max() < 1e-2

0 commit comments

Comments
 (0)