Skip to content

Commit c75e12d

Browse files
committed
Port over tweaked pipeline/scheduler
1 parent ec5449f commit c75e12d

File tree

2 files changed

+49
-50
lines changed

2 files changed

+49
-50
lines changed

src/diffusers/pipelines/chroma/pipeline_chroma.py

Lines changed: 43 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from ...schedulers import FlowMatchEulerDiscreteScheduler
2626
from ...utils import (
2727
USE_PEFT_BACKEND,
28-
deprecate,
2928
is_torch_xla_available,
3029
logging,
3130
replace_example_docstring,
@@ -509,25 +508,13 @@ def enable_vae_slicing(self):
509508
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
510509
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
511510
"""
512-
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
513-
deprecate(
514-
"enable_vae_slicing",
515-
"0.40.0",
516-
depr_message,
517-
)
518511
self.vae.enable_slicing()
519512

520513
def disable_vae_slicing(self):
521514
r"""
522515
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
523516
computing decoding in one step.
524517
"""
525-
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
526-
deprecate(
527-
"disable_vae_slicing",
528-
"0.40.0",
529-
depr_message,
530-
)
531518
self.vae.disable_slicing()
532519

533520
def enable_vae_tiling(self):
@@ -536,25 +523,13 @@ def enable_vae_tiling(self):
536523
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
537524
processing larger images.
538525
"""
539-
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
540-
deprecate(
541-
"enable_vae_tiling",
542-
"0.40.0",
543-
depr_message,
544-
)
545526
self.vae.enable_tiling()
546527

547528
def disable_vae_tiling(self):
548529
r"""
549530
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
550531
computing decoding in one step.
551532
"""
552-
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
553-
deprecate(
554-
"disable_vae_tiling",
555-
"0.40.0",
556-
depr_message,
557-
)
558533
self.vae.disable_tiling()
559534

560535
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
@@ -688,11 +663,11 @@ def __call__(
688663
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
689664
will be used.
690665
guidance_scale (`float`, *optional*, defaults to 3.5):
691-
Guidance scale as defined in [Classifier-Free Diffusion
692-
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
693-
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
694-
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
695-
the text `prompt`, usually at the expense of lower image quality.
666+
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
667+
a model to generate images more aligned with `prompt` at the expense of lower image quality.
668+
669+
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
670+
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
696671
num_images_per_prompt (`int`, *optional*, defaults to 1):
697672
The number of images to generate per prompt.
698673
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -701,7 +676,7 @@ def __call__(
701676
latents (`torch.Tensor`, *optional*):
702677
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
703678
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
704-
tensor will be generated by sampling using the supplied random `generator`.
679+
tensor will ge generated by sampling using the supplied random `generator`.
705680
prompt_embeds (`torch.Tensor`, *optional*):
706681
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
707682
provided, text embeddings will be generated from `prompt` input argument.
@@ -904,31 +879,49 @@ def __call__(
904879
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
905880
timestep = t.expand(latents.shape[0]).to(latents.dtype)
906881

907-
noise_pred = self.transformer(
908-
hidden_states=latents,
909-
timestep=timestep / 1000,
910-
encoder_hidden_states=prompt_embeds,
911-
txt_ids=text_ids,
912-
img_ids=latent_image_ids,
913-
attention_mask=attention_mask,
914-
joint_attention_kwargs=self.joint_attention_kwargs,
915-
return_dict=False,
916-
)[0]
917-
918882
if self.do_classifier_free_guidance:
919-
if negative_image_embeds is not None:
920-
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
921-
neg_noise_pred = self.transformer(
883+
# Batch positive and negative prompts for single transformer call
884+
batched_latents = torch.cat([latents, latents], dim=0)
885+
batched_timestep = torch.cat([timestep, timestep], dim=0)
886+
batched_encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
887+
batched_txt_ids = torch.cat([text_ids, negative_text_ids], dim=0)
888+
batched_img_ids = torch.cat([latent_image_ids, latent_image_ids], dim=0)
889+
890+
# Handle attention masks
891+
if attention_mask is not None and negative_attention_mask is not None:
892+
batched_attention_mask = torch.cat([attention_mask, negative_attention_mask], dim=0)
893+
else:
894+
batched_attention_mask = None
895+
896+
# Single transformer call with batched inputs
897+
batched_noise_pred = self.transformer(
898+
hidden_states=batched_latents,
899+
timestep=batched_timestep / 1000,
900+
encoder_hidden_states=batched_encoder_hidden_states,
901+
txt_ids=text_ids,
902+
img_ids=latent_image_ids,
903+
attention_mask=batched_attention_mask,
904+
joint_attention_kwargs=self.joint_attention_kwargs,
905+
return_dict=False,
906+
)[0]
907+
908+
# Split the batched result back into positive and negative predictions
909+
noise_pred, neg_noise_pred = batched_noise_pred.chunk(2, dim=0)
910+
911+
# Apply classifier-free guidance
912+
noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
913+
else:
914+
# No guidance, single forward pass
915+
noise_pred = self.transformer(
922916
hidden_states=latents,
923917
timestep=timestep / 1000,
924-
encoder_hidden_states=negative_prompt_embeds,
925-
txt_ids=negative_text_ids,
918+
encoder_hidden_states=prompt_embeds,
919+
txt_ids=text_ids,
926920
img_ids=latent_image_ids,
927-
attention_mask=negative_attention_mask,
921+
attention_mask=attention_mask,
928922
joint_attention_kwargs=self.joint_attention_kwargs,
929923
return_dict=False,
930924
)[0]
931-
noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
932925

933926
# compute the previous noisy sample x_t -> x_t-1
934927
latents_dtype = latents.dtype
@@ -971,4 +964,4 @@ def __call__(
971964
if not return_dict:
972965
return (image,)
973966

974-
return ChromaPipelineOutput(images=image)
967+
return ChromaPipelineOutput(images=image)

src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def __init__(
104104
use_beta_sigmas: Optional[bool] = False,
105105
time_shift_type: str = "exponential",
106106
stochastic_sampling: bool = False,
107+
custom_sigmas = None
107108
):
108109
if self.config.use_beta_sigmas and not is_scipy_available():
109110
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -132,6 +133,7 @@ def __init__(
132133
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
133134
self.sigma_min = self.sigmas[-1].item()
134135
self.sigma_max = self.sigmas[0].item()
136+
self.custom_sigmas = custom_sigmas
135137

136138
@property
137139
def shift(self):
@@ -343,6 +345,10 @@ def set_timesteps(
343345
else:
344346
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
345347

348+
if self.custom_sigmas is not None:
349+
timesteps = torch.tensor(self.custom_sigmas[:-1], device=sigmas.device, dtype=torch.float32) * self.config.num_train_timesteps
350+
sigmas = torch.tensor(self.custom_sigmas, device=sigmas.device, dtype=torch.float32)
351+
346352
self.timesteps = timesteps
347353
self.sigmas = sigmas
348354
self._step_index = None

0 commit comments

Comments
 (0)