Skip to content

Move generate compilation to the task model #804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

mattdangerw
Copy link
Member

@mattdangerw mattdangerw commented Mar 4, 2023

This demos a fix for #779 by moving the compilation up to the causal_lm model, where we can most easily control the conditions for recompilation.

This has a few advantages:

  • We only need to tokenize once.
  • All forward passes on the model, including cache seeding, can live in the compiled function.
  • We expose compilation in a similar way to keras.Model train step. There is a overridable make_generate_function (similar to make_train_function) and a accessible model.generate_function property on the model.

Demo -> https://colab.research.google.com/gist/mattdangerw/ea205181ef56d1d95860e8b3f4a9db4d/generate-compile-demo.ipynb

@mattdangerw
Copy link
Member Author

mattdangerw commented Mar 7, 2023

One major alternative we could consider is moving the sampler to the compile() function. Then things become super parallel to how fit() and predict() work.

gpt_lm = keras_nlp.models.GPT2CausalLM.from_preset(...)
# First call compiles generate with default sampler.
gpt_lm.generate(prompt, length)
# Recompile by passing a new sampler to `compile()`.
gpt_lm.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=10))
# Next call will remake the generate function saved on the model.
gpt_lm.generate(prompt, length)

Edit: we are going with this approach.

@mattdangerw mattdangerw changed the title Move compilation to the task model Move generate compilation to the task model Mar 8, 2023
@mattdangerw mattdangerw force-pushed the move-compilation-for-generate branch from a2266bf to ba982e2 Compare March 10, 2023 00:07
@mattdangerw mattdangerw marked this pull request as ready for review March 10, 2023 00:47
@mattdangerw mattdangerw force-pushed the move-compilation-for-generate branch from 334b655 to 0f116cd Compare March 10, 2023 01:06
@chenmoneygithub
Copy link
Contributor

@mattdangerw Thanks Matt! I took a quick pass, and at a high level this looks okay to me. I will need to dig into the details more to think about different use cases, will do it next week after I get back, thanks!


def get_next_token(self, next_token_probs):
def get_next_token(self, probs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generic feedback -- use fully spelled out argument names for anything that isn't a super well established convention. Probabilities / logits here (which one is it?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good comment. I wonder if the thing to do is to remove the from_logits=True argument in our samplers. Then things get really obvious. You are either working with logits or probabilities and never both. Sampling is so complex I think anything we can do to remove cognitive load is worth it.

This would also have the advantage of making it really easy to add a temperature argument to all our samplers which scales the logits pre-softmax. (and tightens or loosens the distribution)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a commit trying this out.

This demos a fix for keras-team#779
by moving the compilation up to the causal_lm model, where we can most
easily control the conditions for recompilation.

This has a few advantages:

 - We only need to tokenize once.
 - All forward passes on the model, including cache seeding, can live
   in the compiled function.
 - We expose compilation in a similar way to `keras.Model` train step.
   There is a overridable make_generate_function (similar to
   make_train_function) and a accessible `model.generate_function`
   property on the model.
@mattdangerw mattdangerw force-pushed the move-compilation-for-generate branch from 3dbe15e to 4e60829 Compare March 10, 2023 23:31
Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Finally took a full pass and dropped some comments.

Overall looks good, and beam search is impressively clean! 2 main comments are:

  • The state variable has lots of freedom, and seems the main usage is still cache.
  • We are forcing users to clear the self.generate_function via compile(), which is a strong contract. As a comparison, I can adjust model.optimizer by directly setting the field, or I can use custom training loop as an alternative. But generate() is tightly coupled with compile().

prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This
tensor will be iteratively updated column by column with new sampled
values.
state: Optional. A tensor or nested structure of tensors that will be
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay... I finally understand how this state works in general. Correct me if I am wrong, sharing my understanding - this state is a free variable, how to use it is totally decided by users of samplers. Most often this is to hold cache, but it is in fact a backdoor people can utilize in the next function or when they want to override __call__ method.

My thought is "state" is too broad for users to learn, also I am not clear on how to let it hold >=2 things, e.g., in contrastive search we need both cache and previous logits. As we cannot assume the existence of cache, how do we retrieve the previous logits robustly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Contrastive search is not covered by this case. IIUC contrastive search does not need a cache per say, or previous logits. What contrastive search needs is the hidden representation of every token to compute a cosine similarity metric. So to make contrastive search work, we will need to update the signature of next to something like this:

def next(prompt, state, index):
   return logits, dense, state

I can take a closer look at contrastive search implementation to understand what is out there. I was thinking to cover it as a follow up, but definitely worth some thought.

Copy link
Member Author

@mattdangerw mattdangerw Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think we might want to pull apart the user journeys here.

Someone writing a sampler is writing a drop in replacement for sampler="beam" or sampler="greedy". state is not an important backdoor in this case, it's just more tensor variables that should be treated as loop variables. But from the perspective of the sampler writer, the model and it's forward pass are a black box.

In contrast, from the perspective of a model writer, the sampler should be a black box, where the user complies with the next contract and doesn't worry about how the sampling actually happens. Here, state can be used for introduction of arbitrary extra updatable state needed solely to compute the probability distribution of the next token. This could be the cache for a transformer decoder, the hidden state of a recurrent network, etc.

Overall, I am most interested in keeping the sampler simple and useful for our own purposes, but worth chatting through all this!

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Took another pass, the functionality is good to me, dropped some comments on style. One thing we may want to do is to compare the performance before and after this PR, since now at head we don't have recompilation either.

)
# Pad ragged to dense tensors.
padded_shape = (None, max_length)
min_length = tf.cast(tf.reduce_min(prompt.row_lengths()), "int32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need an explicit dtyple int32 here?

Copy link
Member Author

@mattdangerw mattdangerw Mar 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this down into sampler for now, but essentially by default tf.shape/tf.range use int32, and tf.ragged row lengths defaults to int64.

This is slightly awkward, and we need to cast to make sure our index comparisons are all using the same type.

@mattdangerw
Copy link
Member Author

One thing we may want to do is to compare the performance before and after this PR, since now at head we don't have recompilation either.

Yeah, I have been doing some rough bench-marking throughout developing this branch. Here's a rough breakdown.

Test with batch_size=2, max_length=256, num_trials=25, so ~12800 tokens generated. 3090 GPU.

  • On master branch: 22s
  • This branch: 17s
  • Huggingface with XLA: 14s

I have some ideas on where we can keep cutting our performance gap down, but this is a definite speed up over master (moving the cache seeding into the compiled function is important).

@mattdangerw
Copy link
Member Author

Comments addressed.

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, approved!

@mattdangerw mattdangerw merged commit c74f9da into keras-team:master Mar 21, 2023
kanpuriyanawab pushed a commit to kanpuriyanawab/keras-nlp that referenced this pull request Mar 26, 2023
* Move compilation to the task model

This demos a fix for keras-team#779
by moving the compilation up to the causal_lm model, where we can most
easily control the conditions for recompilation.

This has a few advantages:

 - We only need to tokenize once.
 - All forward passes on the model, including cache seeding, can live
   in the compiled function.
 - We expose compilation in a similar way to `keras.Model` train step.
   There is a overridable make_generate_function (similar to
   make_train_function) and a accessible `model.generate_function`
   property on the model.

* Fix beam search for new interface

* Fix tests and docstrings

* Minor fixups

* Remove from_logits; clarify logits vs probabilities

* Readability fixes

* Minor fixes

* Fix test failures

* Address comments

* Address comments
kanpuriyanawab pushed a commit to kanpuriyanawab/keras-nlp that referenced this pull request Mar 26, 2023
* Move compilation to the task model

This demos a fix for keras-team#779
by moving the compilation up to the causal_lm model, where we can most
easily control the conditions for recompilation.

This has a few advantages:

 - We only need to tokenize once.
 - All forward passes on the model, including cache seeding, can live
   in the compiled function.
 - We expose compilation in a similar way to `keras.Model` train step.
   There is a overridable make_generate_function (similar to
   make_train_function) and a accessible `model.generate_function`
   property on the model.

* Fix beam search for new interface

* Fix tests and docstrings

* Minor fixups

* Remove from_logits; clarify logits vs probabilities

* Readability fixes

* Minor fixes

* Fix test failures

* Address comments

* Address comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants