Skip to content

Commit ae4faeb

Browse files
authored
support svd quant (#202)
* support svd quant * auto set nunchaku config * fix nunchaku transformer init * fix qwen image init * fix svd quant attn init * mv nunchaku import error to flag
1 parent f7119c8 commit ae4faeb

File tree

16 files changed

+758
-9
lines changed

16 files changed

+758
-9
lines changed

diffsynth_engine/configs/pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,11 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
251251
# override OptimizationConfig
252252
fbcache_relative_l1_threshold = 0.009
253253

254+
# svd
255+
use_nunchaku: Optional[bool] = field(default=None, init=False)
256+
use_nunchaku_awq: Optional[bool] = field(default=None, init=False)
257+
use_nunchaku_attn: Optional[bool] = field(default=None, init=False)
258+
254259
@classmethod
255260
def basic_config(
256261
cls,

diffsynth_engine/models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype
4040

4141
def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = True):
4242
for args in lora_args:
43-
key = args["name"]
43+
key = args["key"]
4444
module = self.get_submodule(key)
4545
if not isinstance(module, (LoRALinear, LoRAConv2d)):
4646
raise ValueError(f"Unsupported lora key: {key}")

diffsynth_engine/models/basic/lora.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def add_frozen_lora(
132132
device: str,
133133
dtype: torch.dtype,
134134
save_original_weight: bool = True,
135+
**kwargs,
135136
):
136137
if save_original_weight and self._original_weight is None:
137138
if self.weight.dtype == torch.float8_e4m3fn:
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import torch
2+
import torch.nn as nn
3+
from collections import OrderedDict
4+
5+
from .lora import LoRA
6+
from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear
7+
from nunchaku.lora.flux.nunchaku_converter import (
8+
pack_lowrank_weight,
9+
unpack_lowrank_weight,
10+
)
11+
12+
13+
class LoRASVDQW4A4Linear(nn.Module):
14+
def __init__(
15+
self,
16+
origin_linear: SVDQW4A4Linear,
17+
):
18+
super().__init__()
19+
20+
self.origin_linear = origin_linear
21+
self.base_rank = self.origin_linear.rank
22+
self._lora_dict = OrderedDict()
23+
24+
def forward(self, x: torch.Tensor) -> torch.Tensor:
25+
return self.origin_linear(x)
26+
27+
def __getattr__(self, name: str):
28+
try:
29+
return super().__getattr__(name)
30+
except AttributeError:
31+
return getattr(self.origin_linear, name)
32+
33+
def _apply_lora_weights(self, name: str, down: torch.Tensor, up: torch.Tensor, alpha: int, scale: float, rank: int):
34+
final_scale = scale * (alpha / rank)
35+
36+
up_scaled = (up * final_scale).to(
37+
dtype=self.origin_linear.proj_up.dtype, device=self.origin_linear.proj_up.device
38+
)
39+
down_final = down.to(dtype=self.origin_linear.proj_down.dtype, device=self.origin_linear.proj_down.device)
40+
41+
with torch.no_grad():
42+
pd_packed = self.origin_linear.proj_down.data
43+
pu_packed = self.origin_linear.proj_up.data
44+
pd = unpack_lowrank_weight(pd_packed, down=True)
45+
pu = unpack_lowrank_weight(pu_packed, down=False)
46+
47+
new_proj_down = torch.cat([pd, down_final], dim=0)
48+
new_proj_up = torch.cat([pu, up_scaled], dim=1)
49+
50+
self.origin_linear.proj_down.data = pack_lowrank_weight(new_proj_down, down=True)
51+
self.origin_linear.proj_up.data = pack_lowrank_weight(new_proj_up, down=False)
52+
53+
current_total_rank = self.origin_linear.rank
54+
self.origin_linear.rank += rank
55+
self._lora_dict[name] = {"rank": rank, "alpha": alpha, "scale": scale, "start_idx": current_total_rank}
56+
57+
def add_frozen_lora(
58+
self,
59+
name: str,
60+
scale: float,
61+
rank: int,
62+
alpha: int,
63+
up: torch.Tensor,
64+
down: torch.Tensor,
65+
device: str,
66+
dtype: torch.dtype,
67+
**kwargs,
68+
):
69+
if name in self._lora_dict:
70+
raise ValueError(f"LoRA with name '{name}' already exists.")
71+
72+
self._apply_lora_weights(name, down, up, alpha, scale, rank)
73+
74+
def add_qkv_lora(
75+
self,
76+
name: str,
77+
scale: float,
78+
rank: int,
79+
alpha: int,
80+
q_up: torch.Tensor,
81+
q_down: torch.Tensor,
82+
k_up: torch.Tensor,
83+
k_down: torch.Tensor,
84+
v_up: torch.Tensor,
85+
v_down: torch.Tensor,
86+
device: str,
87+
dtype: torch.dtype,
88+
**kwargs,
89+
):
90+
if name in self._lora_dict:
91+
raise ValueError(f"LoRA with name '{name}' already exists.")
92+
93+
fused_down = torch.cat([q_down, k_down, v_down], dim=0)
94+
95+
fused_rank = 3 * rank
96+
out_q, out_k = q_up.shape[0], k_up.shape[0]
97+
fused_up = torch.zeros((self.out_features, fused_rank), device=q_up.device, dtype=q_up.dtype)
98+
fused_up[:out_q, :rank] = q_up
99+
fused_up[out_q : out_q + out_k, rank : 2 * rank] = k_up
100+
fused_up[out_q + out_k :, 2 * rank :] = v_up
101+
102+
self._apply_lora_weights(name, fused_down, fused_up, alpha, scale, rank)
103+
104+
def modify_scale(self, name: str, scale: float):
105+
if name not in self._lora_dict:
106+
raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}")
107+
108+
info = self._lora_dict[name]
109+
old_scale = info["scale"]
110+
111+
if old_scale == scale:
112+
return
113+
114+
if old_scale == 0:
115+
scale_factor = 0.0
116+
else:
117+
scale_factor = scale / old_scale
118+
119+
with torch.no_grad():
120+
lora_rank = info["rank"]
121+
start_idx = info["start_idx"]
122+
end_idx = start_idx + lora_rank
123+
124+
pu_packed = self.origin_linear.proj_up.data
125+
pu = unpack_lowrank_weight(pu_packed, down=False)
126+
pu[:, start_idx:end_idx] *= scale_factor
127+
128+
self.origin_linear.proj_up.data = pack_lowrank_weight(pu, down=False)
129+
130+
self._lora_dict[name]["scale"] = scale
131+
132+
def clear(self, release_all_cpu_memory: bool = False):
133+
if not self._lora_dict:
134+
return
135+
136+
with torch.no_grad():
137+
pd_packed = self.origin_linear.proj_down.data
138+
pu_packed = self.origin_linear.proj_up.data
139+
140+
pd = unpack_lowrank_weight(pd_packed, down=True)
141+
pu = unpack_lowrank_weight(pu_packed, down=False)
142+
143+
pd_reset = pd[: self.base_rank, :].clone()
144+
pu_reset = pu[:, : self.base_rank].clone()
145+
146+
self.origin_linear.proj_down.data = pack_lowrank_weight(pd_reset, down=True)
147+
self.origin_linear.proj_up.data = pack_lowrank_weight(pu_reset, down=False)
148+
149+
self.origin_linear.rank = self.base_rank
150+
151+
self._lora_dict.clear()
152+
153+
154+
class LoRAAWQW4A16Linear(nn.Module):
155+
def __init__(self, origin_linear: AWQW4A16Linear):
156+
super().__init__()
157+
self.origin_linear = origin_linear
158+
self._lora_dict = OrderedDict()
159+
160+
def forward(self, x: torch.Tensor) -> torch.Tensor:
161+
quantized_output = self.origin_linear(x)
162+
163+
for name, lora in self._lora_dict.items():
164+
quantized_output += lora(x.to(lora.dtype)).to(quantized_output.dtype)
165+
166+
return quantized_output
167+
168+
def __getattr__(self, name: str):
169+
try:
170+
return super().__getattr__(name)
171+
except AttributeError:
172+
return getattr(self.origin_linear, name)
173+
174+
def add_lora(
175+
self,
176+
name: str,
177+
scale: float,
178+
rank: int,
179+
alpha: int,
180+
up: torch.Tensor,
181+
down: torch.Tensor,
182+
device: str,
183+
dtype: torch.dtype,
184+
**kwargs,
185+
):
186+
up_linear = nn.Linear(rank, self.out_features, bias=False, device="meta", dtype=dtype).to_empty(device=device)
187+
down_linear = nn.Linear(self.in_features, rank, bias=False, device="meta", dtype=dtype).to_empty(device=device)
188+
189+
up_linear.weight.data = up.reshape(self.out_features, rank)
190+
down_linear.weight.data = down.reshape(rank, self.in_features)
191+
192+
lora = LoRA(scale, rank, alpha, up_linear, down_linear, device, dtype)
193+
self._lora_dict[name] = lora
194+
195+
def modify_scale(self, name: str, scale: float):
196+
if name not in self._lora_dict:
197+
raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}")
198+
self._lora_dict[name].scale = scale
199+
200+
def add_frozen_lora(self, *args, **kwargs):
201+
raise NotImplementedError("Frozen LoRA (merging weights) is not supported for AWQW4A16Linear.")
202+
203+
def clear(self, *args, **kwargs):
204+
self._lora_dict.clear()
205+
206+
207+
def patch_nunchaku_model_for_lora(model: nn.Module):
208+
def _recursive_patch(module: nn.Module):
209+
for name, child_module in module.named_children():
210+
replacement = None
211+
if isinstance(child_module, AWQW4A16Linear):
212+
replacement = LoRAAWQW4A16Linear(child_module)
213+
elif isinstance(child_module, SVDQW4A4Linear):
214+
replacement = LoRASVDQW4A4Linear(child_module)
215+
216+
if replacement:
217+
setattr(module, name, replacement)
218+
else:
219+
_recursive_patch(child_module)
220+
221+
_recursive_patch(model)

diffsynth_engine/models/qwen_image/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,11 @@
1111
"Qwen2_5_VLVisionConfig",
1212
"Qwen2_5_VLConfig",
1313
]
14+
15+
try:
16+
from .qwen_image_dit_nunchaku import QwenImageDiTNunchaku
17+
18+
__all__.append("QwenImageDiTNunchaku")
19+
20+
except (ImportError, ModuleNotFoundError):
21+
pass

0 commit comments

Comments
 (0)