-
Notifications
You must be signed in to change notification settings - Fork 287
Move generate compilation to the task model #804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move generate compilation to the task model #804
Conversation
One major alternative we could consider is moving the sampler to the gpt_lm = keras_nlp.models.GPT2CausalLM.from_preset(...)
# First call compiles generate with default sampler.
gpt_lm.generate(prompt, length)
# Recompile by passing a new sampler to `compile()`.
gpt_lm.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=10))
# Next call will remake the generate function saved on the model.
gpt_lm.generate(prompt, length) Edit: we are going with this approach. |
a2266bf
to
ba982e2
Compare
334b655
to
0f116cd
Compare
@mattdangerw Thanks Matt! I took a quick pass, and at a high level this looks okay to me. I will need to dig into the details more to think about different use cases, will do it next week after I get back, thanks! |
keras_nlp/samplers/top_k_sampler.py
Outdated
|
||
def get_next_token(self, next_token_probs): | ||
def get_next_token(self, probs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generic feedback -- use fully spelled out argument names for anything that isn't a super well established convention. Probabilities / logits here (which one is it?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good comment. I wonder if the thing to do is to remove the from_logits=True
argument in our samplers. Then things get really obvious. You are either working with logits
or probabilities
and never both. Sampling is so complex I think anything we can do to remove cognitive load is worth it.
This would also have the advantage of making it really easy to add a temperature
argument to all our samplers which scales the logits pre-softmax. (and tightens or loosens the distribution)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a commit trying this out.
This demos a fix for keras-team#779 by moving the compilation up to the causal_lm model, where we can most easily control the conditions for recompilation. This has a few advantages: - We only need to tokenize once. - All forward passes on the model, including cache seeding, can live in the compiled function. - We expose compilation in a similar way to `keras.Model` train step. There is a overridable make_generate_function (similar to make_train_function) and a accessible `model.generate_function` property on the model.
3dbe15e
to
4e60829
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Finally took a full pass and dropped some comments.
Overall looks good, and beam search is impressively clean! 2 main comments are:
- The
state
variable has lots of freedom, and seems the main usage is still cache. - We are forcing users to clear the
self.generate_function
viacompile()
, which is a strong contract. As a comparison, I can adjustmodel.optimizer
by directly setting the field, or I can use custom training loop as an alternative. Butgenerate()
is tightly coupled withcompile()
.
prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This | ||
tensor will be iteratively updated column by column with new sampled | ||
values. | ||
state: Optional. A tensor or nested structure of tensors that will be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay... I finally understand how this state works in general. Correct me if I am wrong, sharing my understanding - this state
is a free variable, how to use it is totally decided by users of samplers
. Most often this is to hold cache, but it is in fact a backdoor people can utilize in the next
function or when they want to override __call__
method.
My thought is "state" is too broad for users to learn, also I am not clear on how to let it hold >=2 things, e.g., in contrastive search we need both cache and previous logits. As we cannot assume the existence of cache, how do we retrieve the previous logits robustly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Contrastive search is not covered by this case. IIUC contrastive search does not need a cache per say, or previous logits. What contrastive search needs is the hidden representation of every token to compute a cosine similarity metric. So to make contrastive search work, we will need to update the signature of next to something like this:
def next(prompt, state, index):
return logits, dense, state
I can take a closer look at contrastive search implementation to understand what is out there. I was thinking to cover it as a follow up, but definitely worth some thought.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also think we might want to pull apart the user journeys here.
Someone writing a sampler is writing a drop in replacement for sampler="beam"
or sampler="greedy"
. state
is not an important backdoor in this case, it's just more tensor variables that should be treated as loop variables. But from the perspective of the sampler writer, the model
and it's forward pass are a black box.
In contrast, from the perspective of a model writer, the sampler should be a black box, where the user complies with the next
contract and doesn't worry about how the sampling actually happens. Here, state
can be used for introduction of arbitrary extra updatable state needed solely to compute the probability distribution of the next token. This could be the cache for a transformer decoder, the hidden state of a recurrent network, etc.
Overall, I am most interested in keeping the sampler simple and useful for our own purposes, but worth chatting through all this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Took another pass, the functionality is good to me, dropped some comments on style. One thing we may want to do is to compare the performance before and after this PR, since now at head we don't have recompilation either.
) | ||
# Pad ragged to dense tensors. | ||
padded_shape = (None, max_length) | ||
min_length = tf.cast(tf.reduce_min(prompt.row_lengths()), "int32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need an explicit dtyple int32
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved this down into sampler for now, but essentially by default tf.shape/tf.range use int32, and tf.ragged row lengths defaults to int64.
This is slightly awkward, and we need to cast to make sure our index comparisons are all using the same type.
Yeah, I have been doing some rough bench-marking throughout developing this branch. Here's a rough breakdown. Test with
I have some ideas on where we can keep cutting our performance gap down, but this is a definite speed up over master (moving the cache seeding into the compiled function is important). |
Comments addressed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, approved!
* Move compilation to the task model This demos a fix for keras-team#779 by moving the compilation up to the causal_lm model, where we can most easily control the conditions for recompilation. This has a few advantages: - We only need to tokenize once. - All forward passes on the model, including cache seeding, can live in the compiled function. - We expose compilation in a similar way to `keras.Model` train step. There is a overridable make_generate_function (similar to make_train_function) and a accessible `model.generate_function` property on the model. * Fix beam search for new interface * Fix tests and docstrings * Minor fixups * Remove from_logits; clarify logits vs probabilities * Readability fixes * Minor fixes * Fix test failures * Address comments * Address comments
* Move compilation to the task model This demos a fix for keras-team#779 by moving the compilation up to the causal_lm model, where we can most easily control the conditions for recompilation. This has a few advantages: - We only need to tokenize once. - All forward passes on the model, including cache seeding, can live in the compiled function. - We expose compilation in a similar way to `keras.Model` train step. There is a overridable make_generate_function (similar to make_train_function) and a accessible `model.generate_function` property on the model. * Fix beam search for new interface * Fix tests and docstrings * Minor fixups * Remove from_logits; clarify logits vs probabilities * Readability fixes * Minor fixes * Fix test failures * Address comments * Address comments
This demos a fix for #779 by moving the compilation up to the causal_lm model, where we can most easily control the conditions for recompilation.
This has a few advantages:
keras.Model
train step. There is a overridablemake_generate_function
(similar tomake_train_function
) and a accessiblemodel.generate_function
property on the model.Demo -> https://colab.research.google.com/gist/mattdangerw/ea205181ef56d1d95860e8b3f4a9db4d/generate-compile-demo.ipynb