Skip to content

Commit 2f54d7e

Browse files
committed
Adding dcm image support
1 parent 7e808e7 commit 2f54d7e

File tree

9 files changed

+3276
-0
lines changed

9 files changed

+3276
-0
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
## Textual Inversion fine-tuning example
2+
3+
[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
4+
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
5+
6+
## Running on Colab
7+
8+
Colab for training
9+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
10+
11+
Colab for inference
12+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)
13+
14+
## Running locally with PyTorch
15+
### Installing the dependencies
16+
17+
Before running the scripts, make sure to install the library's training dependencies:
18+
19+
**Important**
20+
21+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
22+
```bash
23+
git clone https://github.com/huggingface/diffusers
24+
cd diffusers
25+
pip install .
26+
```
27+
28+
Then cd in the example folder and run:
29+
```bash
30+
pip install -r requirements.txt
31+
```
32+
33+
And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
34+
35+
```bash
36+
accelerate config
37+
```
38+
39+
### Cat toy example
40+
41+
First, let's login so that we can upload the checkpoint to the Hub during training:
42+
43+
```bash
44+
huggingface-cli login
45+
```
46+
47+
Now let's get our dataset. For this example we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example .
48+
49+
Let's first download it locally:
50+
51+
```py
52+
from huggingface_hub import snapshot_download
53+
54+
local_dir = "./cat"
55+
snapshot_download("diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes")
56+
```
57+
58+
This will be our training data.
59+
Now we can launch the training using:
60+
61+
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
62+
63+
**___Note: Please follow the [README_sdxl.md](./README_sdxl.md) if you are using the [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).___**
64+
65+
```bash
66+
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
67+
export DATA_DIR="./cat"
68+
69+
accelerate launch textual_inversion.py \
70+
--pretrained_model_name_or_path=$MODEL_NAME \
71+
--train_data_dir=$DATA_DIR \
72+
--learnable_property="object" \
73+
--placeholder_token="<cat-toy>" \
74+
--initializer_token="toy" \
75+
--resolution=512 \
76+
--train_batch_size=1 \
77+
--gradient_accumulation_steps=4 \
78+
--max_train_steps=3000 \
79+
--learning_rate=5.0e-04 \
80+
--scale_lr \
81+
--lr_scheduler="constant" \
82+
--lr_warmup_steps=0 \
83+
--push_to_hub \
84+
--output_dir="textual_inversion_cat"
85+
```
86+
87+
A full training run takes ~1 hour on one V100 GPU.
88+
89+
**Note**: As described in [the official paper](https://arxiv.org/abs/2208.01618)
90+
only one embedding vector is used for the placeholder token, *e.g.* `"<cat-toy>"`.
91+
However, one can also add multiple embedding vectors for the placeholder token
92+
to increase the number of fine-tuneable parameters. This can help the model to learn
93+
more complex details. To use multiple embedding vectors, you should define `--num_vectors`
94+
to a number larger than one, *e.g.*:
95+
```bash
96+
--num_vectors 5
97+
```
98+
99+
The saved textual inversion vectors will then be larger in size compared to the default case.
100+
101+
### Inference
102+
103+
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
104+
105+
```python
106+
from diffusers import StableDiffusionPipeline
107+
import torch
108+
109+
model_id = "path-to-your-trained-model"
110+
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
111+
112+
prompt = "A <cat-toy> backpack"
113+
114+
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
115+
116+
image.save("cat-backpack.png")
117+
```
118+
119+
120+
## Training with Flax/JAX
121+
122+
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
123+
124+
Before running the scripts, make sure to install the library's training dependencies:
125+
126+
```bash
127+
pip install -U -r requirements_flax.txt
128+
```
129+
130+
```bash
131+
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
132+
export DATA_DIR="path-to-dir-containing-images"
133+
134+
python textual_inversion_flax.py \
135+
--pretrained_model_name_or_path=$MODEL_NAME \
136+
--train_data_dir=$DATA_DIR \
137+
--learnable_property="object" \
138+
--placeholder_token="<cat-toy>" \
139+
--initializer_token="toy" \
140+
--resolution=512 \
141+
--train_batch_size=1 \
142+
--max_train_steps=3000 \
143+
--learning_rate=5.0e-04 \
144+
--scale_lr \
145+
--output_dir="textual_inversion_cat"
146+
```
147+
It should be at least 70% faster than the PyTorch script with the same configuration.
148+
149+
### Training with xformers:
150+
You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
## Textual Inversion fine-tuning example for SDXL
2+
3+
```
4+
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
5+
export DATA_DIR="./cat"
6+
7+
accelerate launch textual_inversion_sdxl.py \
8+
--pretrained_model_name_or_path=$MODEL_NAME \
9+
--train_data_dir=$DATA_DIR \
10+
--learnable_property="object" \
11+
--placeholder_token="<cat-toy>" \
12+
--initializer_token="toy" \
13+
--mixed_precision="bf16" \
14+
--resolution=768 \
15+
--train_batch_size=1 \
16+
--gradient_accumulation_steps=4 \
17+
--max_train_steps=500 \
18+
--learning_rate=5.0e-04 \
19+
--scale_lr \
20+
--lr_scheduler="constant" \
21+
--lr_warmup_steps=0 \
22+
--save_as_full_pipeline \
23+
--output_dir="./textual_inversion_cat_sdxl"
24+
```
25+
26+
For now, only training of the first text encoder is supported.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
accelerate>=0.16.0
2+
torchvision
3+
transformers>=4.25.1
4+
ftfy
5+
tensorboard
6+
Jinja2
7+
pydicom
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
transformers>=4.25.1
2+
flax
3+
optax
4+
torch
5+
torchvision
6+
ftfy
7+
tensorboard
8+
Jinja2
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
22+
sys.path.append("..")
23+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
24+
25+
26+
logging.basicConfig(level=logging.DEBUG)
27+
28+
logger = logging.getLogger()
29+
stream_handler = logging.StreamHandler(sys.stdout)
30+
logger.addHandler(stream_handler)
31+
32+
33+
class TextualInversion(ExamplesTestsAccelerate):
34+
def test_textual_inversion(self):
35+
with tempfile.TemporaryDirectory() as tmpdir:
36+
test_args = f"""
37+
examples/textual_inversion/textual_inversion.py
38+
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
39+
--train_data_dir docs/source/en/imgs
40+
--learnable_property object
41+
--placeholder_token <cat-toy>
42+
--initializer_token a
43+
--save_steps 1
44+
--num_vectors 2
45+
--resolution 64
46+
--train_batch_size 1
47+
--gradient_accumulation_steps 1
48+
--max_train_steps 2
49+
--learning_rate 5.0e-04
50+
--scale_lr
51+
--lr_scheduler constant
52+
--lr_warmup_steps 0
53+
--output_dir {tmpdir}
54+
""".split()
55+
56+
run_command(self._launch_args + test_args)
57+
# save_pretrained smoke test
58+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors")))
59+
60+
def test_textual_inversion_checkpointing(self):
61+
with tempfile.TemporaryDirectory() as tmpdir:
62+
test_args = f"""
63+
examples/textual_inversion/textual_inversion.py
64+
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
65+
--train_data_dir docs/source/en/imgs
66+
--learnable_property object
67+
--placeholder_token <cat-toy>
68+
--initializer_token a
69+
--save_steps 1
70+
--num_vectors 2
71+
--resolution 64
72+
--train_batch_size 1
73+
--gradient_accumulation_steps 1
74+
--max_train_steps 3
75+
--learning_rate 5.0e-04
76+
--scale_lr
77+
--lr_scheduler constant
78+
--lr_warmup_steps 0
79+
--output_dir {tmpdir}
80+
--checkpointing_steps=1
81+
--checkpoints_total_limit=2
82+
""".split()
83+
84+
run_command(self._launch_args + test_args)
85+
86+
# check checkpoint directories exist
87+
self.assertEqual(
88+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
89+
{"checkpoint-2", "checkpoint-3"},
90+
)
91+
92+
def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
93+
with tempfile.TemporaryDirectory() as tmpdir:
94+
test_args = f"""
95+
examples/textual_inversion/textual_inversion.py
96+
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
97+
--train_data_dir docs/source/en/imgs
98+
--learnable_property object
99+
--placeholder_token <cat-toy>
100+
--initializer_token a
101+
--save_steps 1
102+
--num_vectors 2
103+
--resolution 64
104+
--train_batch_size 1
105+
--gradient_accumulation_steps 1
106+
--max_train_steps 2
107+
--learning_rate 5.0e-04
108+
--scale_lr
109+
--lr_scheduler constant
110+
--lr_warmup_steps 0
111+
--output_dir {tmpdir}
112+
--checkpointing_steps=1
113+
""".split()
114+
115+
run_command(self._launch_args + test_args)
116+
117+
# check checkpoint directories exist
118+
self.assertEqual(
119+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
120+
{"checkpoint-1", "checkpoint-2"},
121+
)
122+
123+
resume_run_args = f"""
124+
examples/textual_inversion/textual_inversion.py
125+
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
126+
--train_data_dir docs/source/en/imgs
127+
--learnable_property object
128+
--placeholder_token <cat-toy>
129+
--initializer_token a
130+
--save_steps 1
131+
--num_vectors 2
132+
--resolution 64
133+
--train_batch_size 1
134+
--gradient_accumulation_steps 1
135+
--max_train_steps 2
136+
--learning_rate 5.0e-04
137+
--scale_lr
138+
--lr_scheduler constant
139+
--lr_warmup_steps 0
140+
--output_dir {tmpdir}
141+
--checkpointing_steps=1
142+
--resume_from_checkpoint=checkpoint-2
143+
--checkpoints_total_limit=2
144+
""".split()
145+
146+
run_command(self._launch_args + resume_run_args)
147+
148+
# check checkpoint directories exist
149+
self.assertEqual(
150+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
151+
{"checkpoint-2", "checkpoint-3"},
152+
)

0 commit comments

Comments
 (0)