Skip to content

support qwen3-moe awq #4059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions examples/export/quantize/moe/awq.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
pip uninstall autoawq
pip install git+https://github.com/casper-hansen/AutoAWQ.git --no-deps # or "autoawq>=0.2.9"

CUDA_VISIBLE_DEVICES=0,1 \
swift export \
--model Qwen/Qwen3-30B-A3B \
--dataset 'swift/Qwen3-SFT-Mixin' \
--device_map auto \
--quant_n_samples 64 \
--quant_batch_size -1 \
--max_length 8192 \
--quant_method awq \
--quant_bits 4 \
--output_dir Qwen3-30B-A3B-AWQ
23 changes: 21 additions & 2 deletions swift/llm/export/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from swift.llm import (ExportArguments, HfConfigFactory, MaxLengthError, ProcessorMixin, deep_getattr, get_model_arch,
is_moe_model, load_dataset, prepare_model_template, save_checkpoint, to_device)
from swift.utils import get_logger, get_model_parameter_info
from swift.utils import find_layers, get_logger, get_model_parameter_info

logger = get_logger()

Expand Down Expand Up @@ -136,6 +136,19 @@ def _move_embed(model, device: str):
finally:
awq_model.move_embed = _origin_move_embed

def get_awq_modules_to_not_convert(self):
block_name = self.get_block_name_to_quantize(self.model)
block = deep_getattr(self.model, block_name)[-1]
prefix, experts = self._get_experts(block)
num_experts = len(experts)

def cond(name, module):
if isinstance(module, nn.Linear) and module.out_features == num_experts:
return True
return False

return find_layers(self.model, cond, min_name_len=2) # min_name_len: fix Qwen3-MoE

def awq_model_quantize(self) -> None:
from awq.quantize import quantizer
from transformers import AwqConfig
Expand All @@ -150,6 +163,9 @@ def awq_model_quantize(self) -> None:
'w_bit': args.quant_bits,
'version': 'GEMM'
}
if is_moe_model(self.model):
quant_config['modules_to_not_convert'] = self.get_awq_modules_to_not_convert()
logger.info(f'quant_config: {quant_config}')
logger.info('Start quantizing the model...')
with self._patch_awq_move_embed(self.model):
self.model.quantize(
Expand Down Expand Up @@ -224,14 +240,17 @@ def gptq_model_quantize(self):
args = self.args
logger.info(f'Quantization dataset: {args.dataset}')
block_name_to_quantize = self.get_block_name_to_quantize(self.model)
modules_in_block_to_quantize = self.get_modules_in_block_to_quantize(self.model, block_name_to_quantize)
logger.info(f'block_name_to_quantize: {block_name_to_quantize}')
logger.info(f'modules_in_block_to_quantize: {modules_in_block_to_quantize}')
with self._patch_gptq():
gptq_quantizer = GPTQQuantizer(
bits=args.quant_bits,
group_size=args.group_size,
dataset=','.join(args.dataset),
batch_size=args.quant_batch_size,
block_name_to_quantize=block_name_to_quantize,
modules_in_block_to_quantize=self.get_modules_in_block_to_quantize(self.model, block_name_to_quantize))
modules_in_block_to_quantize=modules_in_block_to_quantize)
gptq_quantizer.serialization_keys.append('block_name_to_quantize')
logger.info('Start quantizing the model...')
logger.warning('The process of packing the model takes a long time and there is no progress bar. '
Expand Down
17 changes: 12 additions & 5 deletions swift/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,12 @@ def _sync_max_memory(max_memory: Dict[Union[int, str], int]) -> Dict[Union[int,
return new_max_memory


def find_layers(model: nn.Module,
cond: Callable[[str, nn.Module], bool],
sub_module: Optional[str] = None) -> List[str]:
def find_layers(
model: nn.Module,
cond: Callable[[str, nn.Module], bool],
sub_module: Optional[str] = None,
min_name_len: Optional[int] = None,
) -> List[str]:
# The content of target_module_names cannot exist in inner_nodes.
sub_module_str = sub_module
if sub_module is None:
Expand All @@ -170,13 +173,17 @@ def find_layers(model: nn.Module,
inner_nodes.add(name)
target_module_names = set()
for name, module in sub_module.named_modules():
name = f'{sub_module_str}.{name}'
if sub_module_str:
name = f'{sub_module_str}.{name}'
if cond(name, module):
module_name_list = name.split('.')
module_name = module_name_list.pop()
i = 1
for inner_node in inner_nodes:
while module_name_list and inner_node.endswith(re.sub(r'\d+\.', '{}.', module_name)):
while module_name_list and inner_node.endswith(re.sub(
r'\d+\.', '{}.', module_name)) or min_name_len and i < min_name_len:
module_name = f'{module_name_list.pop()}.{module_name}'
i += 1
target_module_names.add(module_name)
return list(target_module_names)

Expand Down