Skip to content

Commit 288632a

Browse files
authored
[Training utils] add kohya conversion dict. (huggingface#7435)
* add kohya conversion dict. * update readme * typo * add filename
1 parent 5ce79cb commit 288632a

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

examples/dreambooth/README_sdxl.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,17 @@ The authors found that by using DoRA, both the learning capacity and training st
259259
> This is also aligned with some of the quantitative analysis shown in the paper.
260260
261261
**Usage**
262-
1. To use DoRA you need to install `peft` from main:
262+
1. To use DoRA you need to upgrade the installation of `peft`:
263263
```bash
264-
pip install git+https://github.com/huggingface/peft.git
264+
pip install-U peft
265265
```
266266
2. Enable DoRA training by adding this flag
267267
```bash
268268
--use_dora
269269
```
270270
**Inference**
271-
The inference is the same as if you train a regular LoRA 🤗
271+
The inference is the same as if you train a regular LoRA 🤗
272+
273+
## Format compatibility
274+
275+
You can pass `--output_kohya_format` to additionally generate a state dictionary which should be compatible with other platforms and tools such as Automatic 1111, Comfy, Kohya, etc. The `output_dir` will contain a file named "pytorch_lora_weights_kohya.safetensors".

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from peft.utils import get_peft_model_state_dict
4242
from PIL import Image
4343
from PIL.ImageOps import exif_transpose
44+
from safetensors.torch import load_file, save_file
4445
from torch.utils.data import Dataset
4546
from torchvision import transforms
4647
from torchvision.transforms.functional import crop
@@ -62,7 +63,9 @@
6263
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
6364
from diffusers.utils import (
6465
check_min_version,
66+
convert_all_state_dict_to_peft,
6567
convert_state_dict_to_diffusers,
68+
convert_state_dict_to_kohya,
6669
convert_unet_state_dict_to_peft,
6770
is_wandb_available,
6871
)
@@ -396,6 +399,11 @@ def parse_args(input_args=None):
396399
default="lora-dreambooth-model",
397400
help="The output directory where the model predictions and checkpoints will be written.",
398401
)
402+
parser.add_argument(
403+
"--output_kohya_format",
404+
action="store_true",
405+
help="Flag to additionally generate final state dict in the Kohya format so that it becomes compatible with A111, Comfy, Kohya, etc.",
406+
)
399407
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
400408
parser.add_argument(
401409
"--resolution",
@@ -1890,6 +1898,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18901898
text_encoder_lora_layers=text_encoder_lora_layers,
18911899
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
18921900
)
1901+
if args.output_kohya_format:
1902+
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
1903+
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
1904+
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
1905+
save_file(kohya_state_dict, f"{args.output_dir}/pytorch_lora_weights_kohya.safetensors")
18931906

18941907
# Final inference
18951908
# Load previous pipeline

0 commit comments

Comments
 (0)