|
41 | 41 | from peft.utils import get_peft_model_state_dict |
42 | 42 | from PIL import Image |
43 | 43 | from PIL.ImageOps import exif_transpose |
| 44 | +from safetensors.torch import load_file, save_file |
44 | 45 | from torch.utils.data import Dataset |
45 | 46 | from torchvision import transforms |
46 | 47 | from torchvision.transforms.functional import crop |
|
62 | 63 | from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr |
63 | 64 | from diffusers.utils import ( |
64 | 65 | check_min_version, |
| 66 | + convert_all_state_dict_to_peft, |
65 | 67 | convert_state_dict_to_diffusers, |
| 68 | + convert_state_dict_to_kohya, |
66 | 69 | convert_unet_state_dict_to_peft, |
67 | 70 | is_wandb_available, |
68 | 71 | ) |
@@ -396,6 +399,11 @@ def parse_args(input_args=None): |
396 | 399 | default="lora-dreambooth-model", |
397 | 400 | help="The output directory where the model predictions and checkpoints will be written.", |
398 | 401 | ) |
| 402 | + parser.add_argument( |
| 403 | + "--output_kohya_format", |
| 404 | + action="store_true", |
| 405 | + help="Flag to additionally generate final state dict in the Kohya format so that it becomes compatible with A111, Comfy, Kohya, etc.", |
| 406 | + ) |
399 | 407 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") |
400 | 408 | parser.add_argument( |
401 | 409 | "--resolution", |
@@ -1890,6 +1898,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): |
1890 | 1898 | text_encoder_lora_layers=text_encoder_lora_layers, |
1891 | 1899 | text_encoder_2_lora_layers=text_encoder_2_lora_layers, |
1892 | 1900 | ) |
| 1901 | + if args.output_kohya_format: |
| 1902 | + lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") |
| 1903 | + peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) |
| 1904 | + kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) |
| 1905 | + save_file(kohya_state_dict, f"{args.output_dir}/pytorch_lora_weights_kohya.safetensors") |
1893 | 1906 |
|
1894 | 1907 | # Final inference |
1895 | 1908 | # Load previous pipeline |
|
0 commit comments