Skip to content

Commit 543ee1e

Browse files
authored
[LDMTextToImagePipeline] make text model generic (huggingface#162)
make text model generic
1 parent 75b6c16 commit 543ee1e

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ def __call__(
4545
# get unconditional embeddings for classifier free guidance
4646
if guidance_scale != 1.0:
4747
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
48-
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))
48+
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))[0]
4949

5050
# get prompt text embeddings
5151
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
52-
text_embeddings = self.bert(text_input.input_ids.to(torch_device))
52+
text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0]
5353

5454
latents = torch.randn(
5555
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
@@ -618,5 +618,4 @@ def forward(
618618
output_hidden_states=output_hidden_states,
619619
return_dict=return_dict,
620620
)
621-
sequence_output = outputs[0]
622-
return sequence_output
621+
return outputs

0 commit comments

Comments
 (0)