@@ -2567,6 +2567,31 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
25672567 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
25682568
25692569
2570+ def test_max_num_imprecise_acc (device ):
2571+ capability = torch .cuda .get_device_capability ()
2572+ if capability != (9 , 0 ):
2573+ return
2574+
2575+ @triton .jit
2576+ def kernel (X , Y , Z , BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr , BLOCK_K : tl .constexpr ,
2577+ MAX_NUM_IMPRECISE_ACC : tl .constexpr ):
2578+ off_m = tl .arange (0 , BLOCK_M )
2579+ off_n = tl .arange (0 , BLOCK_N )
2580+ off_k = tl .arange (0 , BLOCK_K )
2581+ x = tl .load (X + off_m [:, None ] * BLOCK_K + off_k [None , :])
2582+ y = tl .load (Y + off_k [:, None ] * BLOCK_N + off_n [None , :])
2583+ z = tl .load (Z + off_m [:, None ] * BLOCK_N + off_n [None , :])
2584+ z = tl .dot (x , y , acc = z , max_num_imprecise_acc = MAX_NUM_IMPRECISE_ACC )
2585+ tl .store (Z + off_m [:, None ] * BLOCK_N + off_n [None , :], z )
2586+
2587+ M , N , K , num_warps , MAX_NUM_IMPRECISE_ACC = 128 , 128 , 128 , 4 , 64
2588+ x = torch .zeros ((M , K ), dtype = torch .float8_e5m2 , device = device )
2589+ y = torch .zeros ((K , N ), dtype = torch .float8_e5m2 , device = device )
2590+ z = torch .zeros ((M , N ), dtype = torch .float32 , device = device )
2591+ h = kernel [(1 , 1 )](x , y , z , M , N , K , MAX_NUM_IMPRECISE_ACC , num_warps = num_warps )
2592+ assert h .asm ["ptx" ].count ("add.f32" ) == (M * N ) // (32 * num_warps ) * (K / MAX_NUM_IMPRECISE_ACC )
2593+
2594+
25702595@pytest .mark .parametrize ('in_dtype' , ['float32' ])
25712596def test_dot_mulbroadcastred (in_dtype , device ):
25722597 capability = torch .cuda .get_device_capability ()
0 commit comments