Skip to content

Commit c82f7ba

Browse files
[SDXL Flax] fix SDXL flax init (huggingface#5187)
* fix SDXL flax init * finish * Fix
1 parent d9e7857 commit c82f7ba

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/diffusers/models/unet_2d_condition_flax.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,18 @@ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
134134

135135
added_cond_kwargs = None
136136
if self.addition_embed_type == "text_time":
137-
# TODO: how to get this from the config? It's no longer cross_attention_dim
138-
text_embeds_dim = 1280
137+
# we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner
138+
# or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim`
139+
is_refiner = (
140+
5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim
141+
== self.config.projection_class_embeddings_input_dim
142+
)
143+
num_micro_conditions = 5 if is_refiner else 6
144+
145+
text_embeds_dim = self.config.projection_class_embeddings_input_dim - (
146+
num_micro_conditions * self.config.addition_time_embed_dim
147+
)
148+
139149
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
140150
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
141151
added_cond_kwargs = {

src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,14 +215,15 @@ def _generate(
215215
else:
216216
if latents.shape != latents_shape:
217217
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
218-
# scale the initial noise by the standard deviation required by the scheduler
219-
latents = latents * params["scheduler"].init_noise_sigma
220218

221219
# Prepare scheduler state
222220
scheduler_state = self.scheduler.set_timesteps(
223221
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
224222
)
225223

224+
# scale the initial noise by the standard deviation required by the scheduler
225+
latents = latents * scheduler_state.init_noise_sigma
226+
226227
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
227228

228229
# Denoising loop

0 commit comments

Comments
 (0)