Skip to content

Commit ccc62cb

Browse files
authored
fix: fix rms_norm params (#18)
1 parent b03e086 commit ccc62cb

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

lmdeploy/pytorch/kernels/ascend/rms_norm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,13 @@
33
from torch import Tensor
44

55

6-
def rms_norm(hidden_states: Tensor, weight: Tensor, epsilon: float = 1e-6):
7-
return ext_ops.rms_norm(hidden_states, weight, epsilon)
6+
def rms_norm(hidden_states: Tensor,
7+
weight: Tensor,
8+
eps: float = 1e-6,
9+
out: Tensor = None):
10+
rms_norm_out = ext_ops.rms_norm(hidden_states, weight, eps)
11+
if out is None:
12+
out = rms_norm_out
13+
else:
14+
out.copy_(rms_norm_out)
15+
return rms_norm_out

0 commit comments

Comments
 (0)