-
Notifications
You must be signed in to change notification settings - Fork 250
decode_n_tokens clean up #1532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
decode_n_tokens clean up #1532
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -535,7 +535,6 @@ def decode_n_tokens( | |||||||||
attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH, | ||||||||||
**sampling_kwargs, | ||||||||||
): | ||||||||||
new_tokens, new_probs = [], [] | ||||||||||
encountered_eos = False | ||||||||||
for _i in range( | ||||||||||
num_new_tokens - 1 | ||||||||||
|
@@ -553,12 +552,10 @@ def decode_n_tokens( | |||||||||
**sampling_kwargs, | ||||||||||
) | ||||||||||
input_pos += 1 | ||||||||||
new_tokens.append(next_token.clone()) | ||||||||||
callback(new_tokens[-1], done_generating=_i == num_new_tokens - 2) | ||||||||||
if need_probs or next_prob is None: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was backwards. It should be: |
||||||||||
callback(next_token.clone(), done_generating=_i == num_new_tokens - 2) | ||||||||||
if not need_probs or next_prob is None: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Everything else is just cleaning up unused code
torchchat/torchchat/generate.py Line 799 in 701d826
I think we should drop the check instead of negating here, so it becomes easier to rip spec decoding out completely. The returned prob doesn't get used either way torchchat/torchchat/generate.py Line 792 in 701d826
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Honestly this is a nit we can just merge |
||||||||||
yield out_token, None | ||||||||||
else: | ||||||||||
new_probs.append(next_prob.clone()) | ||||||||||
yield out_token, next_prob.clone() | ||||||||||
cur_token = next_token | ||||||||||
|
||||||||||
|
@@ -585,7 +582,6 @@ def decode_n_tokens( | |||||||||
dtype=cur_token.dtype, | ||||||||||
device=cur_token.device, | ||||||||||
) | ||||||||||
new_tokens.append(eos_token.clone()) | ||||||||||
eos_token, next_prob = self.decode_one_token( | ||||||||||
model, | ||||||||||
eos_token.view(1, -1), | ||||||||||
|
@@ -788,7 +784,6 @@ def generate( | |||||||||
input_pos = input_pos + num_added | ||||||||||
next_token = next_tokens[-1] | ||||||||||
else: | ||||||||||
generated_tokens = [] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not used. Generated tokens are appended after the call to |
||||||||||
for generated_token, _ in self.decode_n_tokens( | ||||||||||
model, | ||||||||||
next_token, | ||||||||||
|
@@ -806,7 +801,6 @@ def generate( | |||||||||
attention_backend=attention_backend, | ||||||||||
**sampling_kwargs, | ||||||||||
): | ||||||||||
generated_tokens.append(generated_token.view(-1)) | ||||||||||
yield generated_token, None | ||||||||||
|
||||||||||
generate_stats = { | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really used. This function yields individual tokens and probabilities, not the arrays.