|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from collections import OrderedDict |
| 4 | + |
| 5 | +from .lora import LoRA |
| 6 | +from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear |
| 7 | +from nunchaku.lora.flux.nunchaku_converter import ( |
| 8 | + pack_lowrank_weight, |
| 9 | + unpack_lowrank_weight, |
| 10 | +) |
| 11 | + |
| 12 | + |
| 13 | +class LoRASVDQW4A4Linear(nn.Module): |
| 14 | + def __init__( |
| 15 | + self, |
| 16 | + origin_linear: SVDQW4A4Linear, |
| 17 | + ): |
| 18 | + super().__init__() |
| 19 | + |
| 20 | + self.origin_linear = origin_linear |
| 21 | + self.base_rank = self.origin_linear.rank |
| 22 | + self._lora_dict = OrderedDict() |
| 23 | + |
| 24 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 25 | + return self.origin_linear(x) |
| 26 | + |
| 27 | + def __getattr__(self, name: str): |
| 28 | + try: |
| 29 | + return super().__getattr__(name) |
| 30 | + except AttributeError: |
| 31 | + return getattr(self.origin_linear, name) |
| 32 | + |
| 33 | + def _apply_lora_weights(self, name: str, down: torch.Tensor, up: torch.Tensor, alpha: int, scale: float, rank: int): |
| 34 | + final_scale = scale * (alpha / rank) |
| 35 | + |
| 36 | + up_scaled = (up * final_scale).to( |
| 37 | + dtype=self.origin_linear.proj_up.dtype, device=self.origin_linear.proj_up.device |
| 38 | + ) |
| 39 | + down_final = down.to(dtype=self.origin_linear.proj_down.dtype, device=self.origin_linear.proj_down.device) |
| 40 | + |
| 41 | + with torch.no_grad(): |
| 42 | + pd_packed = self.origin_linear.proj_down.data |
| 43 | + pu_packed = self.origin_linear.proj_up.data |
| 44 | + pd = unpack_lowrank_weight(pd_packed, down=True) |
| 45 | + pu = unpack_lowrank_weight(pu_packed, down=False) |
| 46 | + |
| 47 | + new_proj_down = torch.cat([pd, down_final], dim=0) |
| 48 | + new_proj_up = torch.cat([pu, up_scaled], dim=1) |
| 49 | + |
| 50 | + self.origin_linear.proj_down.data = pack_lowrank_weight(new_proj_down, down=True) |
| 51 | + self.origin_linear.proj_up.data = pack_lowrank_weight(new_proj_up, down=False) |
| 52 | + |
| 53 | + current_total_rank = self.origin_linear.rank |
| 54 | + self.origin_linear.rank += rank |
| 55 | + self._lora_dict[name] = {"rank": rank, "alpha": alpha, "scale": scale, "start_idx": current_total_rank} |
| 56 | + |
| 57 | + def add_frozen_lora( |
| 58 | + self, |
| 59 | + name: str, |
| 60 | + scale: float, |
| 61 | + rank: int, |
| 62 | + alpha: int, |
| 63 | + up: torch.Tensor, |
| 64 | + down: torch.Tensor, |
| 65 | + device: str, |
| 66 | + dtype: torch.dtype, |
| 67 | + **kwargs, |
| 68 | + ): |
| 69 | + if name in self._lora_dict: |
| 70 | + raise ValueError(f"LoRA with name '{name}' already exists.") |
| 71 | + |
| 72 | + self._apply_lora_weights(name, down, up, alpha, scale, rank) |
| 73 | + |
| 74 | + def add_qkv_lora( |
| 75 | + self, |
| 76 | + name: str, |
| 77 | + scale: float, |
| 78 | + rank: int, |
| 79 | + alpha: int, |
| 80 | + q_up: torch.Tensor, |
| 81 | + q_down: torch.Tensor, |
| 82 | + k_up: torch.Tensor, |
| 83 | + k_down: torch.Tensor, |
| 84 | + v_up: torch.Tensor, |
| 85 | + v_down: torch.Tensor, |
| 86 | + device: str, |
| 87 | + dtype: torch.dtype, |
| 88 | + **kwargs, |
| 89 | + ): |
| 90 | + if name in self._lora_dict: |
| 91 | + raise ValueError(f"LoRA with name '{name}' already exists.") |
| 92 | + |
| 93 | + fused_down = torch.cat([q_down, k_down, v_down], dim=0) |
| 94 | + |
| 95 | + fused_rank = 3 * rank |
| 96 | + out_q, out_k = q_up.shape[0], k_up.shape[0] |
| 97 | + fused_up = torch.zeros((self.out_features, fused_rank), device=q_up.device, dtype=q_up.dtype) |
| 98 | + fused_up[:out_q, :rank] = q_up |
| 99 | + fused_up[out_q : out_q + out_k, rank : 2 * rank] = k_up |
| 100 | + fused_up[out_q + out_k :, 2 * rank :] = v_up |
| 101 | + |
| 102 | + self._apply_lora_weights(name, fused_down, fused_up, alpha, scale, rank) |
| 103 | + |
| 104 | + def modify_scale(self, name: str, scale: float): |
| 105 | + if name not in self._lora_dict: |
| 106 | + raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}") |
| 107 | + |
| 108 | + info = self._lora_dict[name] |
| 109 | + old_scale = info["scale"] |
| 110 | + |
| 111 | + if old_scale == scale: |
| 112 | + return |
| 113 | + |
| 114 | + if old_scale == 0: |
| 115 | + scale_factor = 0.0 |
| 116 | + else: |
| 117 | + scale_factor = scale / old_scale |
| 118 | + |
| 119 | + with torch.no_grad(): |
| 120 | + lora_rank = info["rank"] |
| 121 | + start_idx = info["start_idx"] |
| 122 | + end_idx = start_idx + lora_rank |
| 123 | + |
| 124 | + pu_packed = self.origin_linear.proj_up.data |
| 125 | + pu = unpack_lowrank_weight(pu_packed, down=False) |
| 126 | + pu[:, start_idx:end_idx] *= scale_factor |
| 127 | + |
| 128 | + self.origin_linear.proj_up.data = pack_lowrank_weight(pu, down=False) |
| 129 | + |
| 130 | + self._lora_dict[name]["scale"] = scale |
| 131 | + |
| 132 | + def clear(self, release_all_cpu_memory: bool = False): |
| 133 | + if not self._lora_dict: |
| 134 | + return |
| 135 | + |
| 136 | + with torch.no_grad(): |
| 137 | + pd_packed = self.origin_linear.proj_down.data |
| 138 | + pu_packed = self.origin_linear.proj_up.data |
| 139 | + |
| 140 | + pd = unpack_lowrank_weight(pd_packed, down=True) |
| 141 | + pu = unpack_lowrank_weight(pu_packed, down=False) |
| 142 | + |
| 143 | + pd_reset = pd[: self.base_rank, :].clone() |
| 144 | + pu_reset = pu[:, : self.base_rank].clone() |
| 145 | + |
| 146 | + self.origin_linear.proj_down.data = pack_lowrank_weight(pd_reset, down=True) |
| 147 | + self.origin_linear.proj_up.data = pack_lowrank_weight(pu_reset, down=False) |
| 148 | + |
| 149 | + self.origin_linear.rank = self.base_rank |
| 150 | + |
| 151 | + self._lora_dict.clear() |
| 152 | + |
| 153 | + |
| 154 | +class LoRAAWQW4A16Linear(nn.Module): |
| 155 | + def __init__(self, origin_linear: AWQW4A16Linear): |
| 156 | + super().__init__() |
| 157 | + self.origin_linear = origin_linear |
| 158 | + self._lora_dict = OrderedDict() |
| 159 | + |
| 160 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 161 | + quantized_output = self.origin_linear(x) |
| 162 | + |
| 163 | + for name, lora in self._lora_dict.items(): |
| 164 | + quantized_output += lora(x.to(lora.dtype)).to(quantized_output.dtype) |
| 165 | + |
| 166 | + return quantized_output |
| 167 | + |
| 168 | + def __getattr__(self, name: str): |
| 169 | + try: |
| 170 | + return super().__getattr__(name) |
| 171 | + except AttributeError: |
| 172 | + return getattr(self.origin_linear, name) |
| 173 | + |
| 174 | + def add_lora( |
| 175 | + self, |
| 176 | + name: str, |
| 177 | + scale: float, |
| 178 | + rank: int, |
| 179 | + alpha: int, |
| 180 | + up: torch.Tensor, |
| 181 | + down: torch.Tensor, |
| 182 | + device: str, |
| 183 | + dtype: torch.dtype, |
| 184 | + **kwargs, |
| 185 | + ): |
| 186 | + up_linear = nn.Linear(rank, self.out_features, bias=False, device="meta", dtype=dtype).to_empty(device=device) |
| 187 | + down_linear = nn.Linear(self.in_features, rank, bias=False, device="meta", dtype=dtype).to_empty(device=device) |
| 188 | + |
| 189 | + up_linear.weight.data = up.reshape(self.out_features, rank) |
| 190 | + down_linear.weight.data = down.reshape(rank, self.in_features) |
| 191 | + |
| 192 | + lora = LoRA(scale, rank, alpha, up_linear, down_linear, device, dtype) |
| 193 | + self._lora_dict[name] = lora |
| 194 | + |
| 195 | + def modify_scale(self, name: str, scale: float): |
| 196 | + if name not in self._lora_dict: |
| 197 | + raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}") |
| 198 | + self._lora_dict[name].scale = scale |
| 199 | + |
| 200 | + def add_frozen_lora(self, *args, **kwargs): |
| 201 | + raise NotImplementedError("Frozen LoRA (merging weights) is not supported for AWQW4A16Linear.") |
| 202 | + |
| 203 | + def clear(self, *args, **kwargs): |
| 204 | + self._lora_dict.clear() |
| 205 | + |
| 206 | + |
| 207 | +def patch_nunchaku_model_for_lora(model: nn.Module): |
| 208 | + def _recursive_patch(module: nn.Module): |
| 209 | + for name, child_module in module.named_children(): |
| 210 | + replacement = None |
| 211 | + if isinstance(child_module, AWQW4A16Linear): |
| 212 | + replacement = LoRAAWQW4A16Linear(child_module) |
| 213 | + elif isinstance(child_module, SVDQW4A4Linear): |
| 214 | + replacement = LoRASVDQW4A4Linear(child_module) |
| 215 | + |
| 216 | + if replacement: |
| 217 | + setattr(module, name, replacement) |
| 218 | + else: |
| 219 | + _recursive_patch(child_module) |
| 220 | + |
| 221 | + _recursive_patch(model) |
0 commit comments