-
Notifications
You must be signed in to change notification settings - Fork 287
Returning all Beams and Probs and adding a Testing Unit #908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Returning all Beams and Probs and adding a Testing Unit #908
Conversation
@mattdangerw Also added a unit test function. The tests were quite simple in which I just checked the dimensionality of all the returned |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! A few comments.
keras_nlp/samplers/beam_sampler.py
Outdated
@@ -77,7 +79,13 @@ def __call__( | |||
index=0, | |||
mask=None, | |||
end_token_id=None, | |||
return_all_beams=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this in both places? I think we can probably just take this in at init time.
Is there a workflow that would need this at call time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in the latest commit
keras_nlp/samplers/beam_sampler.py
Outdated
): | ||
if return_all_beams is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we remove the call time argument, we can remove this whole if/else block too
@@ -65,9 +65,11 @@ def next(prompt, state, index): | |||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure to document this above!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure!
keras_nlp/samplers/beam_sampler.py
Outdated
top_beams = tf.math.argmax(all_log_probs, axis=-1)[:, tf.newaxis] | ||
prompt = tf.gather(all_prompts, top_beams, axis=1, batch_dims=1) | ||
|
||
if return_all_beams: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if instead we should do something like this...
- By default, return the top beam output sequences with shape
(batch_size, length)
. (as we do currently) - If
return_all_beams==True
, return an(outputs, log_probs)
tuple, with shapes(batch_size, num_beams, length)
and(batch_size, num_beams)
respectively, where each beam is ordered so the most likely is first.
This would add a little complexity to the implementation (we would probably need to do an argsort
and gather
for the return_all_beams
branch), but it would make the return type more useful and less redundant to the end user.
Want the top beams? That's outputs[:, 0, :]
. Want the second to top? That's outputs[:, 1, :]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! A concern would be how to handle beams which have the same probability. Ensuring that the output order remains the same no matter how many times we generate it over the input sequence is beneficial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think argmax will break ties by going with whatever came first, so I think we could just use it without adding any randomness to the process.
@mattdangerw Made the requested changes to the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Overall looks good to me, left some comments on style.
keras_nlp/samplers/beam_sampler.py
Outdated
|
||
Call Args: | ||
{{call_args}} | ||
|
||
Examples: | ||
Example 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of Example 1,2,3... let's document what each examples does. For example, you can use "Only return the beam of largest accumulated probability" here, and "Return all beams and their probability" in the next example.
keras_nlp/samplers/beam_sampler.py
Outdated
top_beams = tf.math.argmax(log_probs, axis=-1)[:, tf.newaxis] | ||
prompt = tf.gather(prompt, top_beams, axis=1, batch_dims=1) | ||
return tf.squeeze(prompt, axis=1) | ||
all_prompts, all_log_probs = unflatten_beams(prompt), unflatten_beams( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now the line is a bit long, so we can break it down to 2 lines:
all_prompts = unflatten_beams(prompt)
all_log_probs = unflatten_beams(log_probs)
keras_nlp/samplers/beam_sampler.py
Outdated
all_prompts, all_log_probs = unflatten_beams(prompt), unflatten_beams( | ||
log_probs | ||
) | ||
top_beams = tf.math.argmax(all_log_probs, axis=-1)[:, tf.newaxis] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These 2 lines are only useful in the branch if not self.return_all_beams:
, we can fold them below to slightly improve our performance.
self.assertEqual(output[0].shape, (self.batch_size, 5, self.length)) | ||
self.assertEqual(output[1].shape, (self.batch_size, 5)) | ||
self.assertTrue(tf.reduce_all(output[1][:, 1:] <= output[1][:, :-1])) | ||
self.assertEqual( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also test self.join_as_string(output[0][:, 0, :]) == ["sequentially"]
since we are testing returning all beams.
state_chars = list("sequentially") | ||
state = tf.constant([[self.char_lookup[c] for c in state_chars]]) | ||
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) | ||
output = self.sampler_all_beams( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use explicit names: sorted_prompts
and sorted_log_probs
here so that it's more clear to readers what we are testing.
@chenmoneygithub I have made the required changes. Do let me know what else needs to be done. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Only one minor comment.
keras_nlp/samplers/beam_sampler.py
Outdated
|
||
Call Args: | ||
{{call_args}} | ||
|
||
Examples: | ||
1. Return only the beam with the highest accumulated probability. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove the number here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @chenmoneygithub,
Made the required changes. Do let me know if there are any other changes we should make. Thanks!
Attempting to resolve #770 #776
Modifications were made to return all beams and their scores. Also included a test unit.