Skip to content

Commit c91272d

Browse files
authored
fix indexing issue in sd reference pipeline (huggingface#4531)
1 parent f0725c5 commit c91272d

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

examples/community/stable_diffusion_reference.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,6 @@ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do
153153
)
154154
ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
155155

156-
ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
157-
158156
# aligning device to prevent device errors when concating it with the latent model input
159157
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
160158
return ref_image_latents
@@ -733,6 +731,7 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=
733731
1,
734732
),
735733
)
734+
ref_xt = torch.cat([ref_xt] * 2) if do_classifier_free_guidance else ref_xt
736735
ref_xt = self.scheduler.scale_model_input(ref_xt, t)
737736

738737
MODE = "write"

0 commit comments

Comments
 (0)