Skip to content

support_dynamic_server for wintx #10589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
Prev Previous commit
Next Next commit
Fix bug in triton code
  • Loading branch information
lixcli committed May 13, 2025
commit 89404c7276d3922d89148e796c7f1b0f3ab764a0
3 changes: 1 addition & 2 deletions paddlenlp/experimental/wintx/wintx_fused_moe_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import triton.language as tl
from paddle import _C_ops
from paddle.base.framework import OpProtoHolder
from paddle.framework import in_dynamic_or_pir_mode

Check warning on line 19 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L15-L19

Added lines #L15 - L19 were not covered by tests

from paddlenlp.ops.triton_ops.triton_utils import (

Check warning on line 21 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L21

Added line #L21 was not covered by tests
get_dtype_str,
paddle_use_triton,
rendering_common_template,
)

__all__ = ["fused_moe_wintx_decode_wint2_75", "fused_moe_wintx_decode_wint2_5"]
BLOCK_SIZE_M = 16

Check warning on line 28 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L27-L28

Added lines #L27 - L28 were not covered by tests


def invoke_fused_moe_kernel(

Check warning on line 31 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L31

Added line #L31 was not covered by tests
A,
B,
C,
Expand All @@ -50,16 +50,16 @@

# bit_shift = paddle.to_tensor([4,2,0],dtype='int8')

KK = A.shape[-1]
NN = B.shape[-1]
EEM = sorted_token_ids.shape[0]
sstride_am, sstride_ak = A.shape[1], 1
sstride_be, sstride_bk, sstride_bn = B.shape[1] * B.shape[2], B.shape[2], 1
sstride_cm, sstride_cn = C.shape[-1], 1
sstride_bse, sstride_bsk, sstride_bsn = B_scale.shape[1], 1, 1
nnum_valid_tokens = topk_ids.numel().tolist()

Check warning on line 60 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L53-L60

Added lines #L53 - L60 were not covered by tests

prepare_attr_for_triton_kernel = """

Check warning on line 62 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L62

Added line #L62 was not covered by tests
auto N = B.shape()[2];
auto K = A.shape()[1];
auto EM = sorted_token_ids.shape()[0];
Expand All @@ -82,26 +82,25 @@
auto bzp = bbzp;
"""

config = {

Check warning on line 85 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L85

Added line #L85 was not covered by tests
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4,
}
configs = []

Check warning on line 92 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L92

Added line #L92 was not covered by tests

configs.append(dict(config))

Check warning on line 94 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L94

Added line #L94 was not covered by tests

op_name = f"fused_moe_paddle_wintx_{ppack_num}_{ww_mask}_{ss_mask}_{bbzp}"
op_name += f"{get_dtype_str(A.dtype)}"
op_name += f"{B.shape[0]}"
op_name += f"{B.shape[1]}"
op_name += f"{B.shape[2]}"

Check warning on line 100 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L96-L100

Added lines #L96 - L100 were not covered by tests

if op_name not in OpProtoHolder.instance().op_proto_map.keys():
prepare_ptr_for_triton_kernel = """

Check warning on line 103 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L102-L103

Added lines #L102 - L103 were not covered by tests
CUdeviceptr input_ptrs[9] = {
get_tensor_ptr(A),
get_tensor_ptr(B),
Expand All @@ -114,14 +113,14 @@
get_tensor_ptr(bit_shift),
};
"""
template_used = rendering_common_template(

Check warning on line 116 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L116

Added line #L116 was not covered by tests
invoke_fused_moe_kernel,
prepare_attr_for_triton_kernel,
prepare_ptr_for_triton_kernel,
)
grid = ("(EM+BLOCK_SIZE_M-1)/BLOCK_SIZE_M * ((N+BLOCK_SIZE_N-1)/BLOCK_SIZE_N)",)

Check warning on line 121 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L121

Added line #L121 was not covered by tests

fused_moe_decode_kernel_paddle[(op_name, template_used, grid, configs)](

Check warning on line 123 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L123

Added line #L123 was not covered by tests
A,
B,
C,
Expand All @@ -147,14 +146,14 @@
sstride_bsn,
MUL_ROUTED_WEIGHT=(int)(mul_routed_weight),
top_k=top_k,
BLOCK_SIZE_K=group_size,
BLOCK_SIZE_K=group_size, # must equal to group_size for this kernel
pack_num=ppack_num,
w_mask=ww_mask,
s_mask=ss_mask,
bzp=bbzp,
)
if in_dynamic_or_pir_mode():
outs = _C_ops._run_custom_op(

Check warning on line 156 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L155-L156

Added lines #L155 - L156 were not covered by tests
op_name,
A,
B,
Expand All @@ -174,13 +173,13 @@
ss_mask,
bbzp,
)
return outs[0]

Check warning on line 176 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L176

Added line #L176 was not covered by tests


@paddle_use_triton(

Check warning on line 179 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L179

Added line #L179 was not covered by tests
key=["1"],
)
def fused_moe_decode_kernel_paddle(

Check warning on line 182 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L182

Added line #L182 was not covered by tests
# Pointers to matrices
a_ptr,
b_ptr,
Expand Down Expand Up @@ -252,80 +251,80 @@
multiplication across different blocks processed by the same expert.
"""

real_k_size: tl.constexpr = (BLOCK_SIZE_K - 1) // pack_num + 1

Check warning on line 254 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L254

Added line #L254 was not covered by tests

pid = tl.program_id(axis=0)

Check warning on line 256 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L256

Added line #L256 was not covered by tests

pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

Check warning on line 266 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L258-L266

Added lines #L258 - L266 were not covered by tests

# maybe more efficient by set bf16
compute_type = c_ptr.dtype.element_ty

Check warning on line 269 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L269

Added line #L269 was not covered by tests

num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)

Check warning on line 275 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L271-L275

Added lines #L271 - L275 were not covered by tests

token_mask = offs_token < num_valid_tokens

Check warning on line 277 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L277

Added line #L277 was not covered by tests

offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)

Check warning on line 281 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L279-L281

Added lines #L279 - L281 were not covered by tests

off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] // pack_num * stride_bk + offs_bn[None, :] * stride_bn)

Check warning on line 284 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L283-L284

Added lines #L283 - L284 were not covered by tests

# maybe more efficient by eliminate load process
b_shift_bits = tl.load(bit_shift_ptr + offs_k[:, None] % pack_num)

Check warning on line 287 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L287

Added line #L287 was not covered by tests

bs_ptrs = bs_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn

Check warning on line 289 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L289

Added line #L289 was not covered by tests

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

Check warning on line 291 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L291

Added line #L291 was not covered by tests

super_bs = tl.load(bs_ptrs) # super scale
scale_idx = tl.arange(0, BLOCK_SIZE_K)

Check warning on line 294 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L293-L294

Added lines #L293 - L294 were not covered by tests

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):

Check warning on line 296 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L296

Added line #L296 was not covered by tests

a = tl.load(

Check warning on line 298 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L298

Added line #L298 was not covered by tests
a_ptrs,
mask=token_mask[:, None],
other=0.0,
)
b = tl.load(b_ptrs)

Check warning on line 303 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L303

Added line #L303 was not covered by tests

# maybe more efficent
bs = tl.sum(tl.where(scale_idx[:, None] == BLOCK_SIZE_K - 1, b, 0), 0)
bs = (bs & s_mask)[None, :] * super_bs
b = (((b >> b_shift_bits) & w_mask) - bzp) * bs
accumulator += tl.dot(a, b.to(a.dtype))

Check warning on line 309 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L306-L309

Added lines #L306 - L309 were not covered by tests

b_ptrs += real_k_size * stride_bk
a_ptrs += BLOCK_SIZE_K * stride_ak

Check warning on line 312 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L311-L312

Added lines #L311 - L312 were not covered by tests

if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]

Check warning on line 316 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L314-L316

Added lines #L314 - L316 were not covered by tests

accumulator = accumulator.to(compute_type)

Check warning on line 318 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L318

Added line #L318 was not covered by tests
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)

Check warning on line 324 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L321-L324

Added lines #L321 - L324 were not covered by tests


def fused_moe_wintx_decode_impl(

Check warning on line 327 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L327

Added line #L327 was not covered by tests
hidden_states,
w1,
w2,
Expand All @@ -338,52 +337,52 @@
bit="wint2.75",
):

assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert group_size > 0, "Group size must be greater than 0"

Check warning on line 344 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L340-L344

Added lines #L340 - L344 were not covered by tests

num_tokens, K = hidden_states.shape
E, _, N = w1.shape
M = num_tokens

Check warning on line 348 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L346-L348

Added lines #L346 - L348 were not covered by tests

if group_size < 0:
group_size = K // w1_scale.shape[1]

Check warning on line 351 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L350-L351

Added lines #L350 - L351 were not covered by tests

top_k = topk_ids.shape[1]

Check warning on line 353 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L353

Added line #L353 was not covered by tests

intermediate_cache1 = paddle.zeros(

Check warning on line 355 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L355

Added line #L355 was not covered by tests
[M, top_k, N],
dtype=hidden_states.dtype,
)
intermediate_cache2 = paddle.zeros(

Check warning on line 359 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L359

Added line #L359 was not covered by tests
(M * top_k, N // 2),
dtype=hidden_states.dtype,
)
intermediate_cache3 = paddle.zeros(

Check warning on line 363 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L363

Added line #L363 was not covered by tests
(M, top_k, K),
dtype=hidden_states.dtype,
)

from paddlenlp_ops import preprocess_for_moe

Check warning on line 368 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L368

Added line #L368 was not covered by tests

sorted_token_ids, expert_ids, num_tokens_post_padded = preprocess_for_moe(topk_ids, E, BLOCK_SIZE_M)

Check warning on line 370 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L370

Added line #L370 was not covered by tests

if bit == "wint2.75":
bit_shift = paddle.to_tensor([4, 2, 0], dtype="int8")
ppack_num = 3
ww_mask = 0xF
ss_mask = 0xF
bbzp = 8
elif bit == "wint2.5":
ppack_num = 7
ww_mask = 0x7
ss_mask = 0x1FFF
bbzp = 4
bit_shift = paddle.to_tensor([13, 11, 9, 6, 4, 2, 0], dtype="int16")

Check warning on line 383 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L372-L383

Added lines #L372 - L383 were not covered by tests

invoke_fused_moe_kernel(

Check warning on line 385 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L385

Added line #L385 was not covered by tests
A=hidden_states,
B=w1,
C=intermediate_cache1,
Expand All @@ -403,9 +402,9 @@
bbzp=bbzp,
)

intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1.reshape([-1, N]))

Check warning on line 405 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L405

Added line #L405 was not covered by tests

invoke_fused_moe_kernel(

Check warning on line 407 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L407

Added line #L407 was not covered by tests
A=intermediate_cache2,
B=w2,
C=intermediate_cache3,
Expand All @@ -425,15 +424,15 @@
bbzp=bbzp,
)

out_hidden_states = paddle.sum(intermediate_cache3, axis=1)

Check warning on line 427 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L427

Added line #L427 was not covered by tests

del intermediate_cache1, intermediate_cache2, intermediate_cache3
del sorted_token_ids, expert_ids, num_tokens_post_padded

Check warning on line 430 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L429-L430

Added lines #L429 - L430 were not covered by tests

return out_hidden_states

Check warning on line 432 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L432

Added line #L432 was not covered by tests


def fused_moe_wintx_decode_wint2_75(

Check warning on line 435 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L435

Added line #L435 was not covered by tests
hidden_states,
w1,
w2,
Expand All @@ -443,9 +442,9 @@
w2_scale=None,
):

topk_weights, topk_ids = paddle.topk(scores, k=topk, axis=-1, sorted=False)

Check warning on line 445 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L445

Added line #L445 was not covered by tests

return fused_moe_wintx_decode_impl(

Check warning on line 447 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L447

Added line #L447 was not covered by tests
hidden_states,
w1,
w2,
Expand All @@ -458,7 +457,7 @@
)


def fused_moe_wintx_decode_wint2_5(

Check warning on line 460 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L460

Added line #L460 was not covered by tests
hidden_states,
w1,
w2,
Expand All @@ -468,9 +467,9 @@
w2_scale=None,
):

topk_weights, topk_ids = paddle.topk(scores, k=topk, axis=-1, sorted=False)

Check warning on line 470 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L470

Added line #L470 was not covered by tests

return fused_moe_wintx_decode_impl(

Check warning on line 472 in paddlenlp/experimental/wintx/wintx_fused_moe_decode.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/wintx/wintx_fused_moe_decode.py#L472

Added line #L472 was not covered by tests
hidden_states,
w1,
w2,
Expand Down
Loading