Skip to content

Adding GPTNeoXBackbone #1056

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jun 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9412a83
added gpt-neo attention+decoder+backbone
kanpuriyanawab May 29, 2023
99a8296
fixed formatting + added backbone test
kanpuriyanawab May 29, 2023
afb7e1f
fixed rotary embedding and gpt neo attention layer
kanpuriyanawab Jun 6, 2023
f0f6383
updating decoder and backbone to current version
kanpuriyanawab Jun 6, 2023
bfd56fa
fixed decoder + backbone
kanpuriyanawab Jun 7, 2023
97a347d
fix forward pass
kanpuriyanawab Jun 10, 2023
5ead767
formatting + add checkpoint script
kanpuriyanawab Jun 10, 2023
5776ac1
fix tpu_test, formatting
kanpuriyanawab Jun 10, 2023
e0d343b
removed unnecessary layernorms, correct arguments, fix unit tests (te…
kanpuriyanawab Jun 12, 2023
451cdbc
fix dropout
kanpuriyanawab Jun 12, 2023
e37fb22
matching outputs with hf
kanpuriyanawab Jun 14, 2023
ead11c5
fix formating
kanpuriyanawab Jun 14, 2023
c7117a4
resolving few comments
kanpuriyanawab Jun 14, 2023
c72e629
fixed unit tests + formatting
kanpuriyanawab Jun 16, 2023
2341d0e
refactored rotary embedding
kanpuriyanawab Jun 16, 2023
6112357
revamped checkpoint conversion script
kanpuriyanawab Jun 16, 2023
66afa7c
code format
kanpuriyanawab Jun 16, 2023
f363f24
putting old checkpoint script back until preset
kanpuriyanawab Jun 16, 2023
7a66052
incorporated comments
kanpuriyanawab Jun 17, 2023
6f6f41e
code format
kanpuriyanawab Jun 17, 2023
f34ec47
resolved comments + fixed formatting
kanpuriyanawab Jun 17, 2023
34db7f7
added gpt neo x tokenizer
kanpuriyanawab Jun 17, 2023
1ecfe51
added docstrings
kanpuriyanawab Jun 21, 2023
b3f06e4
formatting fix
kanpuriyanawab Jun 21, 2023
a9f2230
addressing comments
kanpuriyanawab Jun 23, 2023
122a3fb
added tokenizer output verification
kanpuriyanawab Jun 23, 2023
e10ea50
Minor style fixes
mattdangerw Jun 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed rotary embedding and gpt neo attention layer
  • Loading branch information
kanpuriyanawab committed Jun 6, 2023
commit afb7e1f53400c13412a5d6abf90d3a00a7a7e13c
181 changes: 113 additions & 68 deletions keras_nlp/models/gpt_neox/gpt_neox_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,96 +14,138 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.layers.transformer_layer_utils import compute_causal_mask
from keras_nlp.models.gpt_neox.rotary_embedding import RotaryEmbedding
from keras_nlp.utils.keras_utils import clone_initializer


class GPTNeoXAttention(keras.layers.Layer):
def __init__(
self,
num_heads,
hidden_dim,
rotary_pct=0.25,
max_position_embeddings=2048,
dropout=0.1,
max_position_embeddings=512,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
**kwargs
):

super().__init__()
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.head_dim = hidden_dim // num_heads
self.rotary_dim = self.head_dim * rotary_pct
# self.rotary_pct = 4
self.dropout = dropout
self.attn_head_size = hidden_dim // num_heads
self.rotary_dim = self.attn_head_size * self.hidden_dim
self.max_position_embeddings = max_position_embeddings
self.rotary_embedding = RotaryEmbedding(self.rotary_pct)
self.qkv = keras.layers.Dense(3 * self.hidden_dim)
self.dense = keras.layers.Dense(self.hidden_dim)
self.rotary_embedding = RotaryEmbedding(self.attn_head_size)

def _compute_attention(
self, query, key, value, attention_mask=None, head_mask=None
):

batch_size, _, query_len, _ = tf.shape(query)
key_len = tf.shape(key)[-2]
# causal_mask = self.bias[:, :, key_len - query_len : key_len, :key_len]
causal_mask = compute_causal_mask(batch_size, key_len, key_len)
self._kernel_initializer = keras.initializers.get(kernel_initializer)
self._bias_initializer = keras.initializers.get(bias_initializer)

query = tf.reshape(
query, [batch_size * self.num_heads, query_len, self.head_dim]
self._query_dense = keras.layers.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, self.num_heads, self.attn_head_size),
bias_axes="de",
**self._get_common_kwargs_for_sublayer(use_bias=True),
name="attention_output",
)
key = tf.reshape(
key, [batch_size * self.num_heads, query_len, self.head_dim]
self._key_dense = keras.layers.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, self.num_heads, self.attn_head_size),
bias_axes="de",
**self._get_common_kwargs_for_sublayer(use_bias=True),
name="key",
)
attention_scores = tf.zeros(
[batch_size * self.num_heads, query_len, self.head_dim],
dtype=query.dtype,
self._value_dense = keras.layers.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, self.num_heads, self.attn_head_size),
bias_axes="de",
**self._get_common_kwargs_for_sublayer(use_bias=True),
name="value",
)

attention_scores = tf.linalg.matmul(
attention_scores,
query,
tf.transpose(key, perm=[0, 2, 1]),
beta=1.0,
alpha=(tf.constant(1.0)),
self._attn_dropout_layer = keras.layers.Dropout(
self.dropout, name="attention_dropout"
)
attention_scores = tf.reshape(
attention_scores, [batch_size, self.num_heads, query_len, key_len]

self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax")

# Output.
self._output_dense = keras.layers.EinsumDense(
equation="abc,cd->abd",
output_shape=(None, self.hidden_dim),
bias_axes="d",
**self._get_common_kwargs_for_sublayer(use_bias=True),
name="attention_output",
)
mask_value = tf.constant(float("-inf"), dtype=attention_scores.dtype)
attention_scores = tf.where(causal_mask, attention_scores, mask_value)

def _get_common_kwargs_for_sublayer(self, use_bias=True):
common_kwargs = {}

kernel_initializer = clone_initializer(self._kernel_initializer)
bias_initializer = clone_initializer(self._bias_initializer)

common_kwargs["kernel_initializer"] = kernel_initializer
if use_bias:
common_kwargs["bias_initializer"] = bias_initializer

return common_kwargs

def _masked_softmax(self, attention_scores, attention_mask=None):
# print(attention_scores[0].shape, attention_scores[1].shape)
# print(attention_mask.shape, attention_scores.shape)
if attention_mask is not None:
attention_scores += attention_mask
mask_expansion_axis = -3
for _ in range(
attention_scores.shape.rank - attention_mask.shape.rank
):
attention_mask = tf.expand_dims(
attention_mask, axis=mask_expansion_axis
)
return self._softmax(attention_scores, attention_mask)

attention_scores = tf.cast(
tf.nn.softmax(attention_scores, axis=-1), dtype=value.dtype
)
def _compute_attention(
self, query, key, value, attention_mask=None, training=None
):

attention_scores = tf.einsum("aecd,abcd->acbe", key, query)

# batch_size, _, key_len, _ = tf.shape(key)
# causal_mask = compute_causal_mask(batch_size, key_len, key_len)
# attention_mask = attention_mask & causal_mask
# mask_value = tf.constant(float('-inf'), dtype=attention_scores.dtype)
# attention_scores = tf.where(causal_mask, attention_scores, mask_value)

if head_mask is not None:
attention_scores *= head_mask
# print(attention_scores[0].shape, attention_scores[1].shape)
# if attention_mask is not None:
# attention_scores += attention_mask

attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_scores = self._attn_dropout_layer(
attention_scores, training=training
)
attention_output = tf.einsum("acbe,aecd->abcd", attention_scores, value)

attention_output = tf.matmul(attention_scores, value)
return attention_output, attention_scores

def call(
self,
hidden_states,
attention_mask,
head_mask,
layer_past,
return_attention_scores,
return_attention_scores=False,
training=None,
):

qkv = self.qkv(hidden_states)
new_qkv_shape = tf.shape(hidden_states)[:-1] + [
self.num_heads,
self.head_dim,
]
qkv = tf.reshape(qkv, new_qkv_shape)
query = self._query_dense(hidden_states)
key = self._key_dense(hidden_states)
value = self._value_dense(hidden_states)

query = tf.transpose(qkv[..., : self.head_dim], (0, 2, 1, 3))
key = tf.transpose(
qkv[..., : self.head_dim : 2 * self.head_dim], (0, 2, 1, 3)
)
value = tf.transpose(qkv[..., self.head_dim :], (0, 2, 1, 3))
# query = tf.transpose(query, (0, 2, 1, 3))
# key = tf.transpose(key, (0, 2, 1, 3))
# value = tf.transpose(value, (0, 2, 1, 3))

query_rot, query_pass = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we would be better off moving this slice and concat logic into the RotaryEmbedding call. Then our usage here could look a little more like...

query = self.rotary_embedding(query)
key = self.rotary_embedding(key)

And the rotary embedding layer could also hold the percentage argument, which would conceptually be quite clean. Looks like falcon is doing this roughly -> https://huggingface.co/tiiuae/falcon-40b/blob/main/modelling_RW.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this wonderful suggestion !

query[..., : self.rotary_dim],
Expand All @@ -115,26 +157,29 @@ def call(
)

query, key = self.rotary_embedding(query_rot, key_rot)
query = tf.concat((query, query_pass), dim=-1)
key = tf.concat((key, key_pass), dim=-1)
query = tf.concat((query, query_pass), axis=-1)
key = tf.concat((key, key_pass), axis=-1)

if layer_past is not None:
past_key, past_value = layer_past
key = tf.concat((past_key, key), axis=-2)
value = tf.concat((past_value, value), axis=-2)
# if layer_past is not None:
# past_key, past_value = layer_past
# key = tf.concat((past_key, key), axis=-2)
# value = tf.concat((past_value, value), axis=-2)

attention_output, attention_scores = self._compute_attention(
query, key, value, attention_mask, head_mask
)
new_shape = tf.shape(attention_output)[:-2] + (
self.num_heads * self.head_dim
query, key, value, attention_mask, training
)

# Reshape `attention_output` to `(batch_size, sequence_length, hidden_dim)`.
attention_output = tf.reshape(
tf.transpose(attention_output, (0, 2, 1, 3)), new_shape
attention_output,
[
tf.shape(attention_output)[0],
tf.shape(attention_output)[1],
self.hidden_dim,
],
)
attention_output = self.dense(attention_output)
attention_output = self._output_dense(attention_output)

if return_attention_scores:
return (attention_output, attention_scores)

return attention_output, attention_scores
return attention_output
28 changes: 13 additions & 15 deletions keras_nlp/models/gpt_neox/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def build(self, input_shape):
self.inverse_freq = self.add_weight(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually looking at this, this inverse_freq should all be static right? if we don't need this trainable, instead of having this be a weight, let's move this into call somewhere, we can just compute it on the fly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow the comment totally. The issue looks to be a precision one, but if the goal is to keep these explicitly as float32, why not just compute them on the fly with an explicit float32 dtype? I still don't understand the need for a variable. And the fact that this is a trainable seems incorrect looking at the torch implementation, these are not trainable in torch.

In general, I would be careful attempting to apply what seems like a fairly technical point about esm checkpoints to other models. Ideally we would just check how close our forward pass outputs are for the actual pythia checkpoints under fully precision (float32 everywhere), and mixed precision (float32 for variables, float16 for computations), and use that to determine our approach here.

"inverse_freq", shape=(self.hidden_dim // 2,), dtype=tf.float32
)

self.inverse_freq.assign(
1.0
/ (
Expand All @@ -43,27 +44,24 @@ def build(self, input_shape):
)
)

@staticmethod
def apply_rotary_pos_emb(cls, tensor, cos_emb, sin_emb):

cos_emb = cos_emb[:, :, : tf.shape(tensor)[-2], :]
sin_emb = sin_emb[:, :, : tf.shape(tensor)[-2], :]

def apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
cos_emb = cos_emb[:, : tf.shape(tensor)[1], :, :]
sin_emb = sin_emb[:, : tf.shape(tensor)[1], :, :]
x1, x2 = tf.split(tensor, 2, axis=-1)
half_rot_tensor = tf.concat((-x2, x1), axis=-1)
# Incompatible shapes: [32,256,8,2] vs. [1,256,1,16] [Op:Mul]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remember to cleanup little notes like this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done !

ret = (tensor * cos_emb) + (half_rot_tensor * sin_emb)
return ret

return (tensor * cos_emb) + (half_rot_tensor * sin_emb)

def _compute_cos_sin(self, x, seq_dim=2):
def _compute_cos_sin_embedding(self, x, seq_dim=1):
seq_len = tf.shape(x)[seq_dim]
tensor = tf.range(seq_len, dtype=self.inverse_freq.dtype)
freqs = tf.einsum("i, j -> ij", tensor, self.inverse_freq)
embedding = tf.concat((freqs, freqs), axis=-1)[None, None, :, :]
embedding = tf.concat((freqs, freqs), axis=-1)[None, :, None, :]
return tf.cos(embedding), tf.sin(embedding)

def call(self, query, key):
cos_emb, sin_emb = self._compute_cos_sin(key, seq_dim=-2)
return (
self.apply_rotary_pos_emb(query, cos_emb, sin_emb),
self.apply_rotary_pos_emb(key, cos_emb, sin_emb),
)
cos_emb, sin_emb = self._compute_cos_sin_embedding(key, seq_dim=1)
q_emb = self.apply_rotary_pos_emb(query, cos_emb, sin_emb)
k_emb = self.apply_rotary_pos_emb(key, cos_emb, sin_emb)
return q_emb, k_emb