diff --git a/torchchat/generate.py b/torchchat/generate.py index 8555b85bd..53d9d8f8c 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -535,7 +535,6 @@ def decode_n_tokens( attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH, **sampling_kwargs, ): - new_tokens, new_probs = [], [] encountered_eos = False for _i in range( num_new_tokens - 1 @@ -553,12 +552,10 @@ def decode_n_tokens( **sampling_kwargs, ) input_pos += 1 - new_tokens.append(next_token.clone()) - callback(new_tokens[-1], done_generating=_i == num_new_tokens - 2) - if need_probs or next_prob is None: + callback(next_token.clone(), done_generating=_i == num_new_tokens - 2) + if not need_probs or next_prob is None: yield out_token, None else: - new_probs.append(next_prob.clone()) yield out_token, next_prob.clone() cur_token = next_token @@ -585,7 +582,6 @@ def decode_n_tokens( dtype=cur_token.dtype, device=cur_token.device, ) - new_tokens.append(eos_token.clone()) eos_token, next_prob = self.decode_one_token( model, eos_token.view(1, -1), @@ -788,7 +784,6 @@ def generate( input_pos = input_pos + num_added next_token = next_tokens[-1] else: - generated_tokens = [] for generated_token, _ in self.decode_n_tokens( model, next_token, @@ -806,7 +801,6 @@ def generate( attention_backend=attention_backend, **sampling_kwargs, ): - generated_tokens.append(generated_token.view(-1)) yield generated_token, None generate_stats = {