-
Notifications
You must be signed in to change notification settings - Fork 287
Adding GPTNeoXBackbone
#1056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding GPTNeoXBackbone
#1056
Changes from 1 commit
9412a83
99a8296
afb7e1f
f0f6383
bfd56fa
97a347d
5ead767
5776ac1
e0d343b
451cdbc
e37fb22
ead11c5
c7117a4
c72e629
2341d0e
6112357
66afa7c
f363f24
7a66052
6f6f41e
f34ec47
34db7f7
1ecfe51
b3f06e4
a9f2230
122a3fb
e10ea50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,96 +14,138 @@ | |
import tensorflow as tf | ||
from tensorflow import keras | ||
|
||
from keras_nlp.layers.transformer_layer_utils import compute_causal_mask | ||
from keras_nlp.models.gpt_neox.rotary_embedding import RotaryEmbedding | ||
from keras_nlp.utils.keras_utils import clone_initializer | ||
|
||
|
||
class GPTNeoXAttention(keras.layers.Layer): | ||
def __init__( | ||
self, | ||
num_heads, | ||
hidden_dim, | ||
rotary_pct=0.25, | ||
max_position_embeddings=2048, | ||
dropout=0.1, | ||
max_position_embeddings=512, | ||
kernel_initializer="glorot_uniform", | ||
bias_initializer="zeros", | ||
**kwargs | ||
): | ||
|
||
super().__init__() | ||
self.num_heads = num_heads | ||
self.hidden_dim = hidden_dim | ||
self.head_dim = hidden_dim // num_heads | ||
self.rotary_dim = self.head_dim * rotary_pct | ||
# self.rotary_pct = 4 | ||
self.dropout = dropout | ||
self.attn_head_size = hidden_dim // num_heads | ||
self.rotary_dim = self.attn_head_size * self.hidden_dim | ||
self.max_position_embeddings = max_position_embeddings | ||
self.rotary_embedding = RotaryEmbedding(self.rotary_pct) | ||
self.qkv = keras.layers.Dense(3 * self.hidden_dim) | ||
self.dense = keras.layers.Dense(self.hidden_dim) | ||
self.rotary_embedding = RotaryEmbedding(self.attn_head_size) | ||
|
||
def _compute_attention( | ||
self, query, key, value, attention_mask=None, head_mask=None | ||
): | ||
|
||
batch_size, _, query_len, _ = tf.shape(query) | ||
key_len = tf.shape(key)[-2] | ||
# causal_mask = self.bias[:, :, key_len - query_len : key_len, :key_len] | ||
causal_mask = compute_causal_mask(batch_size, key_len, key_len) | ||
self._kernel_initializer = keras.initializers.get(kernel_initializer) | ||
self._bias_initializer = keras.initializers.get(bias_initializer) | ||
|
||
query = tf.reshape( | ||
query, [batch_size * self.num_heads, query_len, self.head_dim] | ||
self._query_dense = keras.layers.EinsumDense( | ||
equation="abc,cde->abde", | ||
output_shape=(None, self.num_heads, self.attn_head_size), | ||
bias_axes="de", | ||
**self._get_common_kwargs_for_sublayer(use_bias=True), | ||
name="attention_output", | ||
) | ||
key = tf.reshape( | ||
key, [batch_size * self.num_heads, query_len, self.head_dim] | ||
self._key_dense = keras.layers.EinsumDense( | ||
equation="abc,cde->abde", | ||
output_shape=(None, self.num_heads, self.attn_head_size), | ||
bias_axes="de", | ||
**self._get_common_kwargs_for_sublayer(use_bias=True), | ||
name="key", | ||
) | ||
attention_scores = tf.zeros( | ||
[batch_size * self.num_heads, query_len, self.head_dim], | ||
dtype=query.dtype, | ||
self._value_dense = keras.layers.EinsumDense( | ||
equation="abc,cde->abde", | ||
output_shape=(None, self.num_heads, self.attn_head_size), | ||
bias_axes="de", | ||
**self._get_common_kwargs_for_sublayer(use_bias=True), | ||
name="value", | ||
) | ||
|
||
attention_scores = tf.linalg.matmul( | ||
attention_scores, | ||
query, | ||
tf.transpose(key, perm=[0, 2, 1]), | ||
beta=1.0, | ||
alpha=(tf.constant(1.0)), | ||
self._attn_dropout_layer = keras.layers.Dropout( | ||
self.dropout, name="attention_dropout" | ||
) | ||
attention_scores = tf.reshape( | ||
attention_scores, [batch_size, self.num_heads, query_len, key_len] | ||
|
||
self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax") | ||
|
||
# Output. | ||
self._output_dense = keras.layers.EinsumDense( | ||
equation="abc,cd->abd", | ||
output_shape=(None, self.hidden_dim), | ||
bias_axes="d", | ||
**self._get_common_kwargs_for_sublayer(use_bias=True), | ||
name="attention_output", | ||
) | ||
mask_value = tf.constant(float("-inf"), dtype=attention_scores.dtype) | ||
attention_scores = tf.where(causal_mask, attention_scores, mask_value) | ||
|
||
def _get_common_kwargs_for_sublayer(self, use_bias=True): | ||
common_kwargs = {} | ||
|
||
kernel_initializer = clone_initializer(self._kernel_initializer) | ||
bias_initializer = clone_initializer(self._bias_initializer) | ||
|
||
common_kwargs["kernel_initializer"] = kernel_initializer | ||
if use_bias: | ||
common_kwargs["bias_initializer"] = bias_initializer | ||
|
||
return common_kwargs | ||
|
||
def _masked_softmax(self, attention_scores, attention_mask=None): | ||
# print(attention_scores[0].shape, attention_scores[1].shape) | ||
# print(attention_mask.shape, attention_scores.shape) | ||
if attention_mask is not None: | ||
attention_scores += attention_mask | ||
mask_expansion_axis = -3 | ||
for _ in range( | ||
attention_scores.shape.rank - attention_mask.shape.rank | ||
): | ||
attention_mask = tf.expand_dims( | ||
attention_mask, axis=mask_expansion_axis | ||
) | ||
return self._softmax(attention_scores, attention_mask) | ||
|
||
attention_scores = tf.cast( | ||
tf.nn.softmax(attention_scores, axis=-1), dtype=value.dtype | ||
) | ||
def _compute_attention( | ||
self, query, key, value, attention_mask=None, training=None | ||
): | ||
|
||
attention_scores = tf.einsum("aecd,abcd->acbe", key, query) | ||
|
||
# batch_size, _, key_len, _ = tf.shape(key) | ||
# causal_mask = compute_causal_mask(batch_size, key_len, key_len) | ||
# attention_mask = attention_mask & causal_mask | ||
# mask_value = tf.constant(float('-inf'), dtype=attention_scores.dtype) | ||
# attention_scores = tf.where(causal_mask, attention_scores, mask_value) | ||
|
||
if head_mask is not None: | ||
attention_scores *= head_mask | ||
# print(attention_scores[0].shape, attention_scores[1].shape) | ||
# if attention_mask is not None: | ||
# attention_scores += attention_mask | ||
|
||
attention_scores = self._masked_softmax( | ||
attention_scores, attention_mask | ||
) | ||
attention_scores = self._attn_dropout_layer( | ||
attention_scores, training=training | ||
) | ||
attention_output = tf.einsum("acbe,aecd->abcd", attention_scores, value) | ||
|
||
attention_output = tf.matmul(attention_scores, value) | ||
return attention_output, attention_scores | ||
|
||
def call( | ||
self, | ||
hidden_states, | ||
attention_mask, | ||
head_mask, | ||
layer_past, | ||
return_attention_scores, | ||
return_attention_scores=False, | ||
training=None, | ||
): | ||
|
||
qkv = self.qkv(hidden_states) | ||
new_qkv_shape = tf.shape(hidden_states)[:-1] + [ | ||
self.num_heads, | ||
self.head_dim, | ||
] | ||
qkv = tf.reshape(qkv, new_qkv_shape) | ||
query = self._query_dense(hidden_states) | ||
key = self._key_dense(hidden_states) | ||
value = self._value_dense(hidden_states) | ||
|
||
query = tf.transpose(qkv[..., : self.head_dim], (0, 2, 1, 3)) | ||
key = tf.transpose( | ||
qkv[..., : self.head_dim : 2 * self.head_dim], (0, 2, 1, 3) | ||
) | ||
value = tf.transpose(qkv[..., self.head_dim :], (0, 2, 1, 3)) | ||
# query = tf.transpose(query, (0, 2, 1, 3)) | ||
# key = tf.transpose(key, (0, 2, 1, 3)) | ||
# value = tf.transpose(value, (0, 2, 1, 3)) | ||
|
||
query_rot, query_pass = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we would be better off moving this slice and concat logic into the
And the rotary embedding layer could also hold the percentage argument, which would conceptually be quite clean. Looks like falcon is doing this roughly -> https://huggingface.co/tiiuae/falcon-40b/blob/main/modelling_RW.py There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for this wonderful suggestion ! |
||
query[..., : self.rotary_dim], | ||
|
@@ -115,26 +157,29 @@ def call( | |
) | ||
|
||
query, key = self.rotary_embedding(query_rot, key_rot) | ||
query = tf.concat((query, query_pass), dim=-1) | ||
key = tf.concat((key, key_pass), dim=-1) | ||
query = tf.concat((query, query_pass), axis=-1) | ||
key = tf.concat((key, key_pass), axis=-1) | ||
|
||
if layer_past is not None: | ||
past_key, past_value = layer_past | ||
key = tf.concat((past_key, key), axis=-2) | ||
value = tf.concat((past_value, value), axis=-2) | ||
# if layer_past is not None: | ||
# past_key, past_value = layer_past | ||
# key = tf.concat((past_key, key), axis=-2) | ||
# value = tf.concat((past_value, value), axis=-2) | ||
|
||
attention_output, attention_scores = self._compute_attention( | ||
query, key, value, attention_mask, head_mask | ||
) | ||
new_shape = tf.shape(attention_output)[:-2] + ( | ||
self.num_heads * self.head_dim | ||
query, key, value, attention_mask, training | ||
) | ||
|
||
# Reshape `attention_output` to `(batch_size, sequence_length, hidden_dim)`. | ||
attention_output = tf.reshape( | ||
tf.transpose(attention_output, (0, 2, 1, 3)), new_shape | ||
attention_output, | ||
[ | ||
tf.shape(attention_output)[0], | ||
tf.shape(attention_output)[1], | ||
self.hidden_dim, | ||
], | ||
) | ||
attention_output = self.dense(attention_output) | ||
attention_output = self._output_dense(attention_output) | ||
|
||
if return_attention_scores: | ||
return (attention_output, attention_scores) | ||
|
||
return attention_output, attention_scores | ||
return attention_output |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ def build(self, input_shape): | |
self.inverse_freq = self.add_weight( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually looking at this, this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @mattdangerw ! We can definitely do that, but I would like you to take a look at this. https://github.com/huggingface/transformers/blob/17a55534f5e5df10ac4804d4270bf6b8cc24998d/src/transformers/models/esm/modeling_tf_esm.py#L102-L107 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I follow the comment totally. The issue looks to be a precision one, but if the goal is to keep these explicitly as float32, why not just compute them on the fly with an explicit float32 dtype? I still don't understand the need for a variable. And the fact that this is a trainable seems incorrect looking at the torch implementation, these are not trainable in torch. In general, I would be careful attempting to apply what seems like a fairly technical point about esm checkpoints to other models. Ideally we would just check how close our forward pass outputs are for the actual pythia checkpoints under fully precision (float32 everywhere), and mixed precision (float32 for variables, float16 for computations), and use that to determine our approach here. |
||
"inverse_freq", shape=(self.hidden_dim // 2,), dtype=tf.float32 | ||
) | ||
|
||
self.inverse_freq.assign( | ||
kanpuriyanawab marked this conversation as resolved.
Show resolved
Hide resolved
|
||
1.0 | ||
/ ( | ||
|
@@ -43,27 +44,24 @@ def build(self, input_shape): | |
) | ||
) | ||
|
||
@staticmethod | ||
def apply_rotary_pos_emb(cls, tensor, cos_emb, sin_emb): | ||
|
||
cos_emb = cos_emb[:, :, : tf.shape(tensor)[-2], :] | ||
sin_emb = sin_emb[:, :, : tf.shape(tensor)[-2], :] | ||
|
||
def apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): | ||
cos_emb = cos_emb[:, : tf.shape(tensor)[1], :, :] | ||
sin_emb = sin_emb[:, : tf.shape(tensor)[1], :, :] | ||
x1, x2 = tf.split(tensor, 2, axis=-1) | ||
half_rot_tensor = tf.concat((-x2, x1), axis=-1) | ||
# Incompatible shapes: [32,256,8,2] vs. [1,256,1,16] [Op:Mul] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remember to cleanup little notes like this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done ! |
||
ret = (tensor * cos_emb) + (half_rot_tensor * sin_emb) | ||
return ret | ||
|
||
return (tensor * cos_emb) + (half_rot_tensor * sin_emb) | ||
|
||
def _compute_cos_sin(self, x, seq_dim=2): | ||
def _compute_cos_sin_embedding(self, x, seq_dim=1): | ||
seq_len = tf.shape(x)[seq_dim] | ||
tensor = tf.range(seq_len, dtype=self.inverse_freq.dtype) | ||
freqs = tf.einsum("i, j -> ij", tensor, self.inverse_freq) | ||
embedding = tf.concat((freqs, freqs), axis=-1)[None, None, :, :] | ||
embedding = tf.concat((freqs, freqs), axis=-1)[None, :, None, :] | ||
return tf.cos(embedding), tf.sin(embedding) | ||
|
||
def call(self, query, key): | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cos_emb, sin_emb = self._compute_cos_sin(key, seq_dim=-2) | ||
return ( | ||
self.apply_rotary_pos_emb(query, cos_emb, sin_emb), | ||
self.apply_rotary_pos_emb(key, cos_emb, sin_emb), | ||
) | ||
cos_emb, sin_emb = self._compute_cos_sin_embedding(key, seq_dim=1) | ||
q_emb = self.apply_rotary_pos_emb(query, cos_emb, sin_emb) | ||
k_emb = self.apply_rotary_pos_emb(key, cos_emb, sin_emb) | ||
return q_emb, k_emb |
Uh oh!
There was an error while loading. Please reload this page.