Skip to content

Commit dccf39f

Browse files
authored
Dreambooth lora flux bug 3dtensor to 2dtensor (huggingface#9653)
* fixed issue huggingface#9350, Tensor is deprecated * ran make style
1 parent 99d8747 commit dccf39f

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,6 @@ def encode_prompt(
985985
text_input_ids_list=None,
986986
):
987987
prompt = [prompt] if isinstance(prompt, str) else prompt
988-
batch_size = len(prompt)
989988
dtype = text_encoders[0].dtype
990989

991990
pooled_prompt_embeds = _encode_prompt_with_clip(
@@ -1007,8 +1006,7 @@ def encode_prompt(
10071006
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
10081007
)
10091008

1010-
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
1011-
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
1009+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
10121010

10131011
return prompt_embeds, pooled_prompt_embeds, text_ids
10141012

0 commit comments

Comments
 (0)