Skip to content

Commit 9d353ca

Browse files
authored
Merge pull request bitsandbytes-foundation#87 from lostmsu/main
Add `device` and `dtype` parameters to `StableEmbedding`
2 parents 7a6563b + 62d39a2 commit 9d353ca

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

bitsandbytes/nn/modules.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def __init__(
2525
scale_grad_by_freq: bool = False,
2626
sparse: bool = False,
2727
_weight: Optional[Tensor] = None,
28+
device=None,
29+
dtype=None,
2830
) -> None:
2931
super().__init__(
3032
num_embeddings,
@@ -35,8 +37,10 @@ def __init__(
3537
scale_grad_by_freq,
3638
sparse,
3739
_weight,
40+
device,
41+
dtype,
3842
)
39-
self.norm = torch.nn.LayerNorm(embedding_dim)
43+
self.norm = torch.nn.LayerNorm(embedding_dim, device=device)
4044
GlobalOptimManager.get_instance().register_module_override(
4145
self, "weight", {"optim_bits": 32}
4246
)
@@ -68,7 +72,10 @@ def forward(self, input: Tensor) -> Tensor:
6872
self.sparse,
6973
)
7074

71-
return self.norm(emb)
75+
# always apply layer norm in full precision
76+
emb = emb.to(torch.get_default_dtype())
77+
78+
return self.norm(emb).to(self.weight.dtype)
7279

7380

7481
class Embedding(torch.nn.Embedding):

0 commit comments

Comments
 (0)