Skip to content

Commit 10e8c39

Browse files
committed
[Cute] Do manual f32->f16x2 conversion for fwd_sm90
1 parent 312bb9b commit 10e8c39

File tree

4 files changed

+73
-19
lines changed

4 files changed

+73
-19
lines changed

flash_attn/cute/blackwell_helpers.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -308,15 +308,10 @@ def gemm_ptx_partial(
308308
smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo)
309309
smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi)
310310

311-
if cutlass.const_expr(not is_ts):
312-
offset_a = [(cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4
313-
for k in range(cute.size(tCrA.shape[2]))]
314-
else:
315-
offset_a = [cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32
316-
for k in range(cute.size(tCrA.shape[2]))]
311+
tCrA_layout = tCrA.layout if cutlass.const_expr(not is_ts) else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout)
312+
offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))]
317313
offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]
318-
offset_b = [(cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4
319-
for k in range(cute.size(tCrB.shape[2]))]
314+
offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))]
320315
offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]
321316

322317
if cutlass.const_expr(not is_ts):
@@ -330,8 +325,8 @@ def gemm_ptx_partial(
330325
None,
331326
[
332327
# acc.iterator.toint().ir_value(),
333-
cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
334-
cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
328+
cutlass.Int32(smem_desc_start_a_lo).ir_value(),
329+
cutlass.Int32(smem_desc_start_b_lo).ir_value(),
335330
cutlass.Int32(not zero_init).ir_value(),
336331
],
337332
"{\n\t"

flash_attn/cute/flash_fwd.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1637,7 +1637,11 @@ def scoremod_premask_fn(acc_S):
16371637
softmax.online_softmax(acc_S, is_first=True)
16381638
tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
16391639
tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
1640-
tOrP.store(tOrP_acc.load().to(self.dtype))
1640+
# tOrP.store(tOrP_acc.load().to(self.dtype))
1641+
# the "to(self.dtype)" conversion fails to vectorize for block sizes other
1642+
# than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of
1643+
# 2 elements. So we just call ptx directly.
1644+
utils.cvt_f16(tOrP_acc, tOrP)
16411645
if const_expr(not self.mma_pv_is_rs):
16421646
tPrP = smem_thr_copy_P.retile(tOrP)
16431647
cute.copy(smem_thr_copy_P, tPrP, tPsP)
@@ -1749,7 +1753,8 @@ def mma_one_n_block(
17491753
# if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
17501754
tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
17511755
tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
1752-
tOrP.store(tOrP_acc.load().to(self.dtype))
1756+
# tOrP.store(tOrP_acc.load().to(self.dtype))
1757+
utils.cvt_f16(tOrP_acc, tOrP)
17531758
if const_expr(not self.mma_pv_is_rs):
17541759
tPrP = smem_copy_params.smem_thr_copy_P.retile(mma_params.tOrP)
17551760
cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
@@ -1817,7 +1822,8 @@ def mma_one_n_block_intrawg_overlap(
18171822
pipeline_v.consumer_release(smem_pipe_read_v)
18181823
tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
18191824
tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
1820-
tOrP.store(tOrP_acc.load().to(self.dtype))
1825+
# tOrP.store(tOrP_acc.load().to(self.dtype))
1826+
utils.cvt_f16(tOrP_acc, tOrP)
18211827
if const_expr(not self.mma_pv_is_rs):
18221828
tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP)
18231829
cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)

flash_attn/cute/interface.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ def _flash_attn_fwd(
133133
assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x"
134134
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
135135

136-
# if compute_capability == 9: # TODO: tune block size according to hdim
137-
# if not causal and not local:
138-
# n_block_size = 176
136+
if compute_capability == 9: # TODO: tune block size according to hdim
137+
if not causal and not local:
138+
n_block_size = 192
139139

140140
compile_key = (
141141
dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None,
@@ -154,6 +154,7 @@ def _flash_attn_fwd(
154154
qhead_per_kvhead,
155155
is_causal=causal,
156156
is_local=local,
157+
pack_gqa=False,
157158
m_block_size=m_block_size,
158159
n_block_size=n_block_size,
159160
# num_stages=1,

flash_attn/cute/utils.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,21 @@ def fmax_reduce(
257257
x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
258258
) -> Float32:
259259
if cutlass.const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
260-
if cutlass.const_expr(init_val is None):
261-
init_val = -cutlass.Float32.inf
262-
return x.reduce(cute.ReductionOp.MAX, init_val, 0)
260+
# if cutlass.const_expr(init_val is None):
261+
# init_val = -cutlass.Float32.if
262+
# return x.reduce(cute.ReductionOp.MAX, init_val, 0)
263+
res = cute.make_fragment(x.shape, Float32)
264+
res.store(x)
265+
local_max = [res[0], res[1], res[2], res[3]]
266+
for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
267+
local_max[0] = fmax(local_max[0], res[i + 0])
268+
local_max[1] = fmax(local_max[1], res[i + 1])
269+
local_max[2] = fmax(local_max[2], res[i + 2])
270+
local_max[3] = fmax(local_max[3], res[i + 3])
271+
local_max[0] = fmax(local_max[0], local_max[1])
272+
local_max[2] = fmax(local_max[2], local_max[3])
273+
local_max[0] = fmax(local_max[0], local_max[2])
274+
return local_max[0] if cutlass.const_expr(init_val is None) else fmax(local_max[0], init_val)
263275
else:
264276
# [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max
265277
# We instead force the 3-input max.
@@ -290,6 +302,18 @@ def fadd_reduce(
290302
if cutlass.const_expr(init_val is None):
291303
init_val = Float32.zero
292304
return x.reduce(cute.ReductionOp.ADD, init_val, 0)
305+
# res = cute.make_fragment(x.shape, Float32)
306+
# res.store(x)
307+
# local_sum = [res[0], res[1], res[2], res[3]]
308+
# for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
309+
# local_sum[0] += res[i + 0]
310+
# local_sum[1] += res[i + 1]
311+
# local_sum[2] += res[i + 2]
312+
# local_sum[3] += res[i + 3]
313+
# local_sum[0] += local_sum[1]
314+
# local_sum[2] += local_sum[3]
315+
# local_sum[0] += local_sum[2]
316+
# return local_sum[0] if cutlass.const_expr(init_val is None) else local_sum[0] + init_val
293317
else:
294318
res = cute.make_fragment(x.shape, Float32)
295319
res.store(x)
@@ -440,3 +464,31 @@ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) ->
440464
val += partial_sum
441465
# if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val)
442466
return val
467+
468+
469+
@dsl_user_op
470+
def cvt_f16x2_f32(a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None) -> cutlass.Int32:
471+
assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16"
472+
return cutlass.Int32(
473+
llvm.inline_asm(
474+
T.i32(),
475+
[Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)],
476+
f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;",
477+
"=r,f,f",
478+
has_side_effects=False,
479+
is_align_stack=False,
480+
asm_dialect=llvm.AsmDialect.AD_ATT,
481+
)
482+
)
483+
484+
485+
@cute.jit
486+
def cvt_f16(src: cute.Tensor, dst: cute.Tensor):
487+
assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size"
488+
assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements"
489+
assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], "dst must be BFloat16 or Float16"
490+
assert src.element_type is Float32, "src must be Float32"
491+
dst_i32 = cute.recast_tensor(dst, cutlass.Int32)
492+
assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape)
493+
for i in cutlass.range_constexpr(cute.size(dst_i32)):
494+
dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type)

0 commit comments

Comments
 (0)