Skip to content

Commit a3a257c

Browse files
committed
Fix out-of-bound writes for var-seq-len zero-length KVs
1 parent bcd918f commit a3a257c

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

hopper/epilogue_fwd_sm90_tma.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,10 @@ struct CollectiveEpilogueFwd {
285285
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); }
286286
// Clear_OOB_K must be false since we don't want to write zeros to gmem
287287
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
288-
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM
288+
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_traits_q.actual_seq_len - m_block * kBlockM
289289
);
290290
static_assert(kBlockM <= NumMmaThreads);
291-
if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; }
291+
if (thread_idx < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; }
292292
}
293293

294294
};

hopper/flash_fwd_kernel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
123123
}
124124
int n_block_max = collective_mainloop.get_n_block_max(
125125
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
126-
if (Is_causal && n_block_max <= 0) {
126+
if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) {
127127
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
128128
scheduler.broadcast_next_work(work_tile_info);
129129
continue;
@@ -169,7 +169,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
169169
}
170170
int n_block_max = collective_mainloop.get_n_block_max(
171171
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
172-
if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
172+
if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
173173
collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
174174
continue;
175175
}

0 commit comments

Comments
 (0)