Skip to content

Commit b41c78d

Browse files
authored
Fix test bug (#4851)
* refactor web_ui * fix ui page layout * refactor rollout and fix bug * modify external rollout display * fix train_local * remove deprecated params for rollout * add resume textbox and fix whitespace in more_params_cmd * Restored the accidentally deleted colocate parameters * fix bugs for colocate mode and rlhf params filter * fix reminder info err * fix syntax error * remove resume textbox and add resume checking * Simplify grpo tab and fix some bug * fix refresh tasks and queue backlog * fix bug
1 parent e09b00d commit b41c78d

File tree

16 files changed

+306
-351
lines changed

16 files changed

+306
-351
lines changed

swift/ui/llm_grpo/external_rollout.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class LLMRollout(BaseUI):
110110

111111
@classmethod
112112
def do_build_ui(cls, base_tab: Type['BaseUI']):
113-
with gr.Accordion(elem_id='llm_rollout', visible=False):
113+
with gr.Accordion(elem_id='llm_rollout', open=False, visible=False):
114114
default_device = 'cpu'
115115
device_count = get_device_count()
116116
if device_count > 0:
@@ -119,7 +119,6 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
119119
with gr.Row():
120120
gr.Textbox(elem_id='tensor_parallel_size', lines=1, value='1', scale=4)
121121
gr.Textbox(elem_id='data_parallel_size', lines=1, value='1', scale=4)
122-
gr.Textbox(elem_id='max_model_len', lines=1, value='', scale=4)
123122
gr.Slider(elem_id='gpu_memory_utilization', minimum=0.0, maximum=1.0, step=0.05, value=0.9, scale=4)
124123
with gr.Row(equal_height=True):
125124
gr.Dropdown(

swift/ui/llm_grpo/external_runtime.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ class RolloutRuntime(Runtime):
5656
'en': 'Logging content'
5757
},
5858
'info': {
59-
'zh': '如果日志无更新请再次点击"展示日志内容"',
60-
'en': 'Please press "Show log" if the log content is not updating'
59+
'zh': '如果日志无更新请再次点击"展示rollout状态"',
60+
'en': 'Please press "Show running status" if the log content is not updating'
6161
}
6262
},
6363
'rollout_running_tasks': {
@@ -90,6 +90,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
9090
with gr.Blocks():
9191
with gr.Row(equal_height=True):
9292
gr.Dropdown(elem_id='rollout_running_tasks', scale=10, allow_custom_value=True)
93+
with gr.Row(equal_height=True):
9394
gr.Button(elem_id='rollout_refresh_tasks', scale=1, variant='primary')
9495
gr.Button(elem_id='rollout_show_log', scale=1, variant='primary')
9596
gr.Button(elem_id='rollout_stop_show_log', scale=1)

swift/ui/llm_grpo/grpo_advanced.py

Lines changed: 159 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
from functools import partial
23
from typing import Type
34

45
import gradio as gr
56

7+
from swift.llm import BaseArguments, ModelType
8+
from swift.llm.model.register import get_all_models
69
from swift.ui.base import BaseUI
710

811

@@ -92,20 +95,167 @@ class GrpoAdvanced(BaseUI):
9295
'en': 'Skip overlong truncated samples and exclude them from loss calculation'
9396
}
9497
},
98+
'beta': {
99+
'label': {
100+
'zh': 'KL正则项系数',
101+
'en': 'KL regularization coefficient'
102+
}
103+
},
104+
'vllm_enable_prefix_caching': {
105+
'label': {
106+
'zh': '开启前缀缓存',
107+
'en': 'Enable prefix cache'
108+
},
109+
'info': {
110+
'zh': 'Colocate模式中vLLM透传参数',
111+
'en': 'vLLM transparent transmission parameters in colocate mode'
112+
}
113+
},
114+
'log_completions': {
115+
'label': {
116+
'zh': '记录生成内容',
117+
'en': 'Record generated content'
118+
},
119+
'info': {
120+
'zh': '是否记录训练中的模型生成内容',
121+
'en': 'Whether to record the model generation content during training'
122+
}
123+
},
124+
'num_iterations': {
125+
'label': {
126+
'zh': '每个批次更新次数',
127+
'en': 'Num of updates per batch'
128+
}
129+
},
130+
'reward_model': {
131+
'label': {
132+
'zh': '奖励模型id或路径',
133+
'en': 'Reward Model id or path'
134+
},
135+
'info': {
136+
'zh': '实际的模型id',
137+
'en': 'The actual model id or model path'
138+
}
139+
},
140+
'reward_model_type': {
141+
'label': {
142+
'zh': '奖励模型类型',
143+
'en': 'Select Reward Model Type'
144+
},
145+
'info': {
146+
'zh': 'SWIFT已支持的模型类型',
147+
'en': 'Base model type supported by SWIFT'
148+
}
149+
},
150+
'reward_model_plugin': {
151+
'label': {
152+
'zh': '奖励模型逻辑',
153+
'en': 'Reward model logic'
154+
},
155+
'info': {
156+
'zh': '利用reward_model_plugin自定义奖励模型的处理逻辑',
157+
'en': 'Use reward_model_plugin to customize the processing logic of the reward model'
158+
}
159+
},
160+
'external_plugins': {
161+
'label': {
162+
'zh': '外部插件文件',
163+
'en': 'External plugin file'
164+
},
165+
'info': {
166+
'zh': '外部插件文件列表,将被注册进插件模块中',
167+
'en': 'List of external plugin files that will be registered into the plugin module'
168+
}
169+
},
170+
'ref_model_type': {
171+
'label': {
172+
'zh': 'Ref模型类型',
173+
'en': 'Ref model type'
174+
},
175+
'info': {
176+
'zh': 'SWIFT已支持的模型类型',
177+
'en': 'Model type supported by SWIFT'
178+
}
179+
},
180+
'ref_model': {
181+
'label': {
182+
'zh': 'Ref模型id或路径',
183+
'en': 'Ref model id or path'
184+
},
185+
'info': {
186+
'zh': '实际的模型id或路径',
187+
'en': 'The actual model id or path'
188+
}
189+
},
95190
}
96191

97192
@classmethod
98193
def do_build_ui(cls, base_tab: Type['BaseUI']):
99194
with gr.TabItem(elem_id='grpo_advanced_tab'):
100195
with gr.Blocks():
101196
with gr.Row():
102-
gr.Dropdown(elem_id='loss_type', choices=['grpo', 'bnpo', 'dr_grpo'], value='grpo', scale=20)
103-
gr.Textbox(elem_id='epsilon', value=0.2, lines=1, scale=20)
104-
gr.Textbox(elem_id='epsilon_high', value=None, lines=1, scale=20)
105-
gr.Textbox(elem_id='move_model_batches', lines=1, scale=20)
197+
gr.Dropdown(elem_id='loss_type', choices=['grpo', 'bnpo', 'dr_grpo'], value='grpo', scale=4)
198+
gr.Textbox(elem_id='epsilon', value=0.2, lines=1, scale=4)
199+
gr.Textbox(elem_id='epsilon_high', value=None, lines=1, scale=4)
200+
gr.Textbox(elem_id='beta', value=0.04, lines=1, scale=4)
201+
gr.Textbox(elem_id='num_iterations', lines=1, scale=4)
106202
with gr.Row():
107-
gr.Textbox(elem_id='multi_turn_scheduler', lines=1, scale=20)
108-
gr.Textbox(elem_id='max_turns', lines=1, scale=20)
109-
gr.Checkbox(elem_id='dynamic_sample', scale=20)
110-
gr.Slider(elem_id='max_resample_times', minimum=1, maximum=16, step=1, value=3, scale=20)
111-
gr.Checkbox(elem_id='overlong_filter', scale=20)
203+
gr.Textbox(elem_id='move_model_batches', lines=1, scale=4)
204+
gr.Checkbox(elem_id='dynamic_sample', scale=4)
205+
gr.Slider(elem_id='max_resample_times', minimum=1, maximum=16, step=1, value=3, scale=4)
206+
gr.Checkbox(elem_id='overlong_filter', scale=4)
207+
gr.Checkbox(elem_id='vllm_enable_prefix_caching', scale=4)
208+
with gr.Row():
209+
gr.Checkbox(elem_id='log_completions', scale=4)
210+
gr.Textbox(elem_id='multi_turn_scheduler', lines=1, scale=4)
211+
gr.Textbox(elem_id='max_turns', lines=1, scale=4)
212+
gr.Textbox(elem_id='external_plugins', lines=1, scale=8)
213+
214+
with gr.Row():
215+
gr.Textbox(elem_id='reward_model_plugin', lines=1, scale=8)
216+
gr.Dropdown(elem_id='reward_model', multiselect=True, choices=get_all_models(), scale=8)
217+
gr.Dropdown(
218+
elem_id='reward_model_type',
219+
multiselect=True,
220+
choices=ModelType.get_model_name_list(),
221+
allow_custom_value=True,
222+
scale=4)
223+
with gr.Blocks():
224+
with gr.Row():
225+
gr.Dropdown(
226+
elem_id='ref_model', scale=12, value=None, choices=get_all_models(), allow_custom_value=True)
227+
gr.Dropdown(elem_id='ref_model_type', choices=ModelType.get_model_name_list(), value=None, scale=8)
228+
229+
@classmethod
230+
def after_build_ui(cls, base_tab: Type['BaseUI']):
231+
cls.element('ref_model').change(
232+
partial(cls.update_input_model, allow_keys=['ref_model_type'], has_record=False, is_ref_model=True),
233+
inputs=[cls.element('ref_model')],
234+
outputs=[cls.element('ref_model_type')])
235+
cls.element('reward_model').change(
236+
partial(cls.update_input_models, allow_keys=['reward_model_type'], is_reward_model=True, has_record=False),
237+
inputs=[cls.element('reward_model')],
238+
outputs=[cls.element('reward_model_type')])
239+
240+
@classmethod
241+
def update_input_models(cls,
242+
models,
243+
allow_keys=None,
244+
has_record=False,
245+
arg_cls=BaseArguments,
246+
is_reward_model=False):
247+
if models is None:
248+
return gr.update()
249+
rm_type_str = ''
250+
for model in models:
251+
rm_type_str = ' '.join([
252+
rm_type_str,
253+
cls.update_input_model(
254+
model,
255+
allow_keys=allow_keys,
256+
has_record=has_record,
257+
arg_cls=arg_cls,
258+
is_reward_model=is_reward_model)['value']
259+
])
260+
261+
return gr.update(value=rm_type_str.strip())

swift/ui/llm_grpo/llm_grpo.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Dict, Type
44

55
import gradio as gr
6+
from packaging import version
67

78
from swift.llm.argument.base_args.base_args import get_supported_tuners
89
from swift.ui.base import BaseUI
@@ -14,14 +15,14 @@
1415
from swift.ui.llm_grpo.model import GRPOModel
1516
from swift.ui.llm_grpo.optimizer import GRPOOptimizer
1617
from swift.ui.llm_grpo.quantization import GRPOQuantization
17-
from swift.ui.llm_grpo.ref_model import RefModel
1818
from swift.ui.llm_grpo.report_to import GRPOReportTo
1919
from swift.ui.llm_grpo.reward import Reward
2020
from swift.ui.llm_grpo.rollout import Rollout
2121
from swift.ui.llm_grpo.runtime import GRPORuntime
2222
from swift.ui.llm_grpo.save import GRPOSave
2323
from swift.ui.llm_grpo.tuner import GRPOTuner
2424
from swift.ui.llm_train.llm_train import LLMTrain
25+
from swift.ui.llm_train.runtime import Runtime
2526
from swift.utils import get_device_count, get_logger
2627

2728
logger = get_logger()
@@ -32,7 +33,7 @@ class LLMGRPO(LLMTrain):
3233

3334
sub_ui = [
3435
GRPOModel, GRPODataset, Reward, GRPORuntime, Rollout, GRPOSave, GRPOTuner, GRPOOptimizer, GRPOHyper,
35-
GRPOQuantization, GRPOAdvanced, RefModel, GrpoAdvanced, GRPOReportTo, LLMRollout
36+
GRPOQuantization, GRPOAdvanced, GrpoAdvanced, GRPOReportTo, LLMRollout
3637
]
3738

3839
locale_dict: Dict[str, Dict] = {
@@ -146,16 +147,6 @@ class LLMGRPO(LLMTrain):
146147
'en': 'The data parallel size of DDP'
147148
}
148149
},
149-
'tuner_backend': {
150-
'label': {
151-
'zh': 'Tuner backend',
152-
'en': 'Tuner backend'
153-
},
154-
'info': {
155-
'zh': 'Tuner实现框架',
156-
'en': 'The tuner backend'
157-
}
158-
},
159150
'use_liger_kernel': {
160151
'label': {
161152
'zh': '使用Liger kernel',
@@ -239,11 +230,17 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
239230
with gr.Accordion(elem_id='train_param', open=True):
240231
with gr.Row():
241232
gr.Dropdown(elem_id='train_type', scale=4, choices=list(get_supported_tuners()))
242-
gr.Dropdown(elem_id='tuner_backend', scale=4)
243233
gr.Textbox(elem_id='seed', scale=4)
244234
gr.Dropdown(elem_id='torch_dtype', scale=4)
245-
with gr.Row():
246235
gr.Checkbox(elem_id='use_liger_kernel', scale=4)
236+
gr.Textbox(elem_id='sequence_parallel_size', lines=1, scale=4)
237+
with gr.Row():
238+
gr.Dropdown(
239+
elem_id='gpu_id',
240+
multiselect=True,
241+
choices=[str(i) for i in range(device_count)] + ['cpu'],
242+
value=default_device,
243+
scale=8)
247244
gr.Checkbox(elem_id='use_ddp', value=False, scale=4)
248245
gr.Textbox(elem_id='ddp_num', value='1', scale=4)
249246
gr.Dropdown(
@@ -252,25 +249,17 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
252249
allow_custom_value=True,
253250
value=None,
254251
choices=['zero0', 'zero1', 'zero2', 'zero3', 'zero2_offload', 'zero3_offload'])
255-
gr.Textbox(elem_id='sequence_parallel_size', lines=1, scale=4)
256252
GRPOHyper.build_ui(base_tab)
257253
GRPORuntime.build_ui(base_tab)
258254
with gr.Row(equal_height=True):
259-
gr.Dropdown(
260-
elem_id='gpu_id',
261-
multiselect=True,
262-
choices=[str(i) for i in range(device_count)] + ['cpu'],
263-
value=default_device,
264-
scale=8)
265-
gr.Textbox(elem_id='envs', scale=8)
255+
gr.Textbox(elem_id='envs', scale=12)
266256
gr.Checkbox(elem_id='dry_run', value=False, scale=4)
267257
submit = gr.Button(elem_id='submit', scale=4, variant='primary')
268258

269259
Rollout.build_ui(base_tab)
270260
LLMRollout.set_lang(cls.lang)
271261
LLMRollout.build_ui(LLMRollout)
272262
GRPOTuner.build_ui(base_tab)
273-
RefModel.build_ui(base_tab)
274263
with gr.Accordion(elem_id='extra_params', open=True):
275264
with gr.Tabs():
276265
GrpoAdvanced.build_ui(base_tab)
@@ -286,13 +275,6 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
286275
inputs=[base_tab.element('train_type')],
287276
outputs=[cls.element('learning_rate')])
288277

289-
base_tab.element('gpu_id').change(
290-
cls.update_ddp_num,
291-
[base_tab.element('gpu_id'), base_tab.element('use_ddp')], base_tab.element('ddp_num'))
292-
base_tab.element('use_ddp').change(
293-
cls.update_ddp_num,
294-
[base_tab.element('gpu_id'), base_tab.element('use_ddp')], base_tab.element('ddp_num'))
295-
296278
submit.click(
297279
cls.train_local,
298280
list(cls.valid_elements().values()), [
@@ -312,15 +294,35 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
312294
cls.element('template')],
313295
[LLMRollout.element('rollout_runtime_tab'),
314296
LLMRollout.element('rollout_running_tasks')])
315-
base_tab.element('running_tasks').change(
316-
partial(GRPORuntime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')],
317-
list(base_tab.valid_elements().values()) + [cls.element('log')] + GRPORuntime.all_plots)
297+
318298
GRPORuntime.element('kill_task').click(
319299
GRPORuntime.kill_task,
320300
[GRPORuntime.element('running_tasks')],
321301
[GRPORuntime.element('running_tasks')] + [GRPORuntime.element('log')] + GRPORuntime.all_plots,
322302
).then(GRPORuntime.reset, [], [GRPORuntime.element('logging_dir')] + [GRPOHyper.element('output_dir')])
323303

304+
base_tab.element('gpu_id').change(
305+
cls.update_ddp_num,
306+
[base_tab.element('gpu_id'), base_tab.element('use_ddp')], base_tab.element('ddp_num'))
307+
base_tab.element('use_ddp').change(
308+
cls.update_ddp_num,
309+
[base_tab.element('gpu_id'), base_tab.element('use_ddp')], base_tab.element('ddp_num'))
310+
base_tab.element('ddp_num').change(Rollout.update_num_gen, [
311+
GRPOHyper.element('per_device_train_batch_size'),
312+
GRPOHyper.element('gradient_accumulation_steps'),
313+
cls.element('ddp_num')
314+
], [Rollout.element('num_generations')])
315+
GRPOHyper.element('gradient_accumulation_steps').change(Rollout.update_num_gen, [
316+
GRPOHyper.element('per_device_train_batch_size'),
317+
GRPOHyper.element('gradient_accumulation_steps'),
318+
cls.element('ddp_num')
319+
], [Rollout.element('num_generations')])
320+
GRPOHyper.element('per_device_train_batch_size').change(Rollout.update_num_gen, [
321+
GRPOHyper.element('per_device_train_batch_size'),
322+
GRPOHyper.element('gradient_accumulation_steps'),
323+
cls.element('ddp_num')
324+
], [Rollout.element('num_generations')])
325+
324326
@classmethod
325327
def prepare_sub_to_filter(cls):
326328
tabs_relation_dict = {

0 commit comments

Comments
 (0)