Skip to content

Commit 2d1f218

Browse files
Aiden-FrostRahul Ramansayakpaul
authored
example: Train Instruct pix2 pix with lora implementation (huggingface#6469)
* base template file - train_instruct_pix2pix.py * additional import and parser argument requried for lora * finetune only instructpix2pix model -- no need to include these layers * inject lora layers * freeze unet model -- only lora layers are trained * training modifications to train only lora parameters * store only lora parameters * move train script to research project * run quality and style code checks * move train script to a new folder * add README * update README * update references in README --------- Co-authored-by: Rahul Raman <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 3be7c96 commit 2d1f218

File tree

2 files changed

+1125
-0
lines changed

2 files changed

+1125
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# InstructPix2Pix text-to-edit-image fine-tuning
2+
This extended LoRA training script was authored by [Aiden-Frost](https://github.com/Aiden-Frost).
3+
This is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py). This script provides further support add LoRA layers for unet model.
4+
5+
## Training script example
6+
7+
```bash
8+
export MODEL_ID="timbrooks/instruct-pix2pix"
9+
export DATASET_ID="instruction-tuning-sd/cartoonization"
10+
export OUTPUT_DIR="instructPix2Pix-cartoonization"
11+
12+
accelerate launch finetune_instruct_pix2pix.py \
13+
--pretrained_model_name_or_path=$MODEL_ID \
14+
--dataset_name=$DATASET_ID \
15+
--enable_xformers_memory_efficient_attention \
16+
--resolution=256 --random_flip \
17+
--train_batch_size=2 --gradient_accumulation_steps=4 --gradient_checkpointing \
18+
--max_train_steps=15000 \
19+
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
20+
--learning_rate=5e-05 --lr_warmup_steps=0 \
21+
--val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
22+
--validation_prompt="Generate a cartoonized version of the natural image" \
23+
--seed=42 \
24+
--rank=4 \
25+
--output_dir=$OUTPUT_DIR \
26+
--report_to=wandb \
27+
--push_to_hub
28+
```
29+
30+
## Inference
31+
After training the model and the lora weight of the model is stored in the ```$OUTPUT_DIR```.
32+
33+
```bash
34+
# load the base model pipeline
35+
pipe_lora = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix")
36+
37+
# Load LoRA weights from the provided path
38+
output_dir = "path/to/lora_weight_directory"
39+
pipe_lora.unet.load_attn_procs(output_dir)
40+
41+
input_image_path = "/path/to/input_image"
42+
input_image = Image.open(input_image_path)
43+
edited_images = pipe_lora(num_images_per_prompt=1, prompt=args.edit_prompt, image=input_image, num_inference_steps=1000).images
44+
edited_images[0].show()
45+
46+
```
47+
48+
## Results
49+
50+
Here is an example of using the script to train a instructpix2pix model.
51+
Trained on google colab T4 GPU
52+
53+
```bash
54+
MODEL_ID="timbrooks/instruct-pix2pix"
55+
DATASET_ID="instruction-tuning-sd/cartoonization"
56+
TRAIN_EPOCHS=100
57+
```
58+
59+
Below are few examples for given the input image, edit_prompt and the edited_image (output of the model)
60+
61+
<p align="center">
62+
<img src="https://github.com/Aiden-Frost/Efficiently-teaching-counting-and-cartoonization-to-InstructPix2Pix.-/blob/main/diffusers_result_assets/edited_image_results.png?raw=true" alt="instructpix2pix-inputs" width=600/>
63+
</p>
64+
65+
66+
Here are some rough statistics about the training model using this script
67+
68+
<p align="center">
69+
<img src="https://github.com/Aiden-Frost/Efficiently-teaching-counting-and-cartoonization-to-InstructPix2Pix.-/blob/main/diffusers_result_assets/results.png?raw=true" alt="instructpix2pix-inputs" width=600/>
70+
</p>
71+
72+
## References
73+
74+
* InstructPix2Pix - https://github.com/timothybrooks/instruct-pix2pix
75+
* Dataset and example training script - https://huggingface.co/blog/instruction-tuning-sd
76+
* For more information about the project - https://github.com/Aiden-Frost/Efficiently-teaching-counting-and-cartoonization-to-InstructPix2Pix.-

0 commit comments

Comments
 (0)