Skip to content

Commit a954dc0

Browse files
committed
Fix minimax & fix agent_template (#4618)
1 parent f14d61a commit a954dc0

File tree

3 files changed

+34
-13
lines changed

3 files changed

+34
-13
lines changed

swift/llm/model/model/minimax.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ def get_model_tokenizer_minimax_vl(model_dir: str,
2727
device_ids = list(range(max(local_rank, 0), n_gpu, local_world_size))
2828
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
2929
kwargs['model_config'] = config
30-
if kwargs.get('attn_impl') == 'flash_attn':
31-
config.attn_type_list = [1] * len(config.attn_type_list)
32-
else:
33-
config.attn_type_list = [0] * len(config.attn_type_list)
3430
if 'quantization_config' in model_kwargs:
3531
quantization_config = model_kwargs['quantization_config']
3632
from transformers import QuantoConfig
@@ -111,11 +107,6 @@ def get_model_tokenizer_minimax_text(model_dir: str,
111107
device_ids = list(range(max(local_rank, 0), n_gpu, local_world_size))
112108
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
113109
kwargs['model_config'] = config
114-
if hasattr(config, 'attn_type_list'):
115-
if kwargs.get('attn_impl') == 'flash_attn':
116-
config.attn_type_list = [1] * len(config.attn_type_list)
117-
else:
118-
config.attn_type_list = [0] * len(config.attn_type_list)
119110
if 'quantization_config' in model_kwargs:
120111
quantization_config = model_kwargs['quantization_config']
121112
from transformers import QuantoConfig
@@ -150,7 +141,6 @@ def get_model_tokenizer_minimax_text(model_dir: str,
150141
LLMModelType.minimax, [
151142
ModelGroup([
152143
Model('MiniMax/MiniMax-Text-01', 'MiniMaxAI/MiniMax-Text-01'),
153-
Model('MiniMax/MiniMax-Text-01-hf', 'MiniMaxAI/MiniMax-Text-01-hf'),
154144
]),
155145
],
156146
TemplateType.minimax,

swift/plugin/agent_template/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ def _format_tool_responses(
8585
def _parse_tool_call(content) -> Dict[str, Any]:
8686
obj = BaseAgentTemplate._parse_json(content)
8787
name = obj['name']
88-
arguments = obj.get('arguments') or obj.get('parameters')
88+
arguments = obj.get('arguments')
89+
if arguments is None:
90+
arguments = obj.get('parameters')
8991
arguments = BaseAgentTemplate._parse_json(arguments)
9092
assert arguments is not None, f'content: {content}'
9193
return {'name': name, 'arguments': arguments}
@@ -127,7 +129,9 @@ def _parse_tool(tool, lang: Literal['zh', 'en']) -> ToolDesc:
127129
name_for_model = BaseAgentTemplate._get_tool_name(tool)
128130
name_for_human = tool.get('name_for_human') or name_for_model
129131

130-
description = tool.get('description') or tool.get('description_for_model')
132+
description = tool.get('description')
133+
if description is None:
134+
description = tool.get('description_for_model')
131135
parameters = tool.get('parameters') or {}
132136
parameters = parameters if isinstance(parameters, str) else json.dumps(parameters, ensure_ascii=False)
133137
args_format = '此工具的输入应为JSON对象。' if lang == 'zh' else 'Format the arguments as a JSON object.'

tests/test_align/test_template/test_llm.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def test_glm_edge():
175175

176176

177177
def test_llama():
178+
from swift.llm import VllmEngine
178179
# pt_engine = PtEngine('LLM-Research/Meta-Llama-3.1-8B-Instruct-BNB-NF4')
179180
# pt_engine = PtEngine('LLM-Research/Meta-Llama-3.1-8B-Instruct')
180181
# pt_engine = PtEngine('LLM-Research/Meta-Llama-3-8B-Instruct')
@@ -397,6 +398,30 @@ def test_mimo():
397398
assert res == res2, f'res: {res}, res2: {res2}'
398399

399400

401+
def test_minicpm():
402+
pt_engine = PtEngine('OpenBMB/MiniCPM4-0.5B')
403+
res = _infer_model(pt_engine)
404+
pt_engine.default_template.template_backend = 'jinja'
405+
res2 = _infer_model(pt_engine)
406+
assert res == res2, f'res: {res}, res2: {res2}'
407+
408+
409+
def test_minimax():
410+
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
411+
from transformers import QuantoConfig
412+
quantization_config = QuantoConfig(weights='int8')
413+
messages = [{
414+
'role': 'system',
415+
'content': 'You are a helpful assistant.'
416+
}, {
417+
'role': 'user',
418+
'content': 'who are you?'
419+
}]
420+
pt_engine = PtEngine('MiniMax/MiniMax-M1-40k', quantization_config=quantization_config)
421+
res = _infer_model(pt_engine, messages=messages)
422+
print(f'res: {res}')
423+
424+
400425
if __name__ == '__main__':
401426
from swift.llm import PtEngine, RequestConfig
402427
from swift.utils import get_logger, seed_everything
@@ -435,4 +460,6 @@ def test_mimo():
435460
# test_gemma3()
436461
# test_glm4_0414()
437462
# test_qwen3()
438-
test_mimo()
463+
# test_mimo()
464+
# test_minicpm()
465+
test_minimax()

0 commit comments

Comments
 (0)