Skip to content

MaskedLMHead should support dtype=bfloat16 #1195

Closed
@g-dspencer

Description

@g-dspencer

Describe the bug

I claim that MaskedLMHead should support a dtype argument of tf.bfloat16 (and tf.float16) so that users can look at the effect of reducing their memory usage. This matters more as the vocab gets larger.

To Reproduce

In google corp colab I do "File -> Save a copy as GitHub Gist", enter an OTP, and then there is a message that "github auth fails" so I'll just include the code inline:

!pip install keras-nlp --upgrade --quiet

import tensorflow as tf
import keras_nlp

# Based on test_valid_call()
# https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/layers/modeling/masked_lm_head_test.py#L25
def test_dtype(dtype):
  head = keras_nlp.layers.MaskedLMHead(
      vocabulary_size=100,
      activation="softmax",
      dtype=dtype, # this is the point
  )
  encoded_tokens = tf.keras.Input(shape=(10, 16))
  positions = tf.keras.Input(shape=(5,), dtype="int32")
  outputs = head(encoded_tokens, mask_positions=positions)
  model = tf.keras.Model((encoded_tokens, positions), outputs)

  token_data = tf.random.uniform(shape=(4, 10, 16))
  position_data = tf.random.uniform(minval=0, maxval=10, shape=(4, 5), dtype=tf.int32)
  model((token_data, position_data))

  for w in head.weights:
      assert w.dtype == dtype, ("Wrong type: " + w.name)
      # When it fails it fails with:
      # TypeError: Input 'y' of 'AddV2' Op has type float16 that does not match type float32 of argument 'x'.

print("float32")
test_dtype(tf.float32) # this works

print("bfloat16")
test_dtype(tf.bfloat16) # this fails

print("float64")
test_dtype(tf.float64)

Expected behavior

Lack of a crash.
The loop checking dtypes (assert w.dtype == dtype, ("Wrong type: " + w.name)) should arguably pass - unless we
are hitting some subtle case of wanting mixed types.

Additional context

The error I get is:

TypeError: Exception encountered when calling layer "masked_lm_head_1" (type MaskedLMHead).

in user code:

    File "/usr/local/lib/python3.10/dist-packages/keras_nlp/src/layers/modeling/masked_lm_head.py", line 196, in call  *
        outputs = outputs + self._bias

    TypeError: Input 'y' of 'AddV2' Op has type bfloat16 that does not match type float32 of argument 'x'.


Call arguments received by layer "masked_lm_head_1" (type MaskedLMHead):
  • inputs=tf.Tensor(shape=(None, 10, 16), dtype=bfloat16)
  • mask_positions=tf.Tensor(shape=(None, 5), dtype=int32)

and I suspect we need to pass in a few dtype= parameters in the code.

Would you like to help us fix it?
yes

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:BugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions