Skip to content

train_dreambooth_lora_flux crash on batch size greater than 1 #10994

Closed
@PluginBOXone

Description

@PluginBOXone

Describe the bug

As soon as launching the train_dreambooth_lora_flux.py (with accelerate) and train_batch_size greater than 1, it takes 40-50sec to load and then crashes with an error "RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x1536 and 768x3072)"

File "/workspace/./train_dreambooth_lora_flux.py", line 1935, in main(args) File "/workspace/./train_dreambooth_lora_flux.py", line 1729, in main model_pred = transformer( File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/accelerate/utils/operations.py", line 819, in forward return model_forward(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast return func(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/diffusers/models/transformers/transformer_flux.py", line 493, in forward else self.time_text_embed(timestep, guidance, pooled_projections) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/diffusers/models/embeddings.py", line 1629, in forward pooled_projections = self.text_embedder(pooled_projection) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/diffusers/models/embeddings.py", line 2222, in forward hidden_states = self.linear_1(caption) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/linear.py", line 117, in forward return F.linear(input, self.weight, self.bias)

Reproduction

FLUX1 Dev used

accelerate launch --config_file /workspace/accelerate.yaml ./train_dreambooth_lora_flux.py --pretrained_model_name_or_path /workspace/model/realflux1 --output_dir /workspace/output_model/577cda6e-628d-4dc0-9ab4-cbcf448b730e-e1 --resolution 1024 --learning_rate 5e-4 --mixed_precision bf16 --instance_data_dir /workspace/job_files/577cda6e-628d-4dc0-9ab4-cbcf448b730e-e1/clean_data --instance_prompt WIXBSAHA black car --lr_warmup_steps 0 --gradient_accumulation_steps 1 --lr_scheduler cosine --train_batch_size 2 --max_train_steps 500 --use_8bit_adam --gradient_checkpointing --checkpointing_steps 500 --num_train_epochs 10 --checkpoints_total_limit 1 --train_text_encoder --rank 16 --guidance_scale 1 --cache_latents

Logs

Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
Mixed precision type: bf16
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type t5 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
{'use_beta_sigmas', 'shift_terminal', 'invert_sigmas', 'use_exponential_sigmas', 'use_karras_sigmas'} was not found in config. Values will be initialized to default values.
{'out_channels', 'axes_dims_rope'} was not found in config. Values will be initialized to default values.

System Info

RunPod
NVIDIA A100 80GB PCIe CUDA Version: 12.4
accelerate 0.33.0
bitsandbytes 0.43.1
datasets 2.19.2
diffusers 0.33.0.dev0
Jinja2 3.1.6
peft 0.14.0
safetensors 0.5.3
sentencepiece 0.2.0
tensorboard 2.19.0
tokenizers 0.21.0
torch 2.4.1+cu124
torchaudio 2.4.1+cu124
torchvision 0.19.1+cu124
transformers 4.49.0

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions