-
Notifications
You must be signed in to change notification settings - Fork 287
Fix mixed precision for BART #1121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -193,7 +193,7 @@ def __init__( | |||
# Use token embedding weights to project from the token representation | |||
# to vocabulary logits. | |||
outputs = tf.matmul( | |||
x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we actually want this the other way. We should cast the variable to the compute dtype (the lower precision type when using mixed), before multiplying with x
. So tf.cast(backbone.token_embedding.embeddings, x.dtype)
.
Can you check if that fixes things as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, why does GPT-2 output tf.float32
instead of tf.float16
, then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In what context does it? There might be some casting going on to the outputs.
My understanding is that we should be casting our variables to the compute_dtype
. Take a look at
https://github.com/keras-team/keras/blob/v2.12.0/keras/engine/base_layer.py#L2216-L2235
and https://github.com/keras-team/keras/blob/v2.12.0/keras/mixed_precision/autocast_variable.py
But maybe I am missing something!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPT2 backbone outputs tf.float32. BART backbone outputs tf.float16. Weird why that's happening.
https://colab.research.google.com/drive/18AbKIwbUAtJySgAYWXa0ggvRbp5TriEa?usp=sharing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually may no longer be necessary with the port. We had to wrap things in a Layer to avoid some errors with keras-core, which might mean that variables are now autocasting automatically (I think it should).
Can you try again on the latest from master?
We think this is no longer relevant with the |
With
,
this error shows up: