3
3
from typing import Dict , Type
4
4
5
5
import gradio as gr
6
+ from packaging import version
6
7
7
8
from swift .llm .argument .base_args .base_args import get_supported_tuners
8
9
from swift .ui .base import BaseUI
14
15
from swift .ui .llm_grpo .model import GRPOModel
15
16
from swift .ui .llm_grpo .optimizer import GRPOOptimizer
16
17
from swift .ui .llm_grpo .quantization import GRPOQuantization
17
- from swift .ui .llm_grpo .ref_model import RefModel
18
18
from swift .ui .llm_grpo .report_to import GRPOReportTo
19
19
from swift .ui .llm_grpo .reward import Reward
20
20
from swift .ui .llm_grpo .rollout import Rollout
21
21
from swift .ui .llm_grpo .runtime import GRPORuntime
22
22
from swift .ui .llm_grpo .save import GRPOSave
23
23
from swift .ui .llm_grpo .tuner import GRPOTuner
24
24
from swift .ui .llm_train .llm_train import LLMTrain
25
+ from swift .ui .llm_train .runtime import Runtime
25
26
from swift .utils import get_device_count , get_logger
26
27
27
28
logger = get_logger ()
@@ -32,7 +33,7 @@ class LLMGRPO(LLMTrain):
32
33
33
34
sub_ui = [
34
35
GRPOModel , GRPODataset , Reward , GRPORuntime , Rollout , GRPOSave , GRPOTuner , GRPOOptimizer , GRPOHyper ,
35
- GRPOQuantization , GRPOAdvanced , RefModel , GrpoAdvanced , GRPOReportTo , LLMRollout
36
+ GRPOQuantization , GRPOAdvanced , GrpoAdvanced , GRPOReportTo , LLMRollout
36
37
]
37
38
38
39
locale_dict : Dict [str , Dict ] = {
@@ -146,16 +147,6 @@ class LLMGRPO(LLMTrain):
146
147
'en' : 'The data parallel size of DDP'
147
148
}
148
149
},
149
- 'tuner_backend' : {
150
- 'label' : {
151
- 'zh' : 'Tuner backend' ,
152
- 'en' : 'Tuner backend'
153
- },
154
- 'info' : {
155
- 'zh' : 'Tuner实现框架' ,
156
- 'en' : 'The tuner backend'
157
- }
158
- },
159
150
'use_liger_kernel' : {
160
151
'label' : {
161
152
'zh' : '使用Liger kernel' ,
@@ -239,11 +230,17 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
239
230
with gr .Accordion (elem_id = 'train_param' , open = True ):
240
231
with gr .Row ():
241
232
gr .Dropdown (elem_id = 'train_type' , scale = 4 , choices = list (get_supported_tuners ()))
242
- gr .Dropdown (elem_id = 'tuner_backend' , scale = 4 )
243
233
gr .Textbox (elem_id = 'seed' , scale = 4 )
244
234
gr .Dropdown (elem_id = 'torch_dtype' , scale = 4 )
245
- with gr .Row ():
246
235
gr .Checkbox (elem_id = 'use_liger_kernel' , scale = 4 )
236
+ gr .Textbox (elem_id = 'sequence_parallel_size' , lines = 1 , scale = 4 )
237
+ with gr .Row ():
238
+ gr .Dropdown (
239
+ elem_id = 'gpu_id' ,
240
+ multiselect = True ,
241
+ choices = [str (i ) for i in range (device_count )] + ['cpu' ],
242
+ value = default_device ,
243
+ scale = 8 )
247
244
gr .Checkbox (elem_id = 'use_ddp' , value = False , scale = 4 )
248
245
gr .Textbox (elem_id = 'ddp_num' , value = '1' , scale = 4 )
249
246
gr .Dropdown (
@@ -252,25 +249,17 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
252
249
allow_custom_value = True ,
253
250
value = None ,
254
251
choices = ['zero0' , 'zero1' , 'zero2' , 'zero3' , 'zero2_offload' , 'zero3_offload' ])
255
- gr .Textbox (elem_id = 'sequence_parallel_size' , lines = 1 , scale = 4 )
256
252
GRPOHyper .build_ui (base_tab )
257
253
GRPORuntime .build_ui (base_tab )
258
254
with gr .Row (equal_height = True ):
259
- gr .Dropdown (
260
- elem_id = 'gpu_id' ,
261
- multiselect = True ,
262
- choices = [str (i ) for i in range (device_count )] + ['cpu' ],
263
- value = default_device ,
264
- scale = 8 )
265
- gr .Textbox (elem_id = 'envs' , scale = 8 )
255
+ gr .Textbox (elem_id = 'envs' , scale = 12 )
266
256
gr .Checkbox (elem_id = 'dry_run' , value = False , scale = 4 )
267
257
submit = gr .Button (elem_id = 'submit' , scale = 4 , variant = 'primary' )
268
258
269
259
Rollout .build_ui (base_tab )
270
260
LLMRollout .set_lang (cls .lang )
271
261
LLMRollout .build_ui (LLMRollout )
272
262
GRPOTuner .build_ui (base_tab )
273
- RefModel .build_ui (base_tab )
274
263
with gr .Accordion (elem_id = 'extra_params' , open = True ):
275
264
with gr .Tabs ():
276
265
GrpoAdvanced .build_ui (base_tab )
@@ -286,13 +275,6 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
286
275
inputs = [base_tab .element ('train_type' )],
287
276
outputs = [cls .element ('learning_rate' )])
288
277
289
- base_tab .element ('gpu_id' ).change (
290
- cls .update_ddp_num ,
291
- [base_tab .element ('gpu_id' ), base_tab .element ('use_ddp' )], base_tab .element ('ddp_num' ))
292
- base_tab .element ('use_ddp' ).change (
293
- cls .update_ddp_num ,
294
- [base_tab .element ('gpu_id' ), base_tab .element ('use_ddp' )], base_tab .element ('ddp_num' ))
295
-
296
278
submit .click (
297
279
cls .train_local ,
298
280
list (cls .valid_elements ().values ()), [
@@ -312,15 +294,35 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
312
294
cls .element ('template' )],
313
295
[LLMRollout .element ('rollout_runtime_tab' ),
314
296
LLMRollout .element ('rollout_running_tasks' )])
315
- base_tab .element ('running_tasks' ).change (
316
- partial (GRPORuntime .task_changed , base_tab = base_tab ), [base_tab .element ('running_tasks' )],
317
- list (base_tab .valid_elements ().values ()) + [cls .element ('log' )] + GRPORuntime .all_plots )
297
+
318
298
GRPORuntime .element ('kill_task' ).click (
319
299
GRPORuntime .kill_task ,
320
300
[GRPORuntime .element ('running_tasks' )],
321
301
[GRPORuntime .element ('running_tasks' )] + [GRPORuntime .element ('log' )] + GRPORuntime .all_plots ,
322
302
).then (GRPORuntime .reset , [], [GRPORuntime .element ('logging_dir' )] + [GRPOHyper .element ('output_dir' )])
323
303
304
+ base_tab .element ('gpu_id' ).change (
305
+ cls .update_ddp_num ,
306
+ [base_tab .element ('gpu_id' ), base_tab .element ('use_ddp' )], base_tab .element ('ddp_num' ))
307
+ base_tab .element ('use_ddp' ).change (
308
+ cls .update_ddp_num ,
309
+ [base_tab .element ('gpu_id' ), base_tab .element ('use_ddp' )], base_tab .element ('ddp_num' ))
310
+ base_tab .element ('ddp_num' ).change (Rollout .update_num_gen , [
311
+ GRPOHyper .element ('per_device_train_batch_size' ),
312
+ GRPOHyper .element ('gradient_accumulation_steps' ),
313
+ cls .element ('ddp_num' )
314
+ ], [Rollout .element ('num_generations' )])
315
+ GRPOHyper .element ('gradient_accumulation_steps' ).change (Rollout .update_num_gen , [
316
+ GRPOHyper .element ('per_device_train_batch_size' ),
317
+ GRPOHyper .element ('gradient_accumulation_steps' ),
318
+ cls .element ('ddp_num' )
319
+ ], [Rollout .element ('num_generations' )])
320
+ GRPOHyper .element ('per_device_train_batch_size' ).change (Rollout .update_num_gen , [
321
+ GRPOHyper .element ('per_device_train_batch_size' ),
322
+ GRPOHyper .element ('gradient_accumulation_steps' ),
323
+ cls .element ('ddp_num' )
324
+ ], [Rollout .element ('num_generations' )])
325
+
324
326
@classmethod
325
327
def prepare_sub_to_filter (cls ):
326
328
tabs_relation_dict = {
0 commit comments