-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[TPU] Fix tpu structured decoding in mixed batches #24458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Chenyaaang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes an indexing issue for structured decoding on TPUs when mixed batches of guided and non-guided requests are present. The change correctly maps the compact grammar_bitmask
to the requests in the batch. However, I've identified a critical issue with how this function handles chunked prefills (sub-batching), which can lead to incorrect behavior.
sorted_struct_requests = sorted( | ||
scheduler_output.structured_output_request_ids.items(), | ||
key=lambda item: item[1]) | ||
cumulative_mask_idx = 0 | ||
for req_id, _ in sorted_struct_requests: | ||
if req_id not in self.input_batch.req_id_to_index: | ||
continue | ||
batch_index = self.input_batch.req_id_to_index[req_id] | ||
struct_out_indices.append(batch_index) | ||
mask_indices.append(mask_index) | ||
self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy( | ||
grammar_bitmask[mask_indices]) | ||
# It's not guaranteed that all requests in this batch require | ||
# structured output, so create a bool tensor to represent | ||
# the requests that need structured output. | ||
struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long) | ||
self.require_structured_out_cpu[struct_out_indices] = True | ||
self.grammar_bitmask_cpu[batch_index] = torch.from_numpy( | ||
grammar_bitmask[cumulative_mask_idx]) | ||
# It's not guaranteed that all requests in this batch require | ||
# structured output, so create a bool tensor to represent | ||
# the requests that need structured output. | ||
self.require_structured_out_cpu[batch_index] = True | ||
cumulative_mask_idx += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation does not correctly handle chunked prefills (sub-batching). The execute_model
method processes requests in chunks, but this function seems to assume it's always processing the first chunk.
The batch_index
is an absolute index within the full self.input_batch
, but the logits
tensor and the returned bitmask tensors are relative to the current sub-batch. When processing a sub-batch that is not the first one (i.e., when start_index > 0
in execute_model
), batch_index
will be larger than the sub-batch size. However, the function returns self.grammar_bitmask_cpu[:num_reqs]
, which is a slice from the beginning of the buffer. This will result in an incorrect (likely all-zero) bitmask for requests in subsequent chunks, leading to incorrect structured decoding.
To fix this, the function needs to be aware of the current sub-batch's start_index
and end_index
and use relative indexing for the output tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the fix!
Signed-off-by: Chenyaaang <[email protected]>
Signed-off-by: Chenyaaang <[email protected]>
Signed-off-by: Chenyaaang <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
[TPU] Fix structured decoding in mixed batches. This should fix the TPU v1 test part2 soft fail.
When there are guided and non-guided requests in the same batch,
scheduler_output.grammar_bitmask
only contains the entries for guided requests, sort the guided requests and use cumulative index to avoid index out of bound error.