Skip to content

[feat]add loadtimequantization modelloader #2711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/get_started/ernie-4.5.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,21 @@ python -m fastdeploy.entrypoints.openai.api_server \
--max-num-seqs 32
```

To speed up model loading, set the environment variable **export FD_USE_FASTSAFETENSOR=1** and use the **--load_format "load_time_quantization"** option.

```shell
export FD_USE_FASTSAFETENSOR=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
--port 8180 --engine-worker-queue-port 8181 \
--cache-queue-port 8182 --metrics-port 8182 \
--tensor-parallel-size 8 \
--quantization wint4 \
--max-model-len 32768 \
--max-num-seqs 32 \
--load_format "load_time_quantization"
```

## Request the Service
After starting the service, the following output indicates successful initialization:

Expand Down
23 changes: 21 additions & 2 deletions docs/quantization/online_quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ python -m fastdeploy.entrypoints.openai.api_server \

- By specifying `--model baidu/ERNIE-4.5-300B-A47B-Paddle`, the model can be automatically downloaded from AIStudio. FastDeploy depends on Paddle format models. For more information, please refer to [Supported Model List](../supported_models.md).
- By setting `--quantization` to `wint8` or `wint4`, online INT8/INT4 quantization can be selected.
- Deploying ERNIE-4.5-300B-A47B-Paddle WINT8 requires at least 80G * 8 cards, while WINT4 requires 80GB * 4 cards.
- Deploying ERNIE-4.5-300B-A47B-Paddle WINT8 requires at least 80G *8 cards, while WINT4 requires 80GB* 4 cards.
- For more deployment tutorials, please refer to [get_started](../get_started/ernie-4.5.md).

## 2. Block-wise FP8
Expand All @@ -51,4 +51,23 @@ python -m fastdeploy.entrypoints.openai.api_server \
- By specifying `--model baidu/ERNIE-4.5-300B-A47B-Paddle`, the model can be automatically downloaded from AIStudio. FastDeploy depends on Paddle format models. For more information, please refer to [Supported Model List](../supported_models.md).
- By setting `--quantization` to `block_wise_fp8`, online Block-wise FP8 quantization can be selected.
- Deploying ERNIE-4.5-300B-A47B-Paddle Block-wise FP8 requires at least 80G * 8 cards.
- For more deployment tutorials, please refer to [get_started](../get_started/ernie-4.5.md)
- For more deployment tutorials, please refer to [get_started](../get_started/ernie-4.5.md)

# LoadTimeQuantization
To speed up loading with FastSafeTensor and load large bfloat16 models onto the GPU, we shifted quantization to the weight loading stage and performed it dynamically. This supports quantization formats such as INT4, INT8, and FP8.

## 1. Run loadtimequant modelloader
To speed up model loading, set the environment variable **export FD_USE_FASTSAFETENSOR=1** and use the **--load_format "load_time_quantization"** option.

```
export FD_USE_FASTSAFETENSOR=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
--port 8180 --engine-worker-queue-port 8181 \
--cache-queue-port 8182 --metrics-port 8182 \
--tensor-parallel-size 8 \
--quantization wint8 \
--max-model-len 32768 \
--max-num-seqs 32\
--load_format "load_time_quantization"
```
15 changes: 15 additions & 0 deletions docs/zh/get_started/ernie-4.5.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ python -m fastdeploy.entrypoints.openai.api_server \
--max-num-seqs 32
```

可以通过设置环境变量 **export FD_USE_FASTSAFETENSOR=1** 并添加参数 **--load_format "load_time_quantization"**,提升权重load速度,

```shell
export FD_USE_FASTSAFETENSOR=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
--port 8180 --engine-worker-queue-port 8181 \
--cache-queue-port 8182 --metrics-port 8182 \
--tensor-parallel-size 8 \
--quantization wint4 \
--max-model-len 32768 \
--max-num-seqs 32 \
--load_format "load_time_quantization"
```

## 用户发起服务请求
执行启动服务指令后,当终端打印如下信息,说明服务已经启动成功。

Expand Down
23 changes: 20 additions & 3 deletions docs/zh/quantization/online_quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ python -m fastdeploy.entrypoints.openai.api_server \
```

- 通过指定 `--model baidu/ERNIE-4.5-300B-A47B-Paddle` 可自动从AIStudio下载模型。FastDeploy依赖Paddle格式的模型,更多说明参考[支持模型列表](../supported_models.md)。
- 通过设置 `--quantization` 为 `wint8` 或 `wint4` 选择在线 INT8/INT4 量化。
- 部署 ERNIE-4.5-300B-A47B-Paddle WINT8 最少需要 80G * 8卡, WINT4 则需要 80GB * 4卡。
- 通过设置 `--quantization` 为 `wint8` 或 `wint4` 选择在线 INT8/INT4 量化。
- 部署 ERNIE-4.5-300B-A47B-Paddle WINT8 最少需要 80G *8卡, WINT4 则需要 80GB* 4卡。
- 更多部署教程请参考[get_started](../get_started/ernie-4.5.md).

## 2. Block-wise FP8
Expand All @@ -49,9 +49,26 @@ python -m fastdeploy.entrypoints.openai.api_server \
```

- 通过指定 `--model baidu/ERNIE-4.5-300B-A47B-Paddle` 可自动从AIStudio下载模型。FastDeploy依赖Paddle格式的模型,更多说明参考[支持模型列表](../supported_models.md)。
- 通过设置 `--quantization` 为 `block_wise_fp8` 选择在线 Block-wise FP8 量化。
- 通过设置 `--quantization` 为 `block_wise_fp8` 选择在线 Block-wise FP8 量化。
- 部署 ERNIE-4.5-300B-A47B-Paddle Block-wise FP8 最少需要 80G * 8卡。
- 更多部署教程请参考[get_started](../get_started/ernie-4.5.md)

# LoadTimeQuantization
为了使用fastsafeTensor提升load权重性能,并将300B的模型load进gpu,我们提供了一种新的modelloder(loadtimequantization),可以在load权重的同时进行动态量化,
该modelloder支持INT4、INT8、FP8动态量化。

## 1. 使用 loadtimequantization modelloader
你可以通过增加环境变量**export FD_USE_FASTSAFETENSOR=1** 开启fastsafetensor,并通过传参数**--load_format "load_time_quantization"** 开启加载权重时量化

```
export FD_USE_FASTSAFETENSOR=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
--port 8180 --engine-worker-queue-port 8181 \
--cache-queue-port 8182 --metrics-port 8182 \
--tensor-parallel-size 8 \
--quantization wint8 \
--max-model-len 32768 \
--max-num-seqs 32\
--load_format "load_time_quantization"
```
10 changes: 9 additions & 1 deletion fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from dataclasses import dataclass, field
from enum import Enum
from typing import Literal, Optional
from typing import Any, Dict, Literal, Optional, Union

from paddleformers.transformers.configuration_utils import PretrainedConfig

Expand Down Expand Up @@ -55,6 +55,7 @@ class ModelConfig(PretrainedConfig):
frequency_score = 0.0
presence_score = 0.0
min_length = 1
weight_infos_dict: Dict[str, Any] = {}

def __init__(
self,
Expand Down Expand Up @@ -343,6 +344,12 @@ def __init__(self,
self.graph_opt_level = 1


class LoadFormat(str, Enum):
"""LoadFormat"""
DEFAULT = "default"
LoadTimeQuant = "load_time_quantization"


@dataclass
class LoadConfig:
"""
Expand All @@ -357,6 +364,7 @@ class LoadConfig:
- 'meta': provide RL traing worker, no_weights_load
- None: No dynamic loading
"""
load_format: Union[str, LoadFormat] = LoadFormat.DEFAULT.value
use_fastsafetensor: bool = False
dynamic_load_weight: bool = False
load_strategy: Optional[Literal['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta']] = None
Expand Down
16 changes: 16 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,14 @@ class EngineArgs:
max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64].
"""

load_format: str = "default"
"""The format of the model weights to load.
Options include:
- "default": default loader.
-"load_time_quantization": Quantization applied during model loading, \
such as INT8, INT4, or FP8 formats.
"""

def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
Expand Down Expand Up @@ -413,6 +421,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help=
"Disabled any whitespaces when using guided decoding backend XGrammar."
)
# Load group
load_group = parser.add_argument_group("Load Configuration")
load_group.add_argument("--load_format",
type=str,
default=EngineArgs.load_format,
help="The format of the model weights to load.\
default/load_time_quantization.")

# Parallel processing parameters group
parallel_group = parser.add_argument_group("Parallel Configuration")
Expand Down Expand Up @@ -784,4 +799,5 @@ def create_engine_config(self) -> Config:
max_capture_batch_size=self.max_capture_batch_size,
guided_decoding_backend=self.guided_decoding_backend,
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
load_format=self.load_format,
)
6 changes: 5 additions & 1 deletion fastdeploy/engine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ class Config:
splitwise_role (str): Splitwise role.
innode_prefill_ports (Optional[List[int]]): Innode prefill ports.
Temporary configuration, will be removed in the future.
load_format(str):The format of the model weights to load. .Default is default
"""

def __init__(
Expand Down Expand Up @@ -526,6 +527,7 @@ def __init__(
max_capture_batch_size: int = 64,
guided_decoding_backend: Optional[str] = None,
disable_any_whitespace: bool = False,
load_format: str = "default",
):
"""
Initialize the Config class.
Expand Down Expand Up @@ -554,6 +556,7 @@ def __init__(
guided_decoding_backend(str): Guided decoding backend. Default is None.
disable_any_whitespace(bool): Disable any whitespace when using guided decoding.
Default is False.
load_format(str):The format of the model weights to load. .Default is default
"""
self.model_config = model_config
self.cache_config = cache_config
Expand Down Expand Up @@ -585,7 +588,8 @@ def __init__(
self.is_master = True
self._str_to_list("innode_prefill_ports", int)
self._str_to_list("pod_ips", str)

self.load_format = load_format

if self.pod_ips is None:
self.nnode = 1
else:
Expand Down
7 changes: 4 additions & 3 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,8 +998,8 @@ def _start_worker_service(self):
py_script = os.path.join(current_dir_path, worker_path)

ori_vocab_size = (
len(self.data_processor.tokenizer.sp_model)
if hasattr(self.data_processor.tokenizer, 'sp_model')
len(self.data_processor.tokenizer.sp_model)
if hasattr(self.data_processor.tokenizer, 'sp_model')
else len(self.data_processor.tokenizer.vocab)
)

Expand Down Expand Up @@ -1032,7 +1032,8 @@ def _start_worker_service(self):
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
f" --max_capture_batch_size {self.cfg.max_capture_batch_size}"
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
f" --load_strategy {self.cfg.model_config.load_strategy}")
f" --load_strategy {self.cfg.model_config.load_strategy}"
f" --load_format {self.cfg.load_format}")

worker_append_flag = {
"enable_expert_parallel":
Expand Down
14 changes: 11 additions & 3 deletions fastdeploy/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,17 @@ def __init__(
self.shift = shift
self.smooth = smooth
self.quant_scale = quant_scale
self.quant_round_type = fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
self.quant_max_bound = fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
self.quant_min_bound = fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
if fd_config.quant_config:
self.quant_round_type = fd_config.quant_config.get_quant_method(
self).quant_round_type
self.quant_max_bound = fd_config.quant_config.get_quant_method(
self).quant_max_bound
self.quant_min_bound = fd_config.quant_config.get_quant_method(
self).quant_min_bound
else:
self.quant_round_type = 0
self.quant_max_bound = 0
self.quant_min_bound = 0

self._dtype = self._helper.get_default_dtype()
if self._dtype == "bfloat16":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,24 @@ def create_weights(self, layer: nn.Layer) -> None:
is_bias=False,
)

def process_loaded_weights(self, layer: nn.Layer,
weight: paddle.Tensor) -> None:
def process_quantized_weights(self, layer, state_dict) -> None:
"""process_quantized_weights"""
# (tangbinhan:todo) quant_utils support xpu
layer.linear_weight.set_value(state_dict.pop(layer.weight_key))
layer.linear_weight_scale.set_value(state_dict.pop(layer.weight_scale))

def apply_weight_quantization(self, weight):
"""apply_weight_quantization"""
quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(
weight, self.quant_config.algo, -1, -1)
return quanted_weight_tensor, weight_scale_tensor

def process_unquantized_weights(self, layer, weight) -> None:
"""
loaded_weights using xpu special quantization
"""
quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(
weight, self.quant_config.algo, -1, -1)
quanted_weight_tensor, weight_scale_tensor = self.apply_weight_quantization(
weight)
layer.linear_weight.set_value(
paddle.transpose(quanted_weight_tensor, [1, 0]))
layer.linear_weight_scale.set_value(weight_scale_tensor)
Expand Down
Loading