Skip to content

Cuda graph runner #10595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/gpu/cpp_extensions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ void SetPreidsTokenPenaltyMultiScores(const paddle::Tensor& pre_ids,

void UpdateInputesV2(const paddle::Tensor& stop_flags,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop, // cpu
paddle::Tensor& not_need_stop, // cpu
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
Expand Down
12 changes: 10 additions & 2 deletions csrc/gpu/update_inputs_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ __global__ void update_inputs_kernel_v2(

void UpdateInputesV2(const paddle::Tensor& stop_flags,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop, // cpu
paddle::Tensor& not_need_stop, // cpu
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
Expand All @@ -126,6 +126,7 @@ void UpdateInputesV2(const paddle::Tensor& stop_flags,
const int input_ids_stride = input_ids.shape()[1];
const int end_length = end_ids.shape()[0];

std::cout << "before: not_need_stop data_prt:" << not_need_stop.data() << const_cast<bool*>(not_need_stop.data<bool>())[0];
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);

update_inputs_kernel_v2<1024><<<1, 1024, 0, input_ids.stream()>>>(
Expand All @@ -147,10 +148,17 @@ void UpdateInputesV2(const paddle::Tensor& stop_flags,
input_ids_stride,
end_length
);

// func 0
auto not_need_stop_cpu = not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool *not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
std::cout << "after func1: not_need_stop data_prt:" << not_need_stop.data() << " data: " << const_cast<bool*>(not_need_stop.data<bool>())[0];

// func 1
not_need_stop.copy_(not_need_stop_gpu, not_need_stop.place(), false);
std::cout << "after func2: not_need_stop data_prt:" << not_need_stop.data() << " data: " << const_cast<bool*>(not_need_stop.data<bool>())[0];

// func 2
}

PD_BUILD_OP(update_inputs_v2)
Expand Down
1 change: 1 addition & 0 deletions csrc/setup_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def get_gencode_flags():
sources += ["./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu"]
nvcc_compile_args += ["-gencode", "arch=compute_89,code=compute_89"]
elif cc >= 90:
os.environ.pop('PADDLE_CUDA_ARCH_LIST', None)
sources += [
"./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu",
"./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu",
Expand Down
6 changes: 3 additions & 3 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class PredictorArgument:
metadata={"help": "Quantization type of moe. Supported values: weight_only_int4, weight_only_int8"},
)
output_via_mq: bool = field(
default=True,
default=False,
metadata={"help": "Controls whether the message queue is enabled for output"},
)
dynamic_insert: bool = field(default=False, metadata={"help": "whether use dynamic insert"})
Expand Down Expand Up @@ -929,7 +929,7 @@ def init_model_inputs(self, config: PredictorArgument):
)
self.model_inputs["bad_tokens"] = paddle.to_tensor([-1], dtype="int64")
self.model_inputs["is_block_step"] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool")
self.model_inputs["msg_queue_id"] = paddle.full(shape=[1], fill_value=self.msg_queue_id, dtype="int32").cpu()
self.model_inputs["msg_queue_id"] = paddle.full(shape=[1], fill_value=self.msg_queue_id, dtype="int32").pin_memory()

# bloom model needs src_mask and tgt_mask!
if "bloom" in self.architectures:
Expand Down Expand Up @@ -1045,7 +1045,7 @@ def _preprocess(self, input_text: list[str] = None, input_ids: list[list[int]] =
shape=[self.config.batch_size, 1], fill_value=0, dtype="int32"
)
self.model_inputs["step_idx"] = paddle.full(shape=[self.config.batch_size, 1], fill_value=0, dtype="int64")
self.model_inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=True, dtype="bool").cpu() # cpu
self.model_inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=True, dtype="bool").pin_memory() # pinned memory
self.model_inputs["stop_flags"] = paddle.full(
shape=[self.config.batch_size, 1], fill_value=False, dtype="bool"
)
Expand Down
228 changes: 216 additions & 12 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import os
from typing import List, Union

import paddle
import paddle.nn.functional as F
from paddle import nn

from paddlenlp.generation import GenerationMixin, LogitsProcessor, LogitsProcessorList

Expand All @@ -29,6 +29,165 @@ def use_faster_top_p_sampling():
return os.getenv("USE_FASTER_TOP_P_SAMPLING", "False") in ["True", "1", "true"]


class PostProcessLayer(nn.Layer):
def __init__(self):
super().__init__()

def forward(self,
encoder_outputs : paddle.Tensor,
**kwargs
):
"""Explicitly passing tensor type input"""
step_idx = kwargs["step_idx"]
logits = paddle.cast(encoder_outputs, paddle.float32)

from paddlenlp_ops import set_preids_token_penalty_multi_scores

set_preids_token_penalty_multi_scores(
kwargs["pre_ids"],
kwargs["input_ids"],
kwargs["seq_lens_encoder"],
kwargs["seq_lens_decoder"],
step_idx,
kwargs["stop_flags"],
logits,
kwargs["penalty_score"],
kwargs["frequency_score"],
kwargs["presence_score"],
kwargs["temperature"],
kwargs["bad_tokens"],
step_idx,
kwargs["min_dec_len"],
kwargs["eos_token_id"],
)

# sample
probs = F.softmax(logits)

# compute next_tokens
if use_faster_top_p_sampling():
from paddlenlp_ops import top_p_sampling_reject

next_tokens = top_p_sampling_reject(probs, kwargs["top_p"], 0)
else:
_, next_tokens = paddle.tensor.top_p_sampling(probs, kwargs["top_p"])

if kwargs["tensor_parallel_degree"] > 1:
paddle.distributed.broadcast(next_tokens, 0)

with paddle.base.framework._stride_in_no_check_dy2st_diff():
from paddlenlp_ops import update_inputs_v2

update_inputs_v2(
kwargs["stop_flags"],
kwargs["step_idx"],
kwargs["not_need_stop"],
kwargs["seq_lens_this_time"],
kwargs["seq_lens_encoder"],
kwargs["seq_lens_decoder"],
kwargs["max_dec_len"],
kwargs["input_ids"],
kwargs["stop_nums"],
next_tokens,
kwargs["is_block_step"],
kwargs["eos_token_id"],
kwargs["next_tokens"],
)

# if kwargs["dynamic_insert"]:
# print("save_output_dygraph")
# from paddlenlp_ops import save_output_dygraph

# save_output_dygraph(
# kwargs["all_token_ids"], next_tokens, kwargs["result_id"], kwargs["step_idx"]
# )
# elif kwargs["output_via_mq"]:
# print("output_via_mq")
# from paddlenlp_ops import save_output
# save_output(
# next_tokens,
# kwargs["not_need_stop"],
# kwargs["msg_queue_id"],
# kwargs["tensor_parallel_rank"]
# )
return next_tokens

class CudaGraphRunner(nn.Layer):
def __init__(self, model: nn.Layer):
super().__init__()
self.model = model

self.input_buffers: Dict[str, paddle.Tensor] = {}
self.output_buffers: Dict[str, paddle.Tensor] = {}

self._graph:Optional[paddle.device.cuda.graphs.CUDAGraph] = None

self._NUM_WARMUP_ITERS = 2

def graph(self):
assert self._graph is not None
return self._graph

def graph_is_none(self):
return self._graph is None

def prepare_input_buffer(self, **kwargs) -> None:
""" """
for (name, value) in kwargs.items():
if type(value) == paddle.Tensor:
if (name not in self.input_buffers.keys()):
self.input_buffers[name] = paddle.zeros_like(value)
self.input_buffers[name].copy_(value, False)
else:
self.input_buffers[name] = value
# print(f"parameter name: {name}, value: {value}, buffer value: {self.input_buffers[name]}")
self.output_buffers["origional_stop_tensor"] = kwargs["not_need_stop"]

def capture(self, graph_inputs, **kwargs) -> None:
assert self._graph is None
# prepare input buffer
self.input_buffers["graph_inputs"] = paddle.zeros_like(graph_inputs)
self.input_buffers["graph_inputs"].copy_(graph_inputs, False)
self.prepare_input_buffer(**kwargs)

paddle.device.synchronize() # 8 卡都要同步
for _ in range(self._NUM_WARMUP_ITERS):
self.model(encoder_outputs = self.input_buffers["graph_inputs"], **kwargs)

self._graph = paddle.device.cuda.graphs.CUDAGraph()
self._graph.capture_begin()
model_out_put = self.model(encoder_outputs = self.input_buffers["graph_inputs"], **kwargs)
self._graph.capture_end()

intermediate_output = paddle.zeros_like(model_out_put)
model_out_put._share_buffer_to(intermediate_output)
model_out_put._clear()

paddle.device.synchronize()
# process output buffer
self.output_buffers["output_ids"] = intermediate_output

self._graph.print_to_dot_files("/root/paddlejob/workspace/env_run/output/gongshaotian/Log/cudaGraph/test_r1", 1 << 0)

def forward(self, graph_inputs, **kwargs) -> paddle.Tensor:
print("---------------------- start cuda graph runner replay ----------------------")
try:
# copy input_tensors to input_buffers
self.input_buffers["graph_inputs"].copy_(graph_inputs, False)
self.prepare_input_buffer(**kwargs)

self._graph.replay()

# process output buffer
# self.output_buffers["origional_stop_tensor"].copy_(self.input_buffers["not_need_stop"], False)
print("kwargs['not_need_stop']", kwargs['not_need_stop'])
print("input_buffers['not_need_stop]", self.input_buffers["not_need_stop"])
return self.output_buffers["output_ids"]
except :
pass



class ForcedDecodingEOSTokenLogitsProcessor(LogitsProcessor):
"""
This `LogitsProcessor` enforces the last generated token to be the selected `forced_eos_token`.
Expand Down Expand Up @@ -407,6 +566,11 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):


class GenerationBlockInferenceModel(GenerationMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
_post_process_layer = PostProcessLayer()
self.cuda_graph_runner = CudaGraphRunner(_post_process_layer)

@classmethod
def get_cache_kvs_shape(cls, max_batch_size: int = None, max_length: int = None) -> list[list[int]]:
raise NotImplementedError
Expand Down Expand Up @@ -705,6 +869,13 @@ def _post_process_(
temperature,
model_kwargs,
):
# print type
print(type(outputs), type(top_k), type(top_p), type(penalty_score), type(frequency_score), type(presence_score), type(temperature))
for (key, value) in model_kwargs.items():
print(key, type(value), value.place)
print(type(self.config.tensor_parallel_degree), type(self.config.tensor_parallel_rank))

# origin code
step_idx = model_kwargs["step_idx"]
logits = paddle.cast(outputs, paddle.float32)

Expand Down Expand Up @@ -780,18 +951,51 @@ def _post_process_(

# encoder
outputs = _forward_(**model_kwargs) # [bs, 1, dim_embed]
# first decoder
next_tokens = _post_process_(
outputs,
top_k,
top_p,
penalty_score,
frequency_score,
presence_score,
temperature,
model_kwargs,
)

# check shape
print(f"for_ward_ outputs :{outputs.shape, outputs.place}")

# first decoder
if self.cuda_graph_runner.graph_is_none():
self.cuda_graph_runner.capture(
outputs,
top_k=top_k,
top_p=top_p,
penalty_score=penalty_score,
frequency_score=frequency_score,
presence_score=presence_score,
temperature=temperature,
tensor_parallel_degree=self.config.tensor_parallel_degree,
tensor_parallel_rank=self.config.tensor_parallel_rank,
dynamic_insert=self.config.dynamic_insert,
output_via_mq=self.config.output_via_mq,
eos_token_id=eos_token_id,
**model_kwargs)
next_tokens = self.cuda_graph_runner(
outputs,
top_k=top_k,
top_p=top_p,
penalty_score=penalty_score,
frequency_score=frequency_score,
presence_score=presence_score,
temperature=temperature,
tensor_parallel_degree=self.config.tensor_parallel_degree,
tensor_parallel_rank=self.config.tensor_parallel_rank,
dynamic_insert=self.config.dynamic_insert,
output_via_mq=self.config.output_via_mq,
eos_token_id=eos_token_id,
**model_kwargs)
print(f"next token:{next_tokens}")
# next_tokens = _post_process_(
# outputs,
# top_k,
# top_p,
# penalty_score,
# frequency_score,
# presence_score,
# temperature,
# model_kwargs,
# )
return next_tokens

def speculate_decoding(
Expand Down