@@ -257,9 +257,21 @@ def fmax_reduce(
257257 x : cute .TensorSSA , init_val : float | Float32 | None = None , arch : cutlass .Constexpr [int ] = 80
258258) -> Float32 :
259259 if cutlass .const_expr (arch < 100 or cute .size (x .shape ) % 8 != 0 ):
260- if cutlass .const_expr (init_val is None ):
261- init_val = - cutlass .Float32 .inf
262- return x .reduce (cute .ReductionOp .MAX , init_val , 0 )
260+ # if cutlass.const_expr(init_val is None):
261+ # init_val = -cutlass.Float32.if
262+ # return x.reduce(cute.ReductionOp.MAX, init_val, 0)
263+ res = cute .make_fragment (x .shape , Float32 )
264+ res .store (x )
265+ local_max = [res [0 ], res [1 ], res [2 ], res [3 ]]
266+ for i in cutlass .range_constexpr (4 , cute .size (x .shape ), 4 ):
267+ local_max [0 ] = fmax (local_max [0 ], res [i + 0 ])
268+ local_max [1 ] = fmax (local_max [1 ], res [i + 1 ])
269+ local_max [2 ] = fmax (local_max [2 ], res [i + 2 ])
270+ local_max [3 ] = fmax (local_max [3 ], res [i + 3 ])
271+ local_max [0 ] = fmax (local_max [0 ], local_max [1 ])
272+ local_max [2 ] = fmax (local_max [2 ], local_max [3 ])
273+ local_max [0 ] = fmax (local_max [0 ], local_max [2 ])
274+ return local_max [0 ] if cutlass .const_expr (init_val is None ) else fmax (local_max [0 ], init_val )
263275 else :
264276 # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max
265277 # We instead force the 3-input max.
@@ -290,6 +302,18 @@ def fadd_reduce(
290302 if cutlass .const_expr (init_val is None ):
291303 init_val = Float32 .zero
292304 return x .reduce (cute .ReductionOp .ADD , init_val , 0 )
305+ # res = cute.make_fragment(x.shape, Float32)
306+ # res.store(x)
307+ # local_sum = [res[0], res[1], res[2], res[3]]
308+ # for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
309+ # local_sum[0] += res[i + 0]
310+ # local_sum[1] += res[i + 1]
311+ # local_sum[2] += res[i + 2]
312+ # local_sum[3] += res[i + 3]
313+ # local_sum[0] += local_sum[1]
314+ # local_sum[2] += local_sum[3]
315+ # local_sum[0] += local_sum[2]
316+ # return local_sum[0] if cutlass.const_expr(init_val is None) else local_sum[0] + init_val
293317 else :
294318 res = cute .make_fragment (x .shape , Float32 )
295319 res .store (x )
@@ -440,3 +464,31 @@ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) ->
440464 val += partial_sum
441465 # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val)
442466 return val
467+
468+
469+ @dsl_user_op
470+ def cvt_f16x2_f32 (a : float | Float32 , b : float | Float32 , to_dtype : Type , * , loc = None , ip = None ) -> cutlass .Int32 :
471+ assert to_dtype in [cutlass .BFloat16 , cutlass .Float16 ], "to_dtype must be BFloat16 or Float16"
472+ return cutlass .Int32 (
473+ llvm .inline_asm (
474+ T .i32 (),
475+ [Float32 (a ).ir_value (loc = loc , ip = ip ), Float32 (b ).ir_value (loc = loc , ip = ip )],
476+ f"cvt.rn.{ 'bf16x2' if to_dtype is cutlass .BFloat16 else 'f16x2' } .f32 $0, $2, $1;" ,
477+ "=r,f,f" ,
478+ has_side_effects = False ,
479+ is_align_stack = False ,
480+ asm_dialect = llvm .AsmDialect .AD_ATT ,
481+ )
482+ )
483+
484+
485+ @cute .jit
486+ def cvt_f16 (src : cute .Tensor , dst : cute .Tensor ):
487+ assert cute .size (dst .shape ) == cute .size (src .shape ), "dst and src must have the same size"
488+ assert cute .size (src .shape ) % 2 == 0 , "src must have an even number of elements"
489+ assert dst .element_type in [cutlass .BFloat16 , cutlass .Float16 ], "dst must be BFloat16 or Float16"
490+ assert src .element_type is Float32 , "src must be Float32"
491+ dst_i32 = cute .recast_tensor (dst , cutlass .Int32 )
492+ assert cute .size (dst_i32 .shape ) * 2 == cute .size (src .shape )
493+ for i in cutlass .range_constexpr (cute .size (dst_i32 )):
494+ dst_i32 [i ] = cvt_f16x2_f32 (src [2 * i ], src [2 * i + 1 ], dst .element_type )
0 commit comments