Skip to content

Commit cac7ada

Browse files
[Flax SDXL] fix zero out sdxl (huggingface#5203)
1 parent a584d42 commit cac7ada

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,10 @@ def _generate(
188188
# Get unconditional embeddings
189189
batch_size = prompt_embeds.shape[0]
190190
if neg_prompt_ids is None:
191-
neg_prompt_ids = self.prepare_inputs([""] * batch_size)
192-
193-
neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params)
191+
neg_prompt_embeds = jnp.zeros_like(prompt_embeds)
192+
negative_pooled_embeds = jnp.zeros_like(pooled_embeds)
193+
else:
194+
neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params)
194195

195196
add_time_ids = self._get_add_time_ids(
196197
(height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype

0 commit comments

Comments
 (0)