Skip to content

Commit ab649f5

Browse files
authored
Do the reverse embedding in the same dtype as the input embedding (keras-team#1548)
1 parent c157ac2 commit ab649f5

File tree

4 files changed

+6
-10
lines changed

4 files changed

+6
-10
lines changed

keras_nlp/layers/modeling/reversible_embedding.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ class ReversibleEmbedding(keras.layers.Embedding):
4848
mask_zero: Boolean, whether or not the input value 0 is a special
4949
"padding" value that should be masked out.
5050
reverse_dtype: The dtype for the reverse projection computation.
51-
For stability, it is usually best to use full precision even when
52-
working with half or mixed precision training.
51+
Defaults to the `compute_dtype` of the layer.
5352
**kwargs: other keyword arguments passed to `keras.layers.Embedding`,
5453
including `name`, `trainable`, `dtype` etc.
5554
@@ -90,7 +89,7 @@ def __init__(
9089
embeddings_regularizer=None,
9190
embeddings_constraint=None,
9291
mask_zero=False,
93-
reverse_dtype="float32",
92+
reverse_dtype=None,
9493
**kwargs,
9594
):
9695
super().__init__(
@@ -122,8 +121,9 @@ def call(self, inputs, reverse=False):
122121
kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
123122
else:
124123
kernel = self.reverse_embeddings
125-
inputs = ops.cast(inputs, self.reverse_dtype)
126-
kernel = ops.cast(kernel, self.reverse_dtype)
124+
if self.reverse_dtype is not None:
125+
inputs = ops.cast(inputs, self.reverse_dtype)
126+
kernel = ops.cast(kernel, self.reverse_dtype)
127127
return ops.matmul(inputs, kernel)
128128

129129
return super().call(inputs)

keras_nlp/models/llama/llama_backbone.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def __init__(
109109
tie_weights=False,
110110
embeddings_initializer=_llama_kernel_initializer(stddev=0.01),
111111
dtype=dtype,
112-
reverse_dtype=dtype,
113112
name="token_embedding",
114113
)
115114
self.transformer_layers = []

keras_nlp/models/mistral/mistral_backbone.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ def __init__(
121121
tie_weights=False,
122122
embeddings_initializer=_mistral_kernel_initializer(stddev=0.01),
123123
dtype=dtype,
124-
reverse_dtype=dtype,
125124
name="token_embedding",
126125
)
127126
self.transformer_layers = []

keras_nlp/samplers/sampler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,8 @@ def compute_probabilities(self, logits):
145145
This will always be done in full precision, regardless of dtype, and
146146
scale by `temperature`.
147147
"""
148-
logits_dtype = logits.dtype
149148
logits = ops.cast(logits, "float32")
150-
probs = keras.activations.softmax(logits / self.temperature)
151-
return ops.cast(probs, logits_dtype)
149+
return keras.activations.softmax(logits / self.temperature)
152150

153151
def run_loop(
154152
self, cond, body, model=None, loop_vars=None, maximum_iterations=None

0 commit comments

Comments
 (0)