Skip to content

bug fix #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 29, 2025
Prev Previous commit
Next Next commit
add controlnet support
  • Loading branch information
tenderness-git committed Apr 28, 2025
commit 75e33ab809c5e748393cd924da0e10143c1e1059
4 changes: 3 additions & 1 deletion diffsynth_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from .pipelines import (
FluxImagePipeline,
FluxControlNet,
SDXLImagePipeline,
SDImagePipeline,
WanVideoPipeline,
FluxModelConfig,
SDXLModelConfig,
SDModelConfig,
WanModelConfig,
WanModelConfig
)
from .utils.download import fetch_model, fetch_modelscope_model, fetch_civitai_model
from .utils.video import load_video, save_video

__all__ = [
"FluxImagePipeline",
"FluxControlNet",
"SDXLImagePipeline",
"SDImagePipeline",
"WanVideoPipeline",
Expand Down
2 changes: 2 additions & 0 deletions diffsynth_engine/models/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .flux_dit import FluxDiT, config as flux_dit_config
from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2, config as flux_text_encoder_config
from .flux_vae import FluxVAEDecoder, FluxVAEEncoder, config as flux_vae_config
from .flux_controlnet import FluxControlNet

__all__ = [
"FluxDiT",
"FluxControlNet",
"FluxTextEncoder1",
"FluxTextEncoder2",
"FluxVAEDecoder",
Expand Down
26 changes: 8 additions & 18 deletions diffsynth_engine/models/flux/flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,13 @@ def __init__(self, attn_impl: Optional[str] = None, device: str = "cuda:0", dtyp
[nn.Linear(3072, 3072, device=device, dtype=dtype) for _ in range(len(self.single_blocks))]
)

def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
def get_patch_callback(self):
def patch_callback(hidden_states, controlnet_outputs, index, patch_point:FluxPatchPoint):

pass

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)

return latent_image_ids
return patch_callback


def forward(
self,
Expand All @@ -77,6 +69,8 @@ def forward(
prompt_emb,
pooled_prompt_emb,
guidance,
image_ids,
text_ids
):
hidden_states = self.x_embedder(hidden_states) + self.controlnet_x_embedder(control_condition)
condition = (
Expand All @@ -85,10 +79,6 @@ def forward(
+ self.pooled_text_embedder(pooled_prompt_emb)
)
prompt_emb = self.context_embedder(prompt_emb)
image_ids = self.prepare_image_ids(hidden_states)
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(
device=self.device, dtype=prompt_emb.dtype
)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))

# double block
Expand Down
15 changes: 11 additions & 4 deletions diffsynth_engine/models/flux/flux_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ class FluxPatchPoint(Enum):
AFTER_EACH_DOUBLE_BLOCK = "after each double block"
AFTER_EACH_SINGLE_BLOCK = "after each single block"

def default_patch_callback(hidden_states, controlnet_outputs, index, patch_point:FluxPatchPoint):
for controlnet_output in controlnet_outputs:
if len(controlnet_output) <= index:
continue
# 主模型第i层输出的hidden_states和每个controlnet第i层的输出结果相加
hidden_states = hidden_states + controlnet_output[index]
return hidden_states

class FluxDiTStateDictConverter(StateDictConverter):
def __init__(self):
Expand Down Expand Up @@ -395,9 +402,9 @@ def forward(
text_ids,
image_ids=None,
use_gradient_checkpointing=False,
controlnet_block_outputs=None,
controlnet_double_block_outputs=None,
controlnet_single_block_outputs=None,
patch_callback=None,
patch_callback=default_patch_callback,
**kwargs,
):
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
Expand Down Expand Up @@ -431,9 +438,9 @@ def forward(
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
if controlnet_block_outputs is not None and patch_callback is not None:
if controlnet_double_block_outputs is not None and patch_callback is not None:
hidden_states = patch_callback(
hidden_states, controlnet_block_outputs, i, FluxPatchPoint.AFTER_EACH_DOUBLE_BLOCK
hidden_states, controlnet_double_block_outputs, i, FluxPatchPoint.AFTER_EACH_DOUBLE_BLOCK
)

hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
Expand Down
82 changes: 77 additions & 5 deletions diffsynth_engine/pipelines/flux_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from tqdm import tqdm
from PIL import Image
from dataclasses import dataclass
from diffsynth_engine.models.flux import (
from diffsynth_engine.models.flux import (
FluxTextEncoder1,
FluxTextEncoder2,
FluxVAEDecoder,
FluxVAEEncoder,
FluxControlNet,
FluxDiT,
flux_dit_config,
flux_text_encoder_config,
Expand Down Expand Up @@ -173,6 +174,12 @@ def calculate_shift(
mu = image_seq_len * m + b
return mu

@dataclass
class ControlNetParams:
model: nn.Moudle
scale: float
images: List[Image.Image | torch.Tensor]


@dataclass
class FluxModelConfig:
Expand Down Expand Up @@ -339,20 +346,21 @@ def predict_noise_with_cfg(
text_ids: torch.Tensor,
cfg_scale: float,
guidance: torch.Tensor,
controlnet_params: List[ControlNetParams],
use_cfg: bool = True,
batch_cfg: bool = True,
):
if cfg_scale <= 1.0 or not use_cfg:
return self.predict_noise(
latents, timestep, positive_prompt_emb, positive_add_text_embeds, image_ids, text_ids, guidance
latents, timestep, positive_prompt_emb, positive_add_text_embeds, image_ids, text_ids, guidance, controlnet_params
)
if not batch_cfg:
# cfg by predict noise one by one
positive_noise_pred = self.predict_noise(
latents, timestep, positive_prompt_emb, positive_add_text_embeds, image_ids, text_ids, guidance
latents, timestep, positive_prompt_emb, positive_add_text_embeds, image_ids, text_ids, guidance, controlnet_params
)
negative_noise_pred = self.predict_noise(
latents, timestep, negative_prompt_emb, negative_add_text_embeds, image_ids, text_ids, guidance
latents, timestep, negative_prompt_emb, negative_add_text_embeds, image_ids, text_ids, guidance, controlnet_params
)
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
return noise_pred
Expand All @@ -363,7 +371,7 @@ def predict_noise_with_cfg(
latents = torch.cat([latents, latents], dim=0)
timestep = torch.cat([timestep, timestep], dim=0)
positive_noise_pred, negative_noise_pred = self.predict_noise(
latents, timestep, prompt_emb, add_text_embeds, image_ids, text_ids, guidance
latents, timestep, prompt_emb, add_text_embeds, image_ids, text_ids, guidance, controlnet_params
)
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
return noise_pred
Expand All @@ -377,7 +385,18 @@ def predict_noise(
image_ids: torch.Tensor,
text_ids: torch.Tensor,
guidance: float,
controlnet_params: List[ControlNetParams]
):
double_block_output, single_block_output = self.predict_multicontrolnet(
hidden_states=latents,
timestep=timestep,
prompt_emb=prompt_emb,
pooled_prompt_emb=add_text_embeds,
guidance=guidance,
text_ids=text_ids,
image_ids=image_ids,
controlnet_params=controlnet_params
)
noise_pred = self.dit(
hidden_states=latents,
timestep=timestep,
Expand All @@ -386,6 +405,8 @@ def predict_noise(
guidance=guidance,
text_ids=text_ids,
image_ids=image_ids,
controlnet_double_block_output=double_block_output,
controlnet_single_block_output=single_block_output
)
return noise_pred

Expand Down Expand Up @@ -424,6 +445,51 @@ def prepare_latents(
sigmas, timesteps = sigmas.to(device=self.device), timesteps.to(self.device)
return init_latents, latents, sigmas, timesteps

def prepare_controlnets(self, controlnet_params: List[ControlNetParams]):
results = []
for param in controlnet_params:
image = self.preprocess_image(param.image).to(device=self.device, dtype=self.dtype)
latent = self.encode_image(image, tiled=False)
results.append(
ControlNetParam(
model=param.model,
scale=param.scale,
image=latent,
)
)
return results

def predict_multicontrolnet(
self,
latents: torch.Tensor,
timestep: torch.Tensor,
prompt_emb: torch.Tensor,
add_text_embeds: torch.Tensor,
image_ids: torch.Tensor,
text_ids: torch.Tensor,
guidance: float,
controlnet_params: List[ControlNetParams]
):
double_block_output, single_block_output = None, None
for param in controlnet_params:
condition = torch.sum(torch.stack(param.images), dim=0, keepdim=True)
ouput1, output2 = param.model(
latents,
condition,
params.scale,
timestep,
prompt_emb,
add_text_embeds,
guidance
)
if double_block_output is None and single_block_output is None:
double_block_output = output1
single_block_output = output2
else:
double_block_output = [block_output_sum + block_output for block_output_sum, block_output in zip(double_block_output, output1)]
single_block_output = [block_output_sum + block_output for block_output_sum, block_output in zip(single_block_output, output2)]
return double_block_output, single_block_output

def enable_fp8_linear(self):
enable_fp8_linear(self.dit)

Expand All @@ -444,8 +510,10 @@ def __call__(
tile_size: int = 128,
tile_stride: int = 64,
seed: int | None = None,
controlnet_params:List[ControlNetParams] = [],
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
):

if input_image is not None:
width, height = input_image.size
self.validate_image_size(height, width, minimum=64, multiple_of=16)
Expand All @@ -472,6 +540,9 @@ def __call__(
# Extra input
image_ids, text_ids, guidance = self.prepare_extra_input(latents, positive_prompt_emb, guidance=3.5)

# ControlNet
controlnet_params = self.prepare_controlnets(controlnet_params)

# Denoise
self.load_models_to_device(["dit"])
for i, timestep in enumerate(tqdm(timesteps)):
Expand All @@ -487,6 +558,7 @@ def __call__(
text_ids=text_ids,
cfg_scale=cfg_scale,
guidance=guidance,
controlnet_params=controlnet_params,
use_cfg=self.use_cfg,
batch_cfg=self.batch_cfg,
)
Expand Down