File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -25,6 +25,8 @@ def __init__(
2525 scale_grad_by_freq : bool = False ,
2626 sparse : bool = False ,
2727 _weight : Optional [Tensor ] = None ,
28+ device = None ,
29+ dtype = None ,
2830 ) -> None :
2931 super ().__init__ (
3032 num_embeddings ,
@@ -35,8 +37,10 @@ def __init__(
3537 scale_grad_by_freq ,
3638 sparse ,
3739 _weight ,
40+ device ,
41+ dtype ,
3842 )
39- self .norm = torch .nn .LayerNorm (embedding_dim )
43+ self .norm = torch .nn .LayerNorm (embedding_dim , device = device )
4044 GlobalOptimManager .get_instance ().register_module_override (
4145 self , "weight" , {"optim_bits" : 32 }
4246 )
@@ -68,7 +72,10 @@ def forward(self, input: Tensor) -> Tensor:
6872 self .sparse ,
6973 )
7074
71- return self .norm (emb )
75+ # always apply layer norm in full precision
76+ emb = emb .to (torch .get_default_dtype ())
77+
78+ return self .norm (emb ).to (self .weight .dtype )
7279
7380
7481class Embedding (torch .nn .Embedding ):
You can’t perform that action at this time.
0 commit comments