Skip to content

Commit 7c1c6a7

Browse files
committed
fix rotary_embedding
1 parent 9acae55 commit 7c1c6a7

File tree

1 file changed

+6
-37
lines changed

1 file changed

+6
-37
lines changed

lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ..rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters,
66
RopeType, RotaryEmbeddingBuilder,
77
RotaryEmbeddingImpl, YarnParameters)
8+
from ..default.rotary_embedding import LlamaDynamicNTKScalingRotaryEmbedding, Llama3RotaryEmbeddingImpl
89

910

1011
class DlinferRotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module):
@@ -23,11 +24,6 @@ def __init__(self,
2324
self.dim)).float().cuda()
2425
self.register_buffer('inv_freq', inv_freq, persistent=False)
2526

26-
def dump_tensor(self, name, t):
27-
import pickle
28-
with open(f'/tzy/dev_ops/{name}.pkl', 'wb') as f:
29-
pickle.dump(t.cpu(), f)
30-
3127
def forward(self, x, position_ids):
3228
"""forward."""
3329
# x: [bs, num_attention_heads, seq_len, head_size]
@@ -47,7 +43,6 @@ def forward(self, x, position_ids):
4743
device_type = x.device.type
4844
device_type = device_type if isinstance(
4945
device_type, str) and device_type != 'mps' else 'cpu'
50-
# with torch.autocast(device_type=device_type, enabled=False):
5146
inv_freq_expanded = inv_freq_expanded
5247
position_ids_expanded = position_ids_expanded
5348
tmp = torch.bmm(inv_freq_expanded, position_ids_expanded)
@@ -78,37 +73,11 @@ def build(
7873
elif emb_type == RopeType.DynamicNTKScaling:
7974
return LlamaDynamicNTKScalingRotaryEmbedding(
8075
dim, base, scaling_factor, max_position_embeddings)
76+
elif emb_type == RopeType.Llama3:
77+
return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor,
78+
llama3_params.low_freq_factor,
79+
llama3_params.high_freq_factor,
80+
max_position_embeddings)
8181
else:
8282
raise NotImplementedError(
8383
f'Unsupported embedding type: {emb_type}')
84-
85-
86-
class LlamaDynamicNTKScalingRotaryEmbedding(RotaryEmbeddingImpl):
87-
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling.
88-
89-
Credits to the Reddit users /u/bloc97 and /u/emozilla
90-
"""
91-
92-
def __init__(self,
93-
dim: int,
94-
base: int = 10000,
95-
scaling_factor: float = 1.0,
96-
max_position_embeddings: int = 2048):
97-
super().__init__(dim, base, scaling_factor)
98-
self.max_position_embeddings = max_position_embeddings
99-
100-
def forward(self, x, position_ids):
101-
"""forward."""
102-
seq_len = torch.max(position_ids) + 1
103-
if seq_len > self.max_position_embeddings:
104-
base = self.base * ((self.scaling_factor * seq_len /
105-
self.max_position_embeddings) -
106-
(self.scaling_factor - 1))**(self.dim /
107-
(self.dim - 2))
108-
inv_freq = 1.0 / (base**(torch.arange(
109-
0, self.dim, 2, dtype=torch.int64).float().to(x.device) /
110-
self.dim))
111-
self.register_buffer('inv_freq', inv_freq, persistent=False)
112-
113-
cos, sin = super().forward(x, position_ids)
114-
return cos, sin

0 commit comments

Comments
 (0)