Skip to content

[Flux Kontext] Support Fal Kontext LoRA #11823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 222 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
return converted_state_dict


def _convert_fal_kontext_lora_to_diffusers(original_state_dict):
converted_state_dict = {}
original_state_dict_keys = list(original_state_dict.keys())
num_layers = 19
num_single_layers = 38
inner_dim = 3072
mlp_ratio = 4.0

# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
original_block_prefix = "base_model.model."

for lora_key in ["lora_A", "lora_B"]:
# norms
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight"
)
if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias"
)

converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
)

# Q, K, V
if lora_key == "lora_A":
sample_lora_weight = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
)
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])

context_lora_weight = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
)
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
[context_lora_weight]
)
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
[context_lora_weight]
)
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
[context_lora_weight]
)
else:
sample_q, sample_k, sample_v = torch.chunk(
original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
),
3,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])

context_q, context_k, context_v = torch.chunk(
original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
),
3,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])

if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"),
3,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])

if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"),
3,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])

# ff img_mlp
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias"
)

converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias"
)

converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
)

converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
)

# output projections.
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
)

# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."

for lora_key in ["lora_A", "lora_B"]:
# norm.linear <- single_blocks.0.modulation.lin
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight"
)
if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias"
)

# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)

if lora_key == "lora_A":
lora_weight = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])

if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
else:
q, k, v, mlp = torch.split(
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"),
split_size,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])

if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
q_bias, k_bias, v_bias, mlp_bias = torch.split(
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"),
split_size,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])

# output projections.
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight"
)
if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias"
)

for lora_key in ["lora_A", "lora_B"]:
converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}final_layer.linear.{lora_key}.weight"
)
if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}final_layer.linear.{lora_key}.bias"
)

if len(original_state_dict) > 0:
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")

for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)

return converted_state_dict


def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}

Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
_convert_fal_kontext_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
Expand Down Expand Up @@ -2062,6 +2063,17 @@ def lora_state_dict(
return_metadata=return_lora_metadata,
)

is_fal_kontext = any("base_model" in k for k in state_dict)
if is_fal_kontext:
state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)

# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
Expand Down
Loading