We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b03e086 commit ccc62cbCopy full SHA for ccc62cb
lmdeploy/pytorch/kernels/ascend/rms_norm.py
@@ -3,5 +3,13 @@
3
from torch import Tensor
4
5
6
-def rms_norm(hidden_states: Tensor, weight: Tensor, epsilon: float = 1e-6):
7
- return ext_ops.rms_norm(hidden_states, weight, epsilon)
+def rms_norm(hidden_states: Tensor,
+ weight: Tensor,
8
+ eps: float = 1e-6,
9
+ out: Tensor = None):
10
+ rms_norm_out = ext_ops.rms_norm(hidden_states, weight, eps)
11
+ if out is None:
12
+ out = rms_norm_out
13
+ else:
14
+ out.copy_(rms_norm_out)
15
+ return rms_norm_out
0 commit comments