Skip to content

Fix mixed precision for BART #1121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

abheesht17
Copy link
Collaborator

@abheesht17 abheesht17 commented Jul 9, 2023

With

policy = keras.mixed_precision.Policy("mixed_float16")
keras.mixed_precision.set_global_policy(policy)

,

this error shows up:

TypeError: Exception encountered when calling layer "tf.linalg.matmul" (type TFOpLambda).

Input 'y' of 'BatchMatMulV2' Op has type float32 that does not match type float16 of argument 'x'.

Call arguments received by layer "tf.linalg.matmul" (type TFOpLambda):
  • a=tf.Tensor(shape=(None, None, 768), dtype=float16)
  • b=<tf.Variable 'token_embedding/embeddings:0' shape=(50265, 768) dtype=float32>
  • transpose_a=False
  • transpose_b=True
  • adjoint_a=False
  • adjoint_b=False
  • a_is_sparse=False
  • b_is_sparse=False
  • output_type=None
  • name=None

@abheesht17 abheesht17 requested a review from mattdangerw July 9, 2023 04:39
@@ -193,7 +193,7 @@ def __init__(
# Use token embedding weights to project from the token representation
# to vocabulary logits.
outputs = tf.matmul(
x,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we actually want this the other way. We should cast the variable to the compute dtype (the lower precision type when using mixed), before multiplying with x. So tf.cast(backbone.token_embedding.embeddings, x.dtype).

Can you check if that fixes things as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, why does GPT-2 output tf.float32 instead of tf.float16, then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what context does it? There might be some casting going on to the outputs.

My understanding is that we should be casting our variables to the compute_dtype. Take a look at
https://github.com/keras-team/keras/blob/v2.12.0/keras/engine/base_layer.py#L2216-L2235
and https://github.com/keras-team/keras/blob/v2.12.0/keras/mixed_precision/autocast_variable.py

But maybe I am missing something!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPT2 backbone outputs tf.float32. BART backbone outputs tf.float16. Weird why that's happening.

https://colab.research.google.com/drive/18AbKIwbUAtJySgAYWXa0ggvRbp5TriEa?usp=sharing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually may no longer be necessary with the port. We had to wrap things in a Layer to avoid some errors with keras-core, which might mean that variables are now autocasting automatically (I think it should).

Can you try again on the latest from master?

@mattdangerw
Copy link
Member

We think this is no longer relevant with the ReverseEmbedding layer. Closing, can reopen if more issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants