-
Notifications
You must be signed in to change notification settings - Fork 249
decode_n_tokens clean up #1532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
decode_n_tokens clean up #1532
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1532
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3816bca with merge base 701d826 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -535,7 +535,6 @@ def decode_n_tokens( | |||
attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH, | |||
**sampling_kwargs, | |||
): | |||
new_tokens, new_probs = [], [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really used. This function yields individual tokens and probabilities, not the arrays.
@@ -553,12 +552,10 @@ def decode_n_tokens( | |||
**sampling_kwargs, | |||
) | |||
input_pos += 1 | |||
new_tokens.append(next_token.clone()) | |||
callback(new_tokens[-1], done_generating=_i == num_new_tokens - 2) | |||
if need_probs or next_prob is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was backwards. It should be: if not need probs or next_prob is None:
. Otherwise you are saying, if you need the probabilities you are getting None, and if you don't need them you are getting them.
@@ -788,7 +784,6 @@ def generate( | |||
input_pos = input_pos + num_added | |||
next_token = next_tokens[-1] | |||
else: | |||
generated_tokens = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not used. Generated tokens are appended after the call to decode_n_tokens
, but then nothing happens.
callback(new_tokens[-1], done_generating=_i == num_new_tokens - 2) | ||
if need_probs or next_prob is None: | ||
callback(next_token.clone(), done_generating=_i == num_new_tokens - 2) | ||
if not need_probs or next_prob is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything else is just cleaning up unused code
not need_probs
is the only real change and in the non-speculative path is always false so the old check is effectively just if next_prob is None
torchchat/torchchat/generate.py
Line 799 in 701d826
need_probs=False, |
I think we should drop the check instead of negating here, so it becomes easier to rip spec decoding out completely. The returned prob doesn't get used either way
torchchat/torchchat/generate.py
Line 792 in 701d826
for generated_token, _ in self.decode_n_tokens( |
if not need_probs or next_prob is None: | |
if next_prob is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly this is a nit we can just merge
Deletes unused python arrays