-
Notifications
You must be signed in to change notification settings - Fork 287
Returns All Beams from Beam Search Utility #776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some initial comments
keras_nlp/utils/text_generation.py
Outdated
@@ -218,6 +218,7 @@ def beam_search( | |||
from_logits=False, | |||
end_token_id=None, | |||
pad_token_id=0, | |||
return_all_beams_and_probs=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably just leave the old version as is right @chenmoneygithub ? This will eventually be deprecated for the new sampler API.
keras_nlp/samplers/beam_sampler.py
Outdated
@@ -85,7 +85,13 @@ def get_next_token(self, next_token_probs): | |||
pass | |||
|
|||
def sample( | |||
self, prompt, token_probability_fn, mask, num_steps, from_logits=True | |||
self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just call this return_all_beams
.
I think rather than updating the sample()
API, you should add return_all_beams
as an init arg, and add self.return_all_beams
to the config.
I also don't think this will work quite yet. The sampler API will generally be called via __call__
, which will expect only a single output from sample. I will think more about how we can best handle this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey!
Sure, I shall make the changes. Should I make them regardless of the sampler API call or will we have to leave this issue to be resolved later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we may need to pause on this for a second, though I will draft a PR shortly that should allow us to have more dynamic return types from sampler layers. Will keep you posted!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be quite doable after #804 lands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mattdangerw Since #804 has landed, should I work on this issue further now? Are there any new things you would like me to take care of and handle? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah! Let's do it.
Not that beam search just overrides call, it should be very easy to return differently based on a boolean parameter.
Let's take in return_all_beams
in __init__
, add it to the config for the model, and use that parameter to control the return type during __call__
.
We should also make sure to add unit test for this. Let us know when this is ready for review!
Sounds like this can be closed in favor of #908 |
Resolves #770