Skip to content

Returning all Beams and Probs and adding a Testing Unit #908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 29, 2023

Conversation

TheAthleticCoder
Copy link
Contributor

Attempting to resolve #770 #776

Modifications were made to return all beams and their scores. Also included a test unit.

@TheAthleticCoder
Copy link
Contributor Author

@mattdangerw
Done with the requested changes. I had to change the names of a few variables in the BeamSampler function. Do let me know if I need to rename them to something else 🤔

Also added a unit test function. The tests were quite simple in which I just checked the dimensionality of all the returned prompts and the log_probs. Are there any more tests you would like me to add?
Thanks!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! A few comments.

@@ -77,7 +79,13 @@ def __call__(
index=0,
mask=None,
end_token_id=None,
return_all_beams=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this in both places? I think we can probably just take this in at init time.

Is there a workflow that would need this at call time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in the latest commit

):
if return_all_beams is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove the call time argument, we can remove this whole if/else block too

@@ -65,9 +65,11 @@ def next(prompt, state, index):
def __init__(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure to document this above!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

top_beams = tf.math.argmax(all_log_probs, axis=-1)[:, tf.newaxis]
prompt = tf.gather(all_prompts, top_beams, axis=1, batch_dims=1)

if return_all_beams:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if instead we should do something like this...

  • By default, return the top beam output sequences with shape (batch_size, length). (as we do currently)
  • If return_all_beams==True, return an (outputs, log_probs) tuple, with shapes (batch_size, num_beams, length) and (batch_size, num_beams) respectively, where each beam is ordered so the most likely is first.

This would add a little complexity to the implementation (we would probably need to do an argsort and gather for the return_all_beams branch), but it would make the return type more useful and less redundant to the end user.

Want the top beams? That's outputs[:, 0, :]. Want the second to top? That's outputs[:, 1, :].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! A concern would be how to handle beams which have the same probability. Ensuring that the output order remains the same no matter how many times we generate it over the input sequence is beneficial.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think argmax will break ties by going with whatever came first, so I think we could just use it without adding any randomness to the process.

@TheAthleticCoder
Copy link
Contributor Author

@mattdangerw Made the requested changes to the BeamSampler() and to the unit tests. Also edited the documentation. Do let me know if any changes need to be made. Thanks!

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Overall looks good to me, left some comments on style.


Call Args:
{{call_args}}

Examples:
Example 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of Example 1,2,3... let's document what each examples does. For example, you can use "Only return the beam of largest accumulated probability" here, and "Return all beams and their probability" in the next example.

top_beams = tf.math.argmax(log_probs, axis=-1)[:, tf.newaxis]
prompt = tf.gather(prompt, top_beams, axis=1, batch_dims=1)
return tf.squeeze(prompt, axis=1)
all_prompts, all_log_probs = unflatten_beams(prompt), unflatten_beams(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the line is a bit long, so we can break it down to 2 lines:

all_prompts = unflatten_beams(prompt)
all_log_probs = unflatten_beams(log_probs)

all_prompts, all_log_probs = unflatten_beams(prompt), unflatten_beams(
log_probs
)
top_beams = tf.math.argmax(all_log_probs, axis=-1)[:, tf.newaxis]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These 2 lines are only useful in the branch if not self.return_all_beams:, we can fold them below to slightly improve our performance.

self.assertEqual(output[0].shape, (self.batch_size, 5, self.length))
self.assertEqual(output[1].shape, (self.batch_size, 5))
self.assertTrue(tf.reduce_all(output[1][:, 1:] <= output[1][:, :-1]))
self.assertEqual(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also test self.join_as_string(output[0][:, 0, :]) == ["sequentially"] since we are testing returning all beams.

state_chars = list("sequentially")
state = tf.constant([[self.char_lookup[c] for c in state_chars]])
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])
output = self.sampler_all_beams(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use explicit names: sorted_prompts and sorted_log_probs here so that it's more clear to readers what we are testing.

@TheAthleticCoder
Copy link
Contributor Author

@chenmoneygithub I have made the required changes. Do let me know what else needs to be done. Thanks!

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Only one minor comment.


Call Args:
{{call_args}}

Examples:
1. Return only the beam with the highest accumulated probability.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove the number here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @chenmoneygithub,
Made the required changes. Do let me know if there are any other changes we should make. Thanks!

@chenmoneygithub chenmoneygithub merged commit 58c7e1d into keras-team:master Mar 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Return All Beams from Beam Search Utility
3 participants