@@ -49,10 +49,10 @@ def update_step_context(cls, step_context):
49
49
is_unpaged_prefill = False
50
50
q_start_loc = torch .cat ((torch .tensor ([0 ], device = device ),
51
51
step_context .q_seqlens .cumsum (0 ))).int ()
52
- q_seqlens_cpu = step_context .q_seqlens .cpu ()
53
- kv_seqlens_cpu = step_context .kv_seqlens .cpu ()
54
- max_q_seq_len = torch .max (q_seqlens_cpu ).item ()
55
- max_kv_seq_len = torch .max (kv_seqlens_cpu ).item ()
52
+ q_seqlens = step_context .q_seqlens .int ()
53
+ kv_seqlens = step_context .kv_seqlens .int ()
54
+ max_q_seq_len = torch .max (q_seqlens ).item ()
55
+ max_kv_seq_len = torch .max (kv_seqlens ).item ()
56
56
57
57
if not step_context .is_decoding :
58
58
is_unpaged_prefill = \
@@ -95,10 +95,10 @@ def update_step_context(cls, step_context):
95
95
attn_meta_cls = cls .get_attention_metadata_cls ()
96
96
attn_metadata = attn_meta_cls (
97
97
step_context .is_decoding ,
98
- step_context .block_offsets ,
98
+ step_context .block_offsets . int () ,
99
99
q_start_loc = q_start_loc ,
100
- q_seqlens = q_seqlens_cpu ,
101
- kv_seqlens = kv_seqlens_cpu ,
100
+ q_seqlens = q_seqlens ,
101
+ kv_seqlens = kv_seqlens ,
102
102
kv_start_indices = kv_start_indices ,
103
103
block_size = block_size ,
104
104
attention_mask = attention_mask ,
0 commit comments