Skip to content

Returns All Beams from Beam Search Utility #776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

TheAthleticCoder
Copy link
Contributor

Resolves #770

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial comments

@@ -218,6 +218,7 @@ def beam_search(
from_logits=False,
end_token_id=None,
pad_token_id=0,
return_all_beams_and_probs=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably just leave the old version as is right @chenmoneygithub ? This will eventually be deprecated for the new sampler API.

@@ -85,7 +85,13 @@ def get_next_token(self, next_token_probs):
pass

def sample(
self, prompt, token_probability_fn, mask, num_steps, from_logits=True
self,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just call this return_all_beams.

I think rather than updating the sample() API, you should add return_all_beams as an init arg, and add self.return_all_beams to the config.

I also don't think this will work quite yet. The sampler API will generally be called via __call__, which will expect only a single output from sample. I will think more about how we can best handle this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey!
Sure, I shall make the changes. Should I make them regardless of the sampler API call or will we have to leave this issue to be resolved later?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may need to pause on this for a second, though I will draft a PR shortly that should allow us to have more dynamic return types from sampler layers. Will keep you posted!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be quite doable after #804 lands.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw Since #804 has landed, should I work on this issue further now? Are there any new things you would like me to take care of and handle? Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah! Let's do it.

Not that beam search just overrides call, it should be very easy to return differently based on a boolean parameter.

Let's take in return_all_beams in __init__, add it to the config for the model, and use that parameter to control the return type during __call__.

We should also make sure to add unit test for this. Let us know when this is ready for review!

@mattdangerw
Copy link
Member

Sounds like this can be closed in favor of #908

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Return All Beams from Beam Search Utility
3 participants