Skip to content

Commit 4a67d6b

Browse files
committed
fix update_step_context.
1 parent 3633de8 commit 4a67d6b

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def update_step_context(cls, step_context):
4949
is_unpaged_prefill = False
5050
q_start_loc = torch.cat((torch.tensor([0], device=device),
5151
step_context.q_seqlens.cumsum(0))).int()
52-
q_seqlens_cpu = step_context.q_seqlens.cpu()
53-
kv_seqlens_cpu = step_context.kv_seqlens.cpu()
54-
max_q_seq_len = torch.max(q_seqlens_cpu).item()
55-
max_kv_seq_len = torch.max(kv_seqlens_cpu).item()
52+
q_seqlens = step_context.q_seqlens.int()
53+
kv_seqlens = step_context.kv_seqlens.int()
54+
max_q_seq_len = torch.max(q_seqlens).item()
55+
max_kv_seq_len = torch.max(kv_seqlens).item()
5656

5757
if not step_context.is_decoding:
5858
is_unpaged_prefill = \
@@ -95,10 +95,10 @@ def update_step_context(cls, step_context):
9595
attn_meta_cls = cls.get_attention_metadata_cls()
9696
attn_metadata = attn_meta_cls(
9797
step_context.is_decoding,
98-
step_context.block_offsets,
98+
step_context.block_offsets.int(),
9999
q_start_loc=q_start_loc,
100-
q_seqlens=q_seqlens_cpu,
101-
kv_seqlens=kv_seqlens_cpu,
100+
q_seqlens=q_seqlens,
101+
kv_seqlens=kv_seqlens,
102102
kv_start_indices=kv_start_indices,
103103
block_size=block_size,
104104
attention_mask=attention_mask,

0 commit comments

Comments
 (0)