Description
Describe the bug
As soon as launching the train_dreambooth_lora_flux.py (with accelerate) and train_batch_size greater than 1, it takes 40-50sec to load and then crashes with an error "RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x1536 and 768x3072)"
File "/workspace/./train_dreambooth_lora_flux.py", line 1935, in main(args) File "/workspace/./train_dreambooth_lora_flux.py", line 1729, in main model_pred = transformer( File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/accelerate/utils/operations.py", line 819, in forward return model_forward(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast return func(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/diffusers/models/transformers/transformer_flux.py", line 493, in forward else self.time_text_embed(timestep, guidance, pooled_projections) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/diffusers/models/embeddings.py", line 1629, in forward pooled_projections = self.text_embedder(pooled_projection) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/diffusers/models/embeddings.py", line 2222, in forward hidden_states = self.linear_1(caption) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/linear.py", line 117, in forward return F.linear(input, self.weight, self.bias)
Reproduction
FLUX1 Dev used
accelerate launch --config_file /workspace/accelerate.yaml ./train_dreambooth_lora_flux.py --pretrained_model_name_or_path /workspace/model/realflux1 --output_dir /workspace/output_model/577cda6e-628d-4dc0-9ab4-cbcf448b730e-e1 --resolution 1024 --learning_rate 5e-4 --mixed_precision bf16 --instance_data_dir /workspace/job_files/577cda6e-628d-4dc0-9ab4-cbcf448b730e-e1/clean_data --instance_prompt WIXBSAHA black car --lr_warmup_steps 0 --gradient_accumulation_steps 1 --lr_scheduler cosine --train_batch_size 2 --max_train_steps 500 --use_8bit_adam --gradient_checkpointing --checkpointing_steps 500 --num_train_epochs 10 --checkpoints_total_limit 1 --train_text_encoder --rank 16 --guidance_scale 1 --cache_latents
Logs
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
Mixed precision type: bf16
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type t5 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
{'use_beta_sigmas', 'shift_terminal', 'invert_sigmas', 'use_exponential_sigmas', 'use_karras_sigmas'} was not found in config. Values will be initialized to default values.
{'out_channels', 'axes_dims_rope'} was not found in config. Values will be initialized to default values.
System Info
RunPod
NVIDIA A100 80GB PCIe CUDA Version: 12.4
accelerate 0.33.0
bitsandbytes 0.43.1
datasets 2.19.2
diffusers 0.33.0.dev0
Jinja2 3.1.6
peft 0.14.0
safetensors 0.5.3
sentencepiece 0.2.0
tensorboard 2.19.0
tokenizers 0.21.0
torch 2.4.1+cu124
torchaudio 2.4.1+cu124
torchvision 0.19.1+cu124
transformers 4.49.0