Skip to content

Commit f891985

Browse files
authored
update stream & fix bugs (#4842)
1 parent 77ad540 commit f891985

File tree

11 files changed

+41
-6
lines changed

11 files changed

+41
-6
lines changed

docs/source/Instruction/命令行参数.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@
110110
- top_p: top_p参数,默认为None。读取generation_config.json。
111111
- repetition_penalty: 重复惩罚项。默认为None,读取generation_config.json。
112112
- num_beams: beam search的并行保留数量,默认为1。
113-
- 🔥stream: 流式输出,默认为`False`
113+
- 🔥stream: 流式输出,默认为`None`,即使用交互式界面时为True,数据集批量推理时为False。
114+
- "ms-swift<3.6"stream默认值为False。
114115
- stop_words: 除了eos_token外额外的停止词,默认为`[]`
115116
- 注意:eos_token会在输出respsone中被删除,额外停止词会在输出中保留。
116117
- logprobs: 是否输出logprobs,默认为False。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ Refer to the [generation_config](https://huggingface.co/docs/transformers/main_c
112112
- top_p: The top_p parameter, defaults to None. It is read from generation_config.json.
113113
- repetition_penalty: The repetition penalty. Defaults to None and is read from generation_config.json.
114114
- num_beams: The number of beams reserved for parallel beam search, default is 1.
115-
- 🔥stream: Stream output, default is `False`.
115+
- 🔥stream: Streaming output. Default is `None`, which means it is set to True when using the interactive interface and False during batch inference on datasets.
116+
- For "ms-swift<3.6", the default value of stream is False.
116117
- stop_words: Additional stop words beyond eos_token, default is`[]`.
117118
- Note: eos_token will be removed in the output response, whereas additional stop words will be retained in the output.
118119
- logprobs: Whether to output logprobs, default is False.

swift/llm/argument/app_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class AppArguments(WebUIArguments, DeployArguments):
1919

2020
lang: Literal['en', 'zh'] = 'en'
2121
verbose: bool = False
22+
stream: bool = True
2223

2324
def _init_torch_dtype(self) -> None:
2425
if self.base_url:

swift/llm/argument/base_args/base_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def __post_init__(self):
153153
self._init_custom_register()
154154
self._import_external_plugins()
155155
self._init_model_kwargs()
156+
self._init_stream()
156157
# The Seq2SeqTrainingArguments has a property called world_size, which cannot be assigned a value.
157158
self.rank, self.local_rank, self.global_world_size, self.local_world_size = get_dist_setting()
158159
logger.info(f'rank: {self.rank}, local_rank: {self.local_rank}, '

swift/llm/argument/base_args/generation_args.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,15 @@ class GenerationArguments:
3232
repetition_penalty: Optional[float] = None
3333
num_beams: int = 1
3434

35-
stream: bool = False
35+
stream: Optional[bool] = None
3636
stop_words: List[str] = field(default_factory=list)
3737
logprobs: bool = False
3838
top_logprobs: Optional[int] = None
3939

40+
def _init_stream(self):
41+
if self.stream is None:
42+
self.stream = False
43+
4044
def get_request_config(self):
4145
if getattr(self, 'task_type') != 'causal_lm':
4246
return

swift/llm/argument/deploy_args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from swift.llm import safe_snapshot_download
66
from swift.utils import find_free_port, get_logger
7+
from .base_args import BaseArguments
78
from .infer_args import InferArguments
89

910
logger = get_logger()
@@ -66,7 +67,7 @@ def _init_ckpt_dir(self, adapters=None):
6667
return super()._init_ckpt_dir(self.adapters + list(self.adapter_mapping.values()))
6768

6869
def _init_stream(self):
69-
pass
70+
return BaseArguments._init_stream(self)
7071

7172
def _init_eval_human(self):
7273
pass

swift/llm/argument/infer_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ def _init_result_path(self, folder_name: str) -> None:
180180

181181
def _init_stream(self):
182182
self.eval_human = not (self.dataset and self.split_dataset_ratio > 0 or self.val_dataset)
183-
183+
if self.stream is None:
184+
self.stream = self.eval_human
184185
if self.stream and self.num_beams != 1:
185186
self.stream = False
186187
logger.info('Setting args.stream: False')
@@ -199,7 +200,6 @@ def __post_init__(self) -> None:
199200
VllmArguments.__post_init__(self)
200201
self._init_result_path('infer_result')
201202
self._init_eval_human()
202-
self._init_stream()
203203
self._init_ddp()
204204

205205
def _init_eval_human(self):

swift/llm/infer/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from dataclasses import dataclass, field
55
from typing import List, Literal, Optional
66

7+
from swift.llm.utils import update_generation_config_eos_token
78
from swift.plugin import extra_tuners
89
from swift.tuners import Swift
910
from swift.utils import get_logger
@@ -144,4 +145,5 @@ def prepare_model_template(args, **kwargs):
144145
model, processor = args.get_model_processor(**kwargs)
145146
model = prepare_adapter(args, model)
146147
template = args.get_template(processor)
148+
update_generation_config_eos_token(model.generation_config, template)
147149
return model, template

swift/llm/template/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,8 @@ def encode(self,
501501
lengths += value
502502
if self.is_training:
503503
encoded['length'] = max(lengths)
504+
else:
505+
encoded.pop('length', None)
504506
if return_template_inputs:
505507
encoded['template_inputs'] = inputs
506508
if not self.remove_unused_columns:

swift/llm/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,21 @@ def get_ckpt_dir(model_dir: str, adapters_dir: Optional[List[str]]) -> str:
300300
ckpt_dir = model_dir
301301
break
302302
return ckpt_dir
303+
304+
305+
def update_generation_config_eos_token(generation_config, template):
306+
stop_words = template.template_meta.stop_words
307+
eos_token_id = generation_config.eos_token_id
308+
if isinstance(eos_token_id, int):
309+
eos_token_id = [eos_token_id]
310+
modified = False
311+
for stop_word in stop_words:
312+
if stop_word is None:
313+
continue
314+
if isinstance(stop_word, str):
315+
stop_word = template._tokenize(stop_word)
316+
if isinstance(stop_word, (list, tuple)) and len(stop_word) == 1 and stop_word[0] not in eos_token_id:
317+
eos_token_id.append(stop_word[0])
318+
modified = True
319+
if modified:
320+
generation_config.eos_token_id = eos_token_id

0 commit comments

Comments
 (0)