5
5
from ..rotary_embedding import (Llama3Parameters , LongRoPEScalingParameters ,
6
6
RopeType , RotaryEmbeddingBuilder ,
7
7
RotaryEmbeddingImpl , YarnParameters )
8
+ from ..default .rotary_embedding import LlamaDynamicNTKScalingRotaryEmbedding , Llama3RotaryEmbeddingImpl
8
9
9
10
10
11
class DlinferRotaryEmbeddingImpl (RotaryEmbeddingImpl , nn .Module ):
@@ -23,11 +24,6 @@ def __init__(self,
23
24
self .dim )).float ().cuda ()
24
25
self .register_buffer ('inv_freq' , inv_freq , persistent = False )
25
26
26
- def dump_tensor (self , name , t ):
27
- import pickle
28
- with open (f'/tzy/dev_ops/{ name } .pkl' , 'wb' ) as f :
29
- pickle .dump (t .cpu (), f )
30
-
31
27
def forward (self , x , position_ids ):
32
28
"""forward."""
33
29
# x: [bs, num_attention_heads, seq_len, head_size]
@@ -47,7 +43,6 @@ def forward(self, x, position_ids):
47
43
device_type = x .device .type
48
44
device_type = device_type if isinstance (
49
45
device_type , str ) and device_type != 'mps' else 'cpu'
50
- # with torch.autocast(device_type=device_type, enabled=False):
51
46
inv_freq_expanded = inv_freq_expanded
52
47
position_ids_expanded = position_ids_expanded
53
48
tmp = torch .bmm (inv_freq_expanded , position_ids_expanded )
@@ -78,37 +73,11 @@ def build(
78
73
elif emb_type == RopeType .DynamicNTKScaling :
79
74
return LlamaDynamicNTKScalingRotaryEmbedding (
80
75
dim , base , scaling_factor , max_position_embeddings )
76
+ elif emb_type == RopeType .Llama3 :
77
+ return Llama3RotaryEmbeddingImpl (dim , base , scaling_factor ,
78
+ llama3_params .low_freq_factor ,
79
+ llama3_params .high_freq_factor ,
80
+ max_position_embeddings )
81
81
else :
82
82
raise NotImplementedError (
83
83
f'Unsupported embedding type: { emb_type } ' )
84
-
85
-
86
- class LlamaDynamicNTKScalingRotaryEmbedding (RotaryEmbeddingImpl ):
87
- """LlamaRotaryEmbedding extended with Dynamic NTK scaling.
88
-
89
- Credits to the Reddit users /u/bloc97 and /u/emozilla
90
- """
91
-
92
- def __init__ (self ,
93
- dim : int ,
94
- base : int = 10000 ,
95
- scaling_factor : float = 1.0 ,
96
- max_position_embeddings : int = 2048 ):
97
- super ().__init__ (dim , base , scaling_factor )
98
- self .max_position_embeddings = max_position_embeddings
99
-
100
- def forward (self , x , position_ids ):
101
- """forward."""
102
- seq_len = torch .max (position_ids ) + 1
103
- if seq_len > self .max_position_embeddings :
104
- base = self .base * ((self .scaling_factor * seq_len /
105
- self .max_position_embeddings ) -
106
- (self .scaling_factor - 1 ))** (self .dim /
107
- (self .dim - 2 ))
108
- inv_freq = 1.0 / (base ** (torch .arange (
109
- 0 , self .dim , 2 , dtype = torch .int64 ).float ().to (x .device ) /
110
- self .dim ))
111
- self .register_buffer ('inv_freq' , inv_freq , persistent = False )
112
-
113
- cos , sin = super ().forward (x , position_ids )
114
- return cos , sin
0 commit comments