Skip to content

Commit 00f95b9

Browse files
sayakpaula-r-r-o-wlinoytsaban
authored
Kontext training (#11813)
* support flux kontext * make fix-copies * add example * add tests * update docs * update * add note on integrity checker * initial commit * initial commit * add readme section and fixes in the training script. * add test * rectify ckpt_id * fix ckpt * fixes * change id * update * Update examples/dreambooth/train_dreambooth_lora_flux_kontext.py Co-authored-by: Aryan <[email protected]> * Update examples/dreambooth/README_flux.md --------- Co-authored-by: Aryan <[email protected]> Co-authored-by: linoytsaban <[email protected]> Co-authored-by: Linoy Tsaban <[email protected]>
1 parent eea7689 commit 00f95b9

File tree

5 files changed

+2439
-1
lines changed

5 files changed

+2439
-1
lines changed

docs/source/en/api/pipelines/flux.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Flux comes in the following variants:
3939
| Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) |
4040
| Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) |
4141
| Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) |
42-
| Kontext | [`black-forest-labs/FLUX.1-kontext`](https://huggingface.co/black-forest-labs/FLUX.1-kontext) |
42+
| Kontext | [`black-forest-labs/FLUX.1-kontext`](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) |
4343

4444
All checkpoints have different usage which we detail below.
4545

examples/dreambooth/README_flux.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,5 +260,51 @@ to enable `latent_caching` simply pass `--cache_latents`.
260260
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
261261
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
262262

263+
## Training Kontext
264+
265+
[Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We
266+
provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for T2I. The optimizations discussed above apply this script, too.
267+
268+
Make sure to follow the [instructions to set up your environment](#running-locally-with-pytorch) before proceeding to the rest of the section.
269+
270+
Below is an example training command:
271+
272+
```bash
273+
accelerate launch train_dreambooth_lora_flux_kontext.py \
274+
--pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \
275+
--instance_data_dir="dog" \
276+
--output_dir="kontext-dog" \
277+
--mixed_precision="bf16" \
278+
--instance_prompt="a photo of sks dog" \
279+
--resolution=1024 \
280+
--train_batch_size=1 \
281+
--guidance_scale=1 \
282+
--gradient_accumulation_steps=4 \
283+
--gradient_checkpointing \
284+
--optimizer="adamw" \
285+
--use_8bit_adam \
286+
--cache_latents \
287+
--learning_rate=1e-4 \
288+
--lr_scheduler="constant" \
289+
--lr_warmup_steps=0 \
290+
--max_train_steps=500 \
291+
--seed="0"
292+
```
293+
294+
Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not
295+
perform as expected.
296+
297+
### Misc notes
298+
299+
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
300+
### Aspect Ratio Bucketing
301+
we've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
302+
303+
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
304+
305+
`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
306+
`
307+
Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
308+
263309
## Other notes
264310
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import json
17+
import logging
18+
import os
19+
import sys
20+
import tempfile
21+
22+
import safetensors
23+
24+
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
25+
26+
27+
sys.path.append("..")
28+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
29+
30+
31+
logging.basicConfig(level=logging.DEBUG)
32+
33+
logger = logging.getLogger()
34+
stream_handler = logging.StreamHandler(sys.stdout)
35+
logger.addHandler(stream_handler)
36+
37+
38+
class DreamBoothLoRAFluxKontext(ExamplesTestsAccelerate):
39+
instance_data_dir = "docs/source/en/imgs"
40+
instance_prompt = "photo"
41+
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
42+
script_path = "examples/dreambooth/train_dreambooth_lora_flux_kontext.py"
43+
transformer_layer_type = "single_transformer_blocks.0.attn.to_k"
44+
45+
def test_dreambooth_lora_flux_kontext(self):
46+
with tempfile.TemporaryDirectory() as tmpdir:
47+
test_args = f"""
48+
{self.script_path}
49+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
50+
--instance_data_dir {self.instance_data_dir}
51+
--instance_prompt {self.instance_prompt}
52+
--resolution 64
53+
--train_batch_size 1
54+
--gradient_accumulation_steps 1
55+
--max_train_steps 2
56+
--learning_rate 5.0e-04
57+
--scale_lr
58+
--lr_scheduler constant
59+
--lr_warmup_steps 0
60+
--output_dir {tmpdir}
61+
""".split()
62+
63+
run_command(self._launch_args + test_args)
64+
# save_pretrained smoke test
65+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
66+
67+
# make sure the state_dict has the correct naming in the parameters.
68+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
69+
is_lora = all("lora" in k for k in lora_state_dict.keys())
70+
self.assertTrue(is_lora)
71+
72+
# when not training the text encoder, all the parameters in the state dict should start
73+
# with `"transformer"` in their names.
74+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
75+
self.assertTrue(starts_with_transformer)
76+
77+
def test_dreambooth_lora_text_encoder_flux_kontext(self):
78+
with tempfile.TemporaryDirectory() as tmpdir:
79+
test_args = f"""
80+
{self.script_path}
81+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
82+
--instance_data_dir {self.instance_data_dir}
83+
--instance_prompt {self.instance_prompt}
84+
--resolution 64
85+
--train_batch_size 1
86+
--train_text_encoder
87+
--gradient_accumulation_steps 1
88+
--max_train_steps 2
89+
--learning_rate 5.0e-04
90+
--scale_lr
91+
--lr_scheduler constant
92+
--lr_warmup_steps 0
93+
--output_dir {tmpdir}
94+
""".split()
95+
96+
run_command(self._launch_args + test_args)
97+
# save_pretrained smoke test
98+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
99+
100+
# make sure the state_dict has the correct naming in the parameters.
101+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
102+
is_lora = all("lora" in k for k in lora_state_dict.keys())
103+
self.assertTrue(is_lora)
104+
105+
starts_with_expected_prefix = all(
106+
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
107+
)
108+
self.assertTrue(starts_with_expected_prefix)
109+
110+
def test_dreambooth_lora_latent_caching(self):
111+
with tempfile.TemporaryDirectory() as tmpdir:
112+
test_args = f"""
113+
{self.script_path}
114+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
115+
--instance_data_dir {self.instance_data_dir}
116+
--instance_prompt {self.instance_prompt}
117+
--resolution 64
118+
--train_batch_size 1
119+
--gradient_accumulation_steps 1
120+
--max_train_steps 2
121+
--cache_latents
122+
--learning_rate 5.0e-04
123+
--scale_lr
124+
--lr_scheduler constant
125+
--lr_warmup_steps 0
126+
--output_dir {tmpdir}
127+
""".split()
128+
129+
run_command(self._launch_args + test_args)
130+
# save_pretrained smoke test
131+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
132+
133+
# make sure the state_dict has the correct naming in the parameters.
134+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
135+
is_lora = all("lora" in k for k in lora_state_dict.keys())
136+
self.assertTrue(is_lora)
137+
138+
# when not training the text encoder, all the parameters in the state dict should start
139+
# with `"transformer"` in their names.
140+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
141+
self.assertTrue(starts_with_transformer)
142+
143+
def test_dreambooth_lora_layers(self):
144+
with tempfile.TemporaryDirectory() as tmpdir:
145+
test_args = f"""
146+
{self.script_path}
147+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
148+
--instance_data_dir {self.instance_data_dir}
149+
--instance_prompt {self.instance_prompt}
150+
--resolution 64
151+
--train_batch_size 1
152+
--gradient_accumulation_steps 1
153+
--max_train_steps 2
154+
--cache_latents
155+
--learning_rate 5.0e-04
156+
--scale_lr
157+
--lora_layers {self.transformer_layer_type}
158+
--lr_scheduler constant
159+
--lr_warmup_steps 0
160+
--output_dir {tmpdir}
161+
""".split()
162+
163+
run_command(self._launch_args + test_args)
164+
# save_pretrained smoke test
165+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
166+
167+
# make sure the state_dict has the correct naming in the parameters.
168+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
169+
is_lora = all("lora" in k for k in lora_state_dict.keys())
170+
self.assertTrue(is_lora)
171+
172+
# when not training the text encoder, all the parameters in the state dict should start
173+
# with `"transformer"` in their names. In this test, we only params of
174+
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
175+
starts_with_transformer = all(
176+
key.startswith("transformer.single_transformer_blocks.0.attn.to_k") for key in lora_state_dict.keys()
177+
)
178+
self.assertTrue(starts_with_transformer)
179+
180+
def test_dreambooth_lora_flux_kontext_checkpointing_checkpoints_total_limit(self):
181+
with tempfile.TemporaryDirectory() as tmpdir:
182+
test_args = f"""
183+
{self.script_path}
184+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
185+
--instance_data_dir={self.instance_data_dir}
186+
--output_dir={tmpdir}
187+
--instance_prompt={self.instance_prompt}
188+
--resolution=64
189+
--train_batch_size=1
190+
--gradient_accumulation_steps=1
191+
--max_train_steps=6
192+
--checkpoints_total_limit=2
193+
--checkpointing_steps=2
194+
""".split()
195+
196+
run_command(self._launch_args + test_args)
197+
198+
self.assertEqual(
199+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
200+
{"checkpoint-4", "checkpoint-6"},
201+
)
202+
203+
def test_dreambooth_lora_flux_kontext_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
204+
with tempfile.TemporaryDirectory() as tmpdir:
205+
test_args = f"""
206+
{self.script_path}
207+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
208+
--instance_data_dir={self.instance_data_dir}
209+
--output_dir={tmpdir}
210+
--instance_prompt={self.instance_prompt}
211+
--resolution=64
212+
--train_batch_size=1
213+
--gradient_accumulation_steps=1
214+
--max_train_steps=4
215+
--checkpointing_steps=2
216+
""".split()
217+
218+
run_command(self._launch_args + test_args)
219+
220+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
221+
222+
resume_run_args = f"""
223+
{self.script_path}
224+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
225+
--instance_data_dir={self.instance_data_dir}
226+
--output_dir={tmpdir}
227+
--instance_prompt={self.instance_prompt}
228+
--resolution=64
229+
--train_batch_size=1
230+
--gradient_accumulation_steps=1
231+
--max_train_steps=8
232+
--checkpointing_steps=2
233+
--resume_from_checkpoint=checkpoint-4
234+
--checkpoints_total_limit=2
235+
""".split()
236+
237+
run_command(self._launch_args + resume_run_args)
238+
239+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
240+
241+
def test_dreambooth_lora_with_metadata(self):
242+
# Use a `lora_alpha` that is different from `rank`.
243+
lora_alpha = 8
244+
rank = 4
245+
with tempfile.TemporaryDirectory() as tmpdir:
246+
test_args = f"""
247+
{self.script_path}
248+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
249+
--instance_data_dir {self.instance_data_dir}
250+
--instance_prompt {self.instance_prompt}
251+
--resolution 64
252+
--train_batch_size 1
253+
--gradient_accumulation_steps 1
254+
--max_train_steps 2
255+
--lora_alpha={lora_alpha}
256+
--rank={rank}
257+
--learning_rate 5.0e-04
258+
--scale_lr
259+
--lr_scheduler constant
260+
--lr_warmup_steps 0
261+
--output_dir {tmpdir}
262+
""".split()
263+
264+
run_command(self._launch_args + test_args)
265+
# save_pretrained smoke test
266+
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
267+
self.assertTrue(os.path.isfile(state_dict_file))
268+
269+
# Check if the metadata was properly serialized.
270+
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
271+
metadata = f.metadata() or {}
272+
273+
metadata.pop("format", None)
274+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
275+
if raw:
276+
raw = json.loads(raw)
277+
278+
loaded_lora_alpha = raw["transformer.lora_alpha"]
279+
self.assertTrue(loaded_lora_alpha == lora_alpha)
280+
loaded_lora_rank = raw["transformer.r"]
281+
self.assertTrue(loaded_lora_rank == rank)

0 commit comments

Comments
 (0)