Skip to content

[DRAFT] Generation refactor #1425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion keras_nlp/src/layers/modeling/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,28 @@ def build(
# Create layers based on input shape.
self.built = True

def compute_self_attention_cache(
self,
decoder_sequence,
):
x = decoder_sequence
if self.normalize_first:
x = self._self_attention_layer_norm(x)
key = self._self_attention_layer._key_dense(x)
value = self._self_attention_layer._value_dense(x)
return ops.stack((key, value), axis=1)

def compute_cross_attention_cache(
self,
encoder_sequence,
):
x = encoder_sequence
if self.normalize_first:
x = self._cross_attention_layer_norm(x)
key = self._cross_attention_layer._key_dense(x)
value = self._cross_attention_layer._value_dense(x)
return ops.stack((key, value), axis=1)

def call(
self,
decoder_sequence,
Expand Down Expand Up @@ -314,7 +336,9 @@ def call(
the layer has cross-attention.
"""

has_encoder_sequence = encoder_sequence is not None
has_encoder_sequence = (
encoder_sequence is not None or cross_attention_cache is not None
)

has_cross_attention = self._cross_attention_layer is not None
if not has_cross_attention and has_encoder_sequence:
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/src/layers/preprocessing/start_end_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def call(
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs

if self.return_padding_mask:
mask = tf.ones_like(x, dtype="bool")
mask = tf.ones_like(x, dtype="int32")
mask = mask.to_tensor(shape=(batch_size, sequence_length))
mask = tf.squeeze(mask, axis=0) if unbatched else mask
return outputs, mask
Expand Down
1 change: 0 additions & 1 deletion keras_nlp/src/models/bart/bart_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,5 +257,4 @@ def get_config(self):
"max_sequence_length": self.max_sequence_length,
}
)

return config
310 changes: 43 additions & 267 deletions keras_nlp/src/models/bart/bart_seq_2_seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
BartSeq2SeqLMPreprocessor,
)
from keras_nlp.src.models.seq_2_seq_lm import Seq2SeqLM
from keras_nlp.src.utils.tensor_utils import any_equal


@keras_nlp_export("keras_nlp.models.BartSeq2SeqLM")
Expand Down Expand Up @@ -200,291 +199,68 @@ def __init__(
**kwargs,
)

def call_decoder_with_cache(
def build_cache(self, batch_size, max_length):
num_layers = self.backbone.num_layers
num_heads = self.backbone.num_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_heads
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
return ops.zeros(shape, dtype=self.compute_dtype)

def compute_cross_attention_cache(
self, encoder_token_ids, encoder_padding_mask
):
"""Does a forward pass on the encoder and returns the encoder output."""
# Embedding layers.
tokens = self.backbone.token_embedding(encoder_token_ids)
positions = self.backbone.encoder_position_embedding(tokens)
# Sum, normalize and apply dropout to embeddings.
x = self.backbone.encoder_embeddings_add((tokens, positions))
x = self.backbone.encoder_embeddings_layer_norm(x)
x = self.backbone.encoder_embeddings_dropout(x)
# Transformer encoder layers.
for layer in self.backbone.encoder_transformer_layers:
x = layer(x, padding_mask=encoder_padding_mask)
# Transformer encoder layers.
caches = []
for layer in self.backbone.decoder_transformer_layers:
caches.append(layer.compute_cross_attention_cache(x))
return ops.stack(caches, axis=1)

def call_with_cache(
self,
encoder_hidden_states,
token_ids,
cache,
index,
*,
encoder_padding_mask,
decoder_token_ids,
self_attention_cache=None,
self_attention_cache_update_index=None,
cross_attention_cache=None,
cross_attention_cache_update_index=None,
cross_attention_cache,
):
"""Forward pass with a key/value caches for generative decoding..

`call_decoder_with_cache` adds an additional inference-time forward pass
for the model for seq2seq text generation. Unlike calling the model
directly, this method does two things to optimize text generation:

- Allows caching previous key/value tensors in the decoder's
self-attention layer to avoid recomputing the outputs of seen tokens.
- Allows caching key/value tensors in the decoder's cross-attention
layer to avoid recomputing the encoder outputs.

Args:
encoder_hidden_states: a dense float Tensor of shape
`(batch_size, encoder_sequence_length, hidden_dim)`. The
sequence of hidden states at the output of the encoder's last
layer.
encoder_padding_mask: a dense float Tensor of shape
`(batch_size, encoder_sequence_length)`. The padding mask for
the encoder input.
decoder_token_ids: a dense int Tensor of shape
`(batch_size, max_length)`. Input token ids to be fed to
the decoder.
self_attention_cache: a dense float Tensor of shape
`(batch_size, num_layers, 2, max_length, num_heads, key_dims)`.
The cached key/value tensors of previously seen tokens in the
decoder's self-attention layer.
self_attention_cache_update_index: an int or int Tensor, the index
at which to update the `self_attention_cache`. Usually, this is
the index of the current token being processed during decoding.
cross_attention_cache: a dense float Tensor of shape
`(batch_size, num_layers, 2, encoder_sequence_length, num_heads, key_dims)`.
The cached key/value tensors of the encoder outputs in the
decoder's cross-attention layer.
cross_attention_cache_update_index: an int or int Tensor, the index
at which to update the `cross_attention_cache`. Usually, this is
either `0` (compute the entire `cross_attention_cache`), or
`None` (reuse a previously computed `cross_attention_cache`).

Returns:
A `(logits, hidden_states, self_attention_cache, cross_attention_cache)`
tuple, where `logits` is the language model logits for the input
`decoder_token_ids`, `hidden_states` is the final hidden
representation of the input tokens, `self_attention_cache` is the
key/value cache in the decoder's self-attention layer and
`cross_attention_cache` is the key/value cache in the decoder's
cross-attention layer.
"""
# Embedding layers.
tokens = self.backbone.token_embedding(decoder_token_ids)
tokens = self.backbone.token_embedding(token_ids)
positions = self.backbone.decoder_position_embedding(
tokens,
start_index=self_attention_cache_update_index,
tokens, start_index=index
)
# Sum, normalize and apply dropout to embeddings.
x = self.backbone.decoder_embeddings_add((tokens, positions))
x = self.backbone.decoder_embeddings_layer_norm(x)
x = self.backbone.decoder_embeddings_dropout(x)

# Every decoder layer has a separate cache for the self-attention layer
# and the cross-attention layer. We update all of them separately.
self_attention_caches = []
cross_attention_caches = []
# Each decoder layer has a cache; we update them separately.
caches = []
for i, layer in enumerate(self.backbone.decoder_transformer_layers):
current_self_attention_cache = self_attention_cache[:, i, ...]
current_self_attention_cache = cache[:, i, ...]
current_cross_attention_cache = cross_attention_cache[:, i, ...]
(
x,
next_self_attention_cache,
next_cross_attention_cache,
) = layer(
x, next_cache, _ = layer(
decoder_sequence=x,
encoder_sequence=encoder_hidden_states,
encoder_padding_mask=encoder_padding_mask,
self_attention_cache=current_self_attention_cache,
self_attention_cache_update_index=self_attention_cache_update_index,
self_attention_cache_update_index=index,
cross_attention_cache=current_cross_attention_cache,
cross_attention_cache_update_index=cross_attention_cache_update_index,
)
if self_attention_cache_update_index is not None:
self_attention_caches.append(next_self_attention_cache)
if cross_attention_cache_update_index is not None:
cross_attention_caches.append(next_cross_attention_cache)

if self_attention_cache_update_index is not None:
self_attention_cache = ops.stack(self_attention_caches, axis=1)
if cross_attention_cache_update_index is not None:
cross_attention_cache = ops.stack(cross_attention_caches, axis=1)

caches.append(next_cache)
cache = ops.stack(caches, axis=1)
hidden_states = x
logits = self.backbone.token_embedding(hidden_states, reverse=True)
return (
logits,
hidden_states,
self_attention_cache,
cross_attention_cache,
cache,
)

def call_encoder(self, token_ids, padding_mask):
"""Does a forward pass on the encoder and returns the encoder output."""
tokens = self.backbone.token_embedding(token_ids)
positions = self.backbone.encoder_position_embedding(tokens)
x = self.backbone.decoder_embeddings_add((tokens, positions))
x = self.backbone.encoder_embeddings_layer_norm(x)
x = self.backbone.encoder_embeddings_dropout(x)
for transformer_layer in self.backbone.encoder_transformer_layers:
x = transformer_layer(x, padding_mask=padding_mask)
return x

def _initialize_cache(self, encoder_token_ids, decoder_token_ids):
"""Initializes empty self-attention cache and cross-attention cache."""
batch_size = ops.shape(encoder_token_ids)[0]
encoder_max_length = ops.shape(encoder_token_ids)[1]
decoder_max_length = ops.shape(decoder_token_ids)[1]

num_layers = self.backbone.num_layers
num_heads = self.backbone.num_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_heads

shape = [
batch_size,
num_layers,
2,
decoder_max_length,
num_heads,
head_dim,
]
self_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)

shape[3] = encoder_max_length
cross_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)

return (self_attention_cache, cross_attention_cache)

def _build_cache(
self, encoder_token_ids, encoder_padding_mask, decoder_token_ids
):
"""Builds the self-attention cache and the cross-attention cache (key/value pairs)."""
encoder_hidden_states = self.call_encoder(
token_ids=encoder_token_ids, padding_mask=encoder_padding_mask
)
self_attention_cache, cross_attention_cache = self._initialize_cache(
encoder_token_ids, decoder_token_ids
)

# Seed the self-attention cache and the cross-attention cache.
(
_,
hidden_states,
self_attention_cache,
cross_attention_cache,
) = self.call_decoder_with_cache(
encoder_hidden_states=encoder_hidden_states,
encoder_padding_mask=encoder_padding_mask,
decoder_token_ids=decoder_token_ids,
self_attention_cache=self_attention_cache,
self_attention_cache_update_index=0,
cross_attention_cache=cross_attention_cache,
cross_attention_cache_update_index=0,
)
return (
hidden_states,
encoder_hidden_states,
self_attention_cache,
cross_attention_cache,
)

def generate_step(
self,
inputs,
stop_token_ids=None,
):
"""A compilable generation function for a batch of inputs.

This function represents the inner, XLA-compilable, generation function
for a single batch of inputs. Inputs should have the same structure as
model inputs, a dictionary with keys `"encoder_token_ids"`,
`"encoder_padding_mask"`, `"decoder_token_ids"` and
`"decoder_padding_mask"`.

Args:
inputs: A dictionary with four keys - `"encoder_token_ids"`,
`"encoder_padding_mask"`, `"decoder_token_ids"` and
`"decoder_padding_mask"`, with batched tensor values.
stop_token_ids: Tuple of id's of end token's to stop on. If all
sequences have produced a new stop token, generation
will stop.
"""
(
encoder_token_ids,
encoder_padding_mask,
decoder_token_ids,
decoder_padding_mask,
) = (
inputs["encoder_token_ids"],
inputs["encoder_padding_mask"],
inputs["decoder_token_ids"],
inputs["decoder_padding_mask"],
)

batch_size = ops.shape(encoder_token_ids)[0]

# Create and seed cache with a single forward pass.
(
hidden_states,
encoder_hidden_states,
self_attention_cache,
cross_attention_cache,
) = self._build_cache(
encoder_token_ids, encoder_padding_mask, decoder_token_ids
)
# Compute the lengths of all user inputted tokens ids.
row_lengths = ops.sum(ops.cast(decoder_padding_mask, "int32"), axis=-1)
# Start at the first index that has no user inputted id.
index = ops.min(row_lengths)

def next(prompt, cache, index):
# The cache index is the index of our previous token.
cache_index = index - 1
num_samples = ops.shape(prompt)[0]
prompt = ops.slice(prompt, [0, cache_index], [num_samples, 1])

def repeat_tensor(x):
"""Repeats tensors along batch axis to match dim for beam search."""
if ops.shape(x)[0] == num_samples:
return x
return ops.repeat(x, repeats=num_samples // batch_size, axis=0)

logits, hidden_states, cache, _ = self.call_decoder_with_cache(
encoder_hidden_states=repeat_tensor(encoder_hidden_states),
encoder_padding_mask=repeat_tensor(encoder_padding_mask),
decoder_token_ids=prompt,
self_attention_cache=cache,
self_attention_cache_update_index=cache_index,
cross_attention_cache=repeat_tensor(cross_attention_cache),
cross_attention_cache_update_index=None,
)
return (
ops.squeeze(logits, axis=1),
ops.squeeze(hidden_states, axis=1),
cache,
)

decoder_token_ids = self.sampler(
next=next,
prompt=decoder_token_ids,
cache=self_attention_cache,
index=index,
mask=decoder_padding_mask,
stop_token_ids=stop_token_ids,
hidden_states=hidden_states,
model=self,
)

# Compute an output padding mask with the token ids we updated.
if stop_token_ids is not None:
# Build a mask of `stop_token_ids` locations not in the original
# prompt (not in locations where `decoder_padding_mask` is True).
end_locations = any_equal(
decoder_token_ids,
stop_token_ids,
ops.logical_not(decoder_padding_mask),
)
end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after `end_locations`.
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
overflow = cumsum - end_locations
# Our padding mask is the inverse of these overflow locations.
decoder_padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
else:
# Without early stopping, all locations will have been updated.
decoder_padding_mask = ops.ones_like(
decoder_token_ids, dtype="bool"
)

return {
"decoder_token_ids": decoder_token_ids,
"decoder_padding_mask": decoder_padding_mask,
}
Loading
Loading