diff --git a/diffsynth_engine/configs/pipeline.py b/diffsynth_engine/configs/pipeline.py index 152610d2..a6b9396a 100644 --- a/diffsynth_engine/configs/pipeline.py +++ b/diffsynth_engine/configs/pipeline.py @@ -249,6 +249,11 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi # override OptimizationConfig fbcache_relative_l1_threshold = 0.009 + # svd + use_nunchaku: Optional[bool] = field(default=None, init=False) + use_nunchaku_awq: Optional[bool] = field(default=None, init=False) + use_nunchaku_attn: Optional[bool] = field(default=None, init=False) + @classmethod def basic_config( cls, diff --git a/diffsynth_engine/models/base.py b/diffsynth_engine/models/base.py index 0637e465..648134a2 100644 --- a/diffsynth_engine/models/base.py +++ b/diffsynth_engine/models/base.py @@ -40,7 +40,7 @@ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = True): for args in lora_args: - key = args["name"] + key = args["key"] module = self.get_submodule(key) if not isinstance(module, (LoRALinear, LoRAConv2d)): raise ValueError(f"Unsupported lora key: {key}") diff --git a/diffsynth_engine/models/basic/lora.py b/diffsynth_engine/models/basic/lora.py index b0d1b92d..e4065bd4 100644 --- a/diffsynth_engine/models/basic/lora.py +++ b/diffsynth_engine/models/basic/lora.py @@ -132,6 +132,7 @@ def add_frozen_lora( device: str, dtype: torch.dtype, save_original_weight: bool = True, + **kwargs, ): if save_original_weight and self._original_weight is None: if self.weight.dtype == torch.float8_e4m3fn: diff --git a/diffsynth_engine/models/basic/lora_nunchaku.py b/diffsynth_engine/models/basic/lora_nunchaku.py new file mode 100644 index 00000000..dea90aca --- /dev/null +++ b/diffsynth_engine/models/basic/lora_nunchaku.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +from collections import OrderedDict + +from .lora import LoRA +from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear +from nunchaku.lora.flux.nunchaku_converter import ( + pack_lowrank_weight, + unpack_lowrank_weight, +) + + +class LoRASVDQW4A4Linear(nn.Module): + def __init__( + self, + origin_linear: SVDQW4A4Linear, + ): + super().__init__() + + self.origin_linear = origin_linear + self.base_rank = self.origin_linear.rank + self._lora_dict = OrderedDict() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.origin_linear(x) + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.origin_linear, name) + + def _apply_lora_weights(self, name: str, down: torch.Tensor, up: torch.Tensor, alpha: int, scale: float, rank: int): + final_scale = scale * (alpha / rank) + + up_scaled = (up * final_scale).to( + dtype=self.origin_linear.proj_up.dtype, device=self.origin_linear.proj_up.device + ) + down_final = down.to(dtype=self.origin_linear.proj_down.dtype, device=self.origin_linear.proj_down.device) + + with torch.no_grad(): + pd_packed = self.origin_linear.proj_down.data + pu_packed = self.origin_linear.proj_up.data + pd = unpack_lowrank_weight(pd_packed, down=True) + pu = unpack_lowrank_weight(pu_packed, down=False) + + new_proj_down = torch.cat([pd, down_final], dim=0) + new_proj_up = torch.cat([pu, up_scaled], dim=1) + + self.origin_linear.proj_down.data = pack_lowrank_weight(new_proj_down, down=True) + self.origin_linear.proj_up.data = pack_lowrank_weight(new_proj_up, down=False) + + current_total_rank = self.origin_linear.rank + self.origin_linear.rank += rank + self._lora_dict[name] = {"rank": rank, "alpha": alpha, "scale": scale, "start_idx": current_total_rank} + + def add_frozen_lora( + self, + name: str, + scale: float, + rank: int, + alpha: int, + up: torch.Tensor, + down: torch.Tensor, + device: str, + dtype: torch.dtype, + **kwargs, + ): + if name in self._lora_dict: + raise ValueError(f"LoRA with name '{name}' already exists.") + + self._apply_lora_weights(name, down, up, alpha, scale, rank) + + def add_qkv_lora( + self, + name: str, + scale: float, + rank: int, + alpha: int, + q_up: torch.Tensor, + q_down: torch.Tensor, + k_up: torch.Tensor, + k_down: torch.Tensor, + v_up: torch.Tensor, + v_down: torch.Tensor, + device: str, + dtype: torch.dtype, + **kwargs, + ): + if name in self._lora_dict: + raise ValueError(f"LoRA with name '{name}' already exists.") + + fused_down = torch.cat([q_down, k_down, v_down], dim=0) + + fused_rank = 3 * rank + out_q, out_k = q_up.shape[0], k_up.shape[0] + fused_up = torch.zeros((self.out_features, fused_rank), device=q_up.device, dtype=q_up.dtype) + fused_up[:out_q, :rank] = q_up + fused_up[out_q : out_q + out_k, rank : 2 * rank] = k_up + fused_up[out_q + out_k :, 2 * rank :] = v_up + + self._apply_lora_weights(name, fused_down, fused_up, alpha, scale, rank) + + def modify_scale(self, name: str, scale: float): + if name not in self._lora_dict: + raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}") + + info = self._lora_dict[name] + old_scale = info["scale"] + + if old_scale == scale: + return + + if old_scale == 0: + scale_factor = 0.0 + else: + scale_factor = scale / old_scale + + with torch.no_grad(): + lora_rank = info["rank"] + start_idx = info["start_idx"] + end_idx = start_idx + lora_rank + + pu_packed = self.origin_linear.proj_up.data + pu = unpack_lowrank_weight(pu_packed, down=False) + pu[:, start_idx:end_idx] *= scale_factor + + self.origin_linear.proj_up.data = pack_lowrank_weight(pu, down=False) + + self._lora_dict[name]["scale"] = scale + + def clear(self, release_all_cpu_memory: bool = False): + if not self._lora_dict: + return + + with torch.no_grad(): + pd_packed = self.origin_linear.proj_down.data + pu_packed = self.origin_linear.proj_up.data + + pd = unpack_lowrank_weight(pd_packed, down=True) + pu = unpack_lowrank_weight(pu_packed, down=False) + + pd_reset = pd[: self.base_rank, :].clone() + pu_reset = pu[:, : self.base_rank].clone() + + self.origin_linear.proj_down.data = pack_lowrank_weight(pd_reset, down=True) + self.origin_linear.proj_up.data = pack_lowrank_weight(pu_reset, down=False) + + self.origin_linear.rank = self.base_rank + + self._lora_dict.clear() + + +class LoRAAWQW4A16Linear(nn.Module): + def __init__(self, origin_linear: AWQW4A16Linear): + super().__init__() + self.origin_linear = origin_linear + self._lora_dict = OrderedDict() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + quantized_output = self.origin_linear(x) + + for name, lora in self._lora_dict.items(): + quantized_output += lora(x.to(lora.dtype)).to(quantized_output.dtype) + + return quantized_output + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.origin_linear, name) + + def add_lora( + self, + name: str, + scale: float, + rank: int, + alpha: int, + up: torch.Tensor, + down: torch.Tensor, + device: str, + dtype: torch.dtype, + **kwargs, + ): + up_linear = nn.Linear(rank, self.out_features, bias=False, device="meta", dtype=dtype).to_empty(device=device) + down_linear = nn.Linear(self.in_features, rank, bias=False, device="meta", dtype=dtype).to_empty(device=device) + + up_linear.weight.data = up.reshape(self.out_features, rank) + down_linear.weight.data = down.reshape(rank, self.in_features) + + lora = LoRA(scale, rank, alpha, up_linear, down_linear, device, dtype) + self._lora_dict[name] = lora + + def modify_scale(self, name: str, scale: float): + if name not in self._lora_dict: + raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}") + self._lora_dict[name].scale = scale + + def add_frozen_lora(self, *args, **kwargs): + raise NotImplementedError("Frozen LoRA (merging weights) is not supported for AWQW4A16Linear.") + + def clear(self, *args, **kwargs): + self._lora_dict.clear() + + +def patch_nunchaku_model_for_lora(model: nn.Module): + def _recursive_patch(module: nn.Module): + for name, child_module in module.named_children(): + replacement = None + if isinstance(child_module, AWQW4A16Linear): + replacement = LoRAAWQW4A16Linear(child_module) + elif isinstance(child_module, SVDQW4A4Linear): + replacement = LoRASVDQW4A4Linear(child_module) + + if replacement: + setattr(module, name, replacement) + else: + _recursive_patch(child_module) + + _recursive_patch(model) diff --git a/diffsynth_engine/models/qwen_image/__init__.py b/diffsynth_engine/models/qwen_image/__init__.py index 92beef69..972697ee 100644 --- a/diffsynth_engine/models/qwen_image/__init__.py +++ b/diffsynth_engine/models/qwen_image/__init__.py @@ -11,3 +11,11 @@ "Qwen2_5_VLVisionConfig", "Qwen2_5_VLConfig", ] + +try: + from .qwen_image_dit_nunchaku import QwenImageDiTNunchaku + + __all__.append("QwenImageDiTNunchaku") + +except (ImportError, ModuleNotFoundError): + pass diff --git a/diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py b/diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py new file mode 100644 index 00000000..1d5f0aad --- /dev/null +++ b/diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py @@ -0,0 +1,341 @@ +import torch +import torch.nn as nn +from typing import Any, Dict, List, Tuple, Optional +from einops import rearrange + +from diffsynth_engine.models.basic import attention as attention_ops +from diffsynth_engine.models.basic.timestep import TimestepEmbeddings +from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, RMSNorm +from diffsynth_engine.models.qwen_image.qwen_image_dit import ( + QwenFeedForward, + apply_rotary_emb_qwen, + QwenDoubleStreamAttention, + QwenImageTransformerBlock, + QwenImageDiT, + QwenEmbedRope, +) + +from nunchaku.models.utils import fuse_linears +from nunchaku.ops.fused import fused_gelu_mlp +from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear +from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d +from diffsynth_engine.models.basic.lora_nunchaku import LoRASVDQW4A4Linear, LoRAAWQW4A16Linear + + +class QwenDoubleStreamAttentionNunchaku(QwenDoubleStreamAttention): + def __init__( + self, + dim_a, + dim_b, + num_heads, + head_dim, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + nunchaku_rank: int = 32, + ): + super().__init__(dim_a, dim_b, num_heads, head_dim, device=device, dtype=dtype) + + to_qkv = fuse_linears([self.to_q, self.to_k, self.to_v]) + self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, rank=nunchaku_rank) + self.to_out = SVDQW4A4Linear.from_linear(self.to_out, rank=nunchaku_rank) + + del self.to_q, self.to_k, self.to_v + + add_qkv_proj = fuse_linears([self.add_q_proj, self.add_k_proj, self.add_v_proj]) + self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, rank=nunchaku_rank) + self.to_add_out = SVDQW4A4Linear.from_linear(self.to_add_out, rank=nunchaku_rank) + + del self.add_q_proj, self.add_k_proj, self.add_v_proj + + def forward( + self, + image: torch.FloatTensor, + text: torch.FloatTensor, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attn_mask: Optional[torch.Tensor] = None, + attn_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + img_q, img_k, img_v = self.to_qkv(image).chunk(3, dim=-1) + txt_q, txt_k, txt_v = self.add_qkv_proj(text).chunk(3, dim=-1) + + img_q = rearrange(img_q, "b s (h d) -> b s h d", h=self.num_heads) + img_k = rearrange(img_k, "b s (h d) -> b s h d", h=self.num_heads) + img_v = rearrange(img_v, "b s (h d) -> b s h d", h=self.num_heads) + + txt_q = rearrange(txt_q, "b s (h d) -> b s h d", h=self.num_heads) + txt_k = rearrange(txt_k, "b s (h d) -> b s h d", h=self.num_heads) + txt_v = rearrange(txt_v, "b s (h d) -> b s h d", h=self.num_heads) + + img_q, img_k = self.norm_q(img_q), self.norm_k(img_k) + txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k) + + if rotary_emb is not None: + img_freqs, txt_freqs = rotary_emb + img_q = apply_rotary_emb_qwen(img_q, img_freqs) + img_k = apply_rotary_emb_qwen(img_k, img_freqs) + txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs) + txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs) + + joint_q = torch.cat([txt_q, img_q], dim=1) + joint_k = torch.cat([txt_k, img_k], dim=1) + joint_v = torch.cat([txt_v, img_v], dim=1) + + attn_kwargs = attn_kwargs if attn_kwargs is not None else {} + joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs) + + joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype) + + txt_attn_output = joint_attn_out[:, : text.shape[1], :] + img_attn_output = joint_attn_out[:, text.shape[1] :, :] + + img_attn_output = self.to_out(img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class QwenFeedForwardNunchaku(QwenFeedForward): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dropout: float = 0.0, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + rank: int = 32, + ): + super().__init__(dim, dim_out, dropout, device=device, dtype=dtype) + self.net[0].proj = SVDQW4A4Linear.from_linear(self.net[0].proj, rank=rank) + self.net[2] = SVDQW4A4Linear.from_linear(self.net[2], rank=rank) + self.net[2].act_unsigned = self.net[2].precision != "nvfp4" + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return fused_gelu_mlp(hidden_states, self.net[0].proj, self.net[2]) + + +class QwenImageTransformerBlockNunchaku(QwenImageTransformerBlock): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + scale_shift: float = 1.0, + use_nunchaku_awq: bool = True, + use_nunchaku_attn: bool = True, + nunchaku_rank: int = 32, + ): + super().__init__(dim, num_attention_heads, attention_head_dim, eps, device=device, dtype=dtype) + + self.use_nunchaku_awq = use_nunchaku_awq + if use_nunchaku_awq: + self.img_mod[1] = AWQW4A16Linear.from_linear(self.img_mod[1], rank=nunchaku_rank) + + if use_nunchaku_attn: + self.attn = QwenDoubleStreamAttentionNunchaku( + dim_a=dim, + dim_b=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + device=device, + dtype=dtype, + nunchaku_rank=nunchaku_rank, + ) + else: + self.attn = QwenDoubleStreamAttention( + dim_a=dim, + dim_b=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + device=device, + dtype=dtype, + ) + + self.img_mlp = QwenFeedForwardNunchaku(dim=dim, dim_out=dim, device=device, dtype=dtype, rank=nunchaku_rank) + + if use_nunchaku_awq: + self.txt_mod[1] = AWQW4A16Linear.from_linear(self.txt_mod[1], rank=nunchaku_rank) + + self.txt_mlp = QwenFeedForwardNunchaku(dim=dim, dim_out=dim, device=device, dtype=dtype, rank=nunchaku_rank) + + self.scale_shift = scale_shift + + def _modulate(self, x, mod_params): + shift, scale, gate = mod_params.chunk(3, dim=-1) + if self.use_nunchaku_awq: + if self.scale_shift != 0: + scale.add_(self.scale_shift) + return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1) + else: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + image: torch.Tensor, + text: torch.Tensor, + temb: torch.Tensor, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attn_mask: Optional[torch.Tensor] = None, + attn_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.use_nunchaku_awq: + img_mod_params = self.img_mod(temb) # [B, 6*dim] + txt_mod_params = self.txt_mod(temb) # [B, 6*dim] + + # nunchaku's mod_params is [B, 6*dim] instead of [B, dim*6] + img_mod_params = ( + img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1) + ) + txt_mod_params = ( + txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1) + ) + + img_mod_attn, img_mod_mlp = img_mod_params.chunk(2, dim=-1) # [B, 3*dim] each + txt_mod_attn, txt_mod_mlp = txt_mod_params.chunk(2, dim=-1) # [B, 3*dim] each + else: + img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each + txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each + + img_normed = self.img_norm1(image) + img_modulated, img_gate = self._modulate(img_normed, img_mod_attn) + + txt_normed = self.txt_norm1(text) + txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn) + + img_attn_out, txt_attn_out = self.attn( + image=img_modulated, + text=txt_modulated, + rotary_emb=rotary_emb, + attn_mask=attn_mask, + attn_kwargs=attn_kwargs, + ) + + image = image + img_gate * img_attn_out + text = text + txt_gate * txt_attn_out + + img_normed_2 = self.img_norm2(image) + img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp) + + txt_normed_2 = self.txt_norm2(text) + txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp) + + img_mlp_out = self.img_mlp(img_modulated_2) + txt_mlp_out = self.txt_mlp(txt_modulated_2) + + image = image + img_gate_2 * img_mlp_out + text = text + txt_gate_2 * txt_mlp_out + + return text, image + + +class QwenImageDiTNunchaku(QwenImageDiT): + def __init__( + self, + num_layers: int = 60, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + use_nunchaku_awq: bool = True, + use_nunchaku_attn: bool = True, + nunchaku_rank: int = 32, + ): + super().__init__() + + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True, device=device) + + self.time_text_embed = TimestepEmbeddings(256, 3072, device=device, dtype=dtype) + + self.txt_norm = RMSNorm(3584, eps=1e-6, device=device, dtype=dtype) + + self.img_in = nn.Linear(64, 3072, device=device, dtype=dtype) + self.txt_in = nn.Linear(3584, 3072, device=device, dtype=dtype) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlockNunchaku( + dim=3072, + num_attention_heads=24, + attention_head_dim=128, + device=device, + dtype=dtype, + scale_shift=0, + use_nunchaku_awq=use_nunchaku_awq, + use_nunchaku_attn=use_nunchaku_attn, + nunchaku_rank=nunchaku_rank, + ) + for _ in range(num_layers) + ] + ) + self.norm_out = AdaLayerNorm(3072, device=device, dtype=dtype) + self.proj_out = nn.Linear(3072, 64, device=device, dtype=dtype) + + @classmethod + def from_state_dict( + cls, + state_dict: Dict[str, torch.Tensor], + device: str, + dtype: torch.dtype, + num_layers: int = 60, + use_nunchaku_awq: bool = True, + use_nunchaku_attn: bool = True, + nunchaku_rank: int = 32, + ): + model = cls( + device="meta", + dtype=dtype, + num_layers=num_layers, + use_nunchaku_awq=use_nunchaku_awq, + use_nunchaku_attn=use_nunchaku_attn, + nunchaku_rank=nunchaku_rank, + ) + model = model.requires_grad_(False) + model.load_state_dict(state_dict, assign=True) + model.to(device=device, non_blocking=True) + return model + + def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = False): + fuse_dict = {} + for args in lora_args: + key = args["key"] + if any(suffix in key for suffix in {"add_q_proj", "add_k_proj", "add_v_proj"}): + fuse_key = f"{key.rsplit('.', 1)[0]}.add_qkv_proj" + type = key.rsplit(".", 1)[-1].split("_")[1] + fuse_dict[fuse_key] = fuse_dict.get(fuse_key, {}) + fuse_dict[fuse_key][type] = args + continue + + if any(suffix in key for suffix in {"to_q", "to_k", "to_v"}): + fuse_key = f"{key.rsplit('.', 1)[0]}.to_qkv" + type = key.rsplit(".", 1)[-1].split("_")[1] + fuse_dict[fuse_key] = fuse_dict.get(fuse_key, {}) + fuse_dict[fuse_key][type] = args + continue + + module = self.get_submodule(key) + if not isinstance(module, (LoRALinear, LoRAConv2d, LoRASVDQW4A4Linear, LoRAAWQW4A16Linear)): + raise ValueError(f"Unsupported lora key: {key}") + + if fused and not isinstance(module, LoRAAWQW4A16Linear): + module.add_frozen_lora(**args) + else: + module.add_lora(**args) + + for key in fuse_dict.keys(): + module = self.get_submodule(key) + if not isinstance(module, LoRASVDQW4A4Linear): + raise ValueError(f"Unsupported lora key: {key}") + module.add_qkv_lora( + name=args["name"], + scale=fuse_dict[key]["q"]["scale"], + rank=fuse_dict[key]["q"]["rank"], + alpha=fuse_dict[key]["q"]["alpha"], + q_up=fuse_dict[key]["q"]["up"], + q_down=fuse_dict[key]["q"]["down"], + k_up=fuse_dict[key]["k"]["up"], + k_down=fuse_dict[key]["k"]["down"], + v_up=fuse_dict[key]["v"]["up"], + v_down=fuse_dict[key]["v"]["down"], + device=fuse_dict[key]["q"]["device"], + dtype=fuse_dict[key]["q"]["dtype"], + ) diff --git a/diffsynth_engine/pipelines/base.py b/diffsynth_engine/pipelines/base.py index 97bb9616..abe8532f 100644 --- a/diffsynth_engine/pipelines/base.py +++ b/diffsynth_engine/pipelines/base.py @@ -106,7 +106,8 @@ def load_loras( for key, param in state_dict.items(): lora_args.append( { - "name": key, + "name": lora_path, + "key": key, "scale": lora_scale, "rank": param["rank"], "alpha": param["alpha"], @@ -130,7 +131,10 @@ def unload_loras(self): @staticmethod def load_model_checkpoint( - checkpoint_path: str | List[str], device: str = "cpu", dtype: torch.dtype = torch.float16 + checkpoint_path: str | List[str], + device: str = "cpu", + dtype: torch.dtype = torch.float16, + convert_dtype: bool = True, ) -> Dict[str, torch.Tensor]: if isinstance(checkpoint_path, str): checkpoint_path = [checkpoint_path] @@ -140,8 +144,11 @@ def load_model_checkpoint( raise FileNotFoundError(f"{path} is not a file") elif path.endswith(".safetensors"): state_dict_ = load_file(path, device=device) - for key, value in state_dict_.items(): - state_dict[key] = value.to(dtype) + if convert_dtype: + for key, value in state_dict_.items(): + state_dict[key] = value.to(dtype) + else: + state_dict.update(state_dict_) elif path.endswith(".gguf"): state_dict.update(**load_gguf_checkpoint(path, device=device, dtype=dtype)) diff --git a/diffsynth_engine/pipelines/qwen_image.py b/diffsynth_engine/pipelines/qwen_image.py index 637467d3..5cd90a3b 100644 --- a/diffsynth_engine/pipelines/qwen_image.py +++ b/diffsynth_engine/pipelines/qwen_image.py @@ -2,6 +2,7 @@ import torch import torch.distributed as dist import math +import sys from typing import Callable, List, Dict, Tuple, Optional, Union from tqdm import tqdm from einops import rearrange @@ -38,11 +39,13 @@ from diffsynth_engine.utils import logging from diffsynth_engine.utils.fp8_linear import enable_fp8_linear from diffsynth_engine.utils.download import fetch_model +from diffsynth_engine.utils.flag import NUNCHAKU_AVAILABLE logger = logging.get_logger(__name__) + class QwenImageLoRAConverter(LoRAStateDictConverter): def _from_diffsynth(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: dit_dict = {} @@ -77,6 +80,7 @@ def _from_diffsynth(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, key = key.replace(f".{lora_a_suffix}", "") key = key.replace("base_model.model.", "") + key = key.replace("transformer.", "") if key.startswith("transformer") and "attn.to_out.0" in key: key = key.replace("attn.to_out.0", "attn.to_out") @@ -177,6 +181,36 @@ def __init__( "vae", ] + @classmethod + def _setup_nunchaku_config( + cls, model_state_dict: Dict[str, torch.Tensor], config: QwenImagePipelineConfig + ) -> QwenImagePipelineConfig: + is_nunchaku_model = any("qweight" in key for key in model_state_dict) + + if is_nunchaku_model: + logger.info("Nunchaku quantized model detected. Configuring for nunchaku.") + config.use_nunchaku = True + config.nunchaku_rank = model_state_dict["transformer_blocks.0.img_mlp.net.0.proj.proj_up"].shape[1] + + if "transformer_blocks.0.img_mod.1.qweight" in model_state_dict: + config.use_nunchaku_awq = True + logger.info("Enable nunchaku AWQ.") + else: + config.use_nunchaku_awq = False + logger.info("Disable nunchaku AWQ.") + + if "transformer_blocks.0.attn.to_qkv.qweight" in model_state_dict: + config.use_nunchaku_attn = True + logger.info("Enable nunchaku attention quantization.") + else: + config.use_nunchaku_attn = False + logger.info("Disable nunchaku attention quantization.") + + else: + config.use_nunchaku = False + + return config + @classmethod def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig) -> "QwenImagePipeline": if isinstance(model_path_or_config, str): @@ -185,7 +219,16 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig) -> config = model_path_or_config logger.info(f"loading state dict from {config.model_path} ...") - model_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype) + model_state_dict = cls.load_model_checkpoint( + config.model_path, device="cpu", dtype=config.model_dtype, convert_dtype=False + ) + + config = cls._setup_nunchaku_config(model_state_dict, config) + + # for svd quant model fp4/int4 linear layers, do not convert dtype here + if not config.use_nunchaku: + for key, value in model_state_dict.items(): + model_state_dict[key] = value.to(config.model_dtype) if config.vae_path is None: config.vae_path = fetch_model( @@ -221,6 +264,8 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig) -> @classmethod def from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipelineConfig) -> "QwenImagePipeline": + config = cls._setup_nunchaku_config(state_dicts.model, config) + if config.parallelism > 1: pipe = ParallelWrapper( cfg_degree=config.cfg_degree, @@ -270,13 +315,30 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip dtype=config.model_dtype, relative_l1_threshold=config.fbcache_relative_l1_threshold, ) + elif config.use_nunchaku: + if not NUNCHAKU_AVAILABLE: + from diffsynth_engine.utils.flag import NUNCHAKU_IMPORT_ERROR + raise ImportError(NUNCHAKU_IMPORT_ERROR) + + from diffsynth_engine.models.qwen_image import QwenImageDiTNunchaku + from diffsynth_engine.models.basic.lora_nunchaku import patch_nunchaku_model_for_lora + + dit = QwenImageDiTNunchaku.from_state_dict( + state_dicts.model, + device=init_device, + dtype=config.model_dtype, + use_nunchaku_awq=config.use_nunchaku_awq, + use_nunchaku_attn=config.use_nunchaku_attn, + nunchaku_rank=config.nunchaku_rank, + ) + patch_nunchaku_model_for_lora(dit) else: dit = QwenImageDiT.from_state_dict( state_dicts.model, device=("cpu" if config.use_fsdp else init_device), dtype=config.model_dtype, ) - if config.use_fp8_linear: + if config.use_fp8_linear and not config.use_nunchaku: enable_fp8_linear(dit) pipe = cls( diff --git a/diffsynth_engine/utils/flag.py b/diffsynth_engine/utils/flag.py index 7ac0b3e3..94b6afdb 100644 --- a/diffsynth_engine/utils/flag.py +++ b/diffsynth_engine/utils/flag.py @@ -50,3 +50,27 @@ logger.info("Video sparse attention is available") else: logger.info("Video sparse attention is not available") + +NUNCHAKU_AVAILABLE = importlib.util.find_spec("nunchaku") is not None +NUNCHAKU_IMPORT_ERROR = None +if NUNCHAKU_AVAILABLE: + logger.info("Nunchaku is available") +else: + logger.info("Nunchaku is not available") + import sys + torch_version = getattr(torch, "__version__", "unknown") + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + NUNCHAKU_IMPORT_ERROR = ( + "\n\n" + "ERROR: This model requires the 'nunchaku' library for quantized inference, but it is not installed.\n" + "'nunchaku' is not available on PyPI and must be installed manually.\n\n" + "Please follow these steps:\n" + "1. Visit the nunchaku releases page: https://github.com/nunchaku-tech/nunchaku/releases\n" + "2. Find the wheel (.whl) file that matches your environment:\n" + f" - PyTorch version: {torch_version}\n" + f" - Python version: {python_version}\n" + f" - Operating System: {sys.platform}\n" + "3. Copy the URL of the correct wheel file.\n" + "4. Install it using pip, for example:\n" + " pip install nunchaku @ https://.../your_specific_nunchaku_file.whl\n" + ) \ No newline at end of file diff --git a/examples/wan_lora_low_noise.py b/examples/wan_lora_low_noise.py index 5a11a9c9..fb00c391 100644 --- a/examples/wan_lora_low_noise.py +++ b/examples/wan_lora_low_noise.py @@ -35,8 +35,16 @@ device=args.device, ) pipe = WanVideoPipeline.from_pretrained(config) - pipe.load_loras_high_noise([(f"{args.lora_dir}/wan22-style1-violetevergarden-16-sel-2-high-000100.safetensors", 1.0)], fused=False, save_original_weight=False) - pipe.load_loras_low_noise([(f"{args.lora_dir}/wan22-style1-violetevergarden-16-sel-2-low-4-000060.safetensors", 1.0)], fused=False, save_original_weight=False) + pipe.load_loras_high_noise( + [(f"{args.lora_dir}/wan22-style1-violetevergarden-16-sel-2-high-000100.safetensors", 1.0)], + fused=False, + save_original_weight=False, + ) + pipe.load_loras_low_noise( + [(f"{args.lora_dir}/wan22-style1-violetevergarden-16-sel-2-low-4-000060.safetensors", 1.0)], + fused=False, + save_original_weight=False, + ) video = pipe( prompt="白天,晴天光,侧光,硬光,暖色调,中近景,中心构图,一个银色短发少女戴着精致的皇冠,穿着华丽的长裙,站在阳光明媚的花园中。她面向镜头微笑,眼睛闪烁着光芒。阳光从侧面照来,照亮了她的银色短发和华丽的服饰,营造出一种温暖而高贵的氛围。微风轻拂,吹动着她裙摆上的蕾丝花边,增添了几分动感。背景是盛开的花朵和绿意盎然的植物,为画面增色不少。,anime style", diff --git a/tests/data/expect/qwen_image/qwen_image_svd_quant.png b/tests/data/expect/qwen_image/qwen_image_svd_quant.png new file mode 100644 index 00000000..ccdd2bbf Binary files /dev/null and b/tests/data/expect/qwen_image/qwen_image_svd_quant.png differ diff --git a/tests/data/expect/qwen_image/qwen_image_svd_quant_lora.png b/tests/data/expect/qwen_image/qwen_image_svd_quant_lora.png new file mode 100644 index 00000000..1b9e1ff4 Binary files /dev/null and b/tests/data/expect/qwen_image/qwen_image_svd_quant_lora.png differ diff --git a/tests/data/input/man.png b/tests/data/input/man.png new file mode 100644 index 00000000..51e125d6 Binary files /dev/null and b/tests/data/input/man.png differ diff --git a/tests/data/input/puppy.png b/tests/data/input/puppy.png new file mode 100644 index 00000000..93f1535c Binary files /dev/null and b/tests/data/input/puppy.png differ diff --git a/tests/data/input/sofa.png b/tests/data/input/sofa.png new file mode 100644 index 00000000..8592e3e0 Binary files /dev/null and b/tests/data/input/sofa.png differ diff --git a/tests/test_pipelines/test_qwen_image_svd_quant.py b/tests/test_pipelines/test_qwen_image_svd_quant.py new file mode 100644 index 00000000..76c1890b --- /dev/null +++ b/tests/test_pipelines/test_qwen_image_svd_quant.py @@ -0,0 +1,72 @@ +import unittest +import torch +import math + +from diffsynth_engine import QwenImagePipelineConfig +from diffsynth_engine.pipelines import QwenImagePipeline +from diffsynth_engine.utils.download import fetch_model +from tests.common.test_case import ImageTestCase + + +class TestQwenImagePipelineSVDQuant(ImageTestCase): + @classmethod + def setUpClass(cls): + config = QwenImagePipelineConfig( + model_path=fetch_model( + "nunchaku-tech/nunchaku-qwen-image-edit-2509", path="svdq-int4_r128-qwen-image-edit-2509.safetensors" + ), + encoder_path=fetch_model("MusePublic/Qwen-image", revision="v1", path="text_encoder/*.safetensors"), + vae_path=fetch_model("MusePublic/Qwen-image", revision="v1", path="vae/*.safetensors"), + model_dtype=torch.bfloat16, + encoder_dtype=torch.bfloat16, + vae_dtype=torch.float32, + offload_mode="cpu_offload", + ) + cls.pipe = QwenImagePipeline.from_pretrained(config) + + @classmethod + def tearDownClass(cls): + del cls.pipe + + def test_txt2img(self): + image = self.pipe( + prompt="Let the man in image 1 lie on the sofa in image 3, and let the puppy in image 2 lie on the floor to sleep. ", + negative_prompt=" ", + input_image=[ + self.get_input_image("man.png"), + self.get_input_image("puppy.png"), + self.get_input_image("sofa.png"), + ], + cfg_scale=4.0, + num_inference_steps=40, + seed=42, + ) + self.assertImageEqualAndSaveFailed(image, "qwen_image/qwen_image_svd_quant.png", threshold=0.90) + + def test_lora(self): + self.pipe.load_lora( + path=fetch_model( + "lightx2v/Qwen-Image-Lightning", path="Qwen-Image-Edit-Lightning-4steps-V1.0-bf16.safetensors" + ), + scale=1.0, + fused=True, + ) + self.pipe.apply_scheduler_config({"exponential_shift_mu": math.log(3)}) + image = self.pipe( + prompt="Let the man in image 1 lie on the sofa in image 3, and let the puppy in image 2 lie on the floor to sleep. ", + negative_prompt=" ", + input_image=[ + self.get_input_image("man.png"), + self.get_input_image("puppy.png"), + self.get_input_image("sofa.png"), + ], + cfg_scale=1.0, + num_inference_steps=4, + seed=42, + ) + self.assertImageEqualAndSaveFailed(image, "qwen_image/qwen_image_svd_quant_lora.png", threshold=0.90) + self.pipe.unload_loras() + + +if __name__ == "__main__": + unittest.main()