@@ -48,14 +48,13 @@ def convert_to_cudnn_type(torch_type):
4848 raise ValueError ("Unsupported tensor data type." )
4949
5050
51- def cudnn_sdpa_setup (q , k , v , grad , causal = False , varlen = False , seqlens = None ):
51+ def cudnn_sdpa_setup (q , k , v , grad , o , stats , causal = False , varlen = False , seqlens = None ):
5252 b , nheads , seqlen_q , headdim = q .shape
53- _ , _ , seqlen_k , _ = k .shape
54- assert v .shape == (b , nheads , seqlen_k , headdim )
53+ _ , nheads_kv , seqlen_k , _ = k .shape
54+ assert v .shape == (b , nheads_kv , seqlen_k , headdim )
5555 assert cudnn is not None , 'CUDNN is not available'
5656 q_gpu , k_gpu , v_gpu = q , k , v
57- o_gpu = torch .empty_like (q_gpu )
58- stats_gpu = torch .empty (b , nheads , seqlen_q , 1 , dtype = torch .float32 , device = q .device )
57+ o_gpu , stats_gpu = o , stats
5958 graph_forward = cudnn .pygraph (
6059 io_data_type = convert_to_cudnn_type (q .dtype ),
6160 intermediate_data_type = cudnn .data_type .FLOAT ,
@@ -65,7 +64,7 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False, varlen=False, seqlens=None):
6564 k_forward = graph_forward .tensor_like (k_gpu .detach ())
6665 v_forward = graph_forward .tensor_like (v_gpu .detach ())
6766
68- seqlens_reshaped = seqlens . reshape ( b , 1 , 1 , 1 ). contiguous (). cuda () if varlen else None
67+ seqlens_reshaped = seqlens if varlen else None
6968 seq_len_q = graph_forward .tensor_like (seqlens_reshaped .detach ()) if varlen else None
7069 seq_len_kv = graph_forward .tensor_like (seqlens_reshaped .detach ()) if varlen else None
7170
@@ -193,8 +192,8 @@ def run_bwd(*args, **kwargs):
193192# headdim = 64
194193headdim = 256
195194
196- # for mode in ['fwd', 'bwd']:
197- for mode in ['fwd ' ]:
195+ for mode in ['fwd' , 'bwd' ]:
196+ # for mode in ['bwd ']:
198197 for headdim in [64 , 128 , 256 ]:
199198 # for headdim in [128]:
200199 for seqlen in [1024 , 2048 , 4096 , 8192 , 16384 , 32768 ]:
@@ -206,39 +205,46 @@ def run_bwd(*args, **kwargs):
206205 # seqlen = 512
207206 # nheads = 8
208207 # headdim = 128
208+ # nheads = 16
209+ # headdim = 128
209210 nheads_kv = nheads
211+ # nheads_kv = 1
210212
211213 qkv = torch .randn (batch_size , seqlen , 3 , nheads , headdim , device = device , dtype = dtype ,
212214 requires_grad = True )
213215 q = torch .randn (batch_size , seqlen , nheads , headdim , device = device , dtype = dtype , requires_grad = True )
214- k = torch .randn (batch_size , seqlen , nheads , headdim , device = device , dtype = dtype , requires_grad = True )
215- v = torch .randn (batch_size , seqlen , nheads , headdim , device = device , dtype = dtype , requires_grad = True )
216+ k = torch .randn (batch_size , seqlen , nheads_kv , headdim , device = device , dtype = dtype , requires_grad = True )
217+ v = torch .randn (batch_size , seqlen , nheads_kv , headdim , device = device , dtype = dtype , requires_grad = True )
216218 q_t = q .transpose (1 , 2 ).contiguous ().detach ().requires_grad_ ()
217219 k_t = k .transpose (1 , 2 ).contiguous ().detach ().requires_grad_ ()
218220 v_t = k .transpose (1 , 2 ).contiguous ().detach ().requires_grad_ ()
219221 grad = torch .randn (batch_size , seqlen , nheads , headdim , device = device , dtype = dtype )
220222 grad_t = grad .transpose (1 , 2 ).contiguous ()
223+ o_t = torch .empty_like (q .transpose (1 , 2 ))
224+ stats = torch .empty (batch_size , nheads , seqlen , 1 , dtype = torch .float32 , device = q .device )
221225
222226 bench_fn = benchmark_forward if mode == 'fwd' else partial (benchmark_backward , grad = grad )
223227
224228 for causal in [False , True ]:
225229 # for causal in [True]:
226- print (f"\n ### { headdim = } , { seqlen = } , { causal = } ###" )
230+ print (f"\n ### { mode = } , { batch_size = } , { headdim = } , { seqlen = } , { causal = } ###" )
227231 # For var-seq-len
228232 lens = torch .full ([q .shape [0 ]], seqlen , dtype = torch .int32 )
233+ seqlens_cudnn = lens .reshape (batch_size , 1 , 1 , 1 ).contiguous ().cuda ()
229234 cu_seqlens = torch .cat ([torch .tensor ([0 ], dtype = torch .int32 ), torch .cumsum (lens , dim = 0 , dtype = torch .int32 )]).cuda ()
230235 if headdim <= 128 and cudnn is not None :
231- cudnn_sdpa_fwd , cudnn_sdpa_bwd = cudnn_sdpa_setup (q .transpose (1 , 2 ), k .transpose (1 , 2 ), v .transpose (1 , 2 ), grad .transpose (1 , 2 ), causal = causal )
232- cudnn_sdpa_fwd_varlen , cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup (q .transpose (1 , 2 ), k .transpose (1 , 2 ), v .transpose (1 , 2 ), grad .transpose (1 , 2 ), causal = causal , varlen = True , seqlens = lens )
236+ cudnn_sdpa_fwd , cudnn_sdpa_bwd = cudnn_sdpa_setup (q .transpose (1 , 2 ), k .transpose (1 , 2 ), v .transpose (1 , 2 ), grad .transpose (1 , 2 ), o_t , stats , causal = causal )
237+ cudnn_sdpa_fwd_varlen , cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup (q .transpose (1 , 2 ), k .transpose (1 , 2 ), v .transpose (1 , 2 ), grad .transpose (1 , 2 ), o_t , stats , causal = causal , varlen = True , seqlens = seqlens_cudnn )
233238 f = flops (batch_size , nheads , seqlen , seqlen , headdim , causal = causal , mode = mode )
239+ ref_o = flash_attn_func (q , k , v , dropout_p , causal = causal )
234240 _ , m0 = bench_fn (flash_attn_func , q , k , v , dropout_p , causal = causal , repeats = repeats , verbose = verbose , desc = 'Fav2' )
235241 if mode == 'bwd' :
236242 ref_dv , v .grad = v .grad .clone (), None
237243 ref_dk , k .grad = k .grad .clone (), None
238244 ref_dq , q .grad = q .grad .clone (), None
239245 # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False)
240246 if headdim <= 128 :
241- if triton_attention is not None :
247+ if triton_attention is not None and nheads_kv == nheads :
242248 if mode == 'fwd' :
243249 time .sleep (1 ) # Sleep to avoid residual power throttling from the previous benchmark
244250 _ , m3 = benchmark_forward (triton_attention , q_t , k_t , v_t , causal , 1 / math .sqrt (headdim ), repeats = repeats , verbose = verbose , desc = 'Triton' )
@@ -255,22 +261,31 @@ def run_bwd(*args, **kwargs):
255261 if mode == 'fwd' :
256262 _ , m2 = benchmark_forward (cudnn_sdpa_fwd , repeats = repeats , verbose = verbose , desc = 'CuDNN' )
257263 _ , m2_var = benchmark_forward (cudnn_sdpa_fwd_varlen , repeats = repeats , verbose = verbose , desc = 'CuDNN' )
264+ cudnn_sdpa_fwd ()
265+ torch .testing .assert_close (ref_o , o_t .transpose (1 , 2 ), atol = 0.05 , rtol = 0.05 )
266+ cudnn_sdpa_fwd_varlen ()
267+ torch .testing .assert_close (ref_o , o_t .transpose (1 , 2 ), atol = 0.05 , rtol = 0.05 )
258268 else :
259269 cudnn_sdpa_fwd ()
260270 _ , m2 = benchmark_forward (cudnn_sdpa_bwd , repeats = repeats , verbose = verbose , desc = 'CuDNN' )
271+ _ , m2_var = benchmark_forward (cudnn_sdpa_bwd_varlen , repeats = repeats , verbose = verbose , desc = 'CuDNN' )
261272 dq , dk , dv = cudnn_sdpa_bwd ()
262273 torch .testing .assert_close (ref_dv , dv .transpose (1 , 2 ), atol = 0.05 , rtol = 0.05 )
263274 torch .testing .assert_close (ref_dk , dk .transpose (1 , 2 ), atol = 0.05 , rtol = 0.05 )
264275 torch .testing .assert_close (ref_dq , dq .transpose (1 , 2 ), atol = 0.05 , rtol = 0.05 )
276+ dq , dk , dv = cudnn_sdpa_bwd_varlen ()
277+ torch .testing .assert_close (ref_dv , dv .transpose (1 , 2 ), atol = 0.05 , rtol = 0.05 )
278+ torch .testing .assert_close (ref_dk , dk .transpose (1 , 2 ), atol = 0.05 , rtol = 0.05 )
279+ torch .testing .assert_close (ref_dq , dq .transpose (1 , 2 ), atol = 0.05 , rtol = 0.05 )
265280 # pytorch_profiler(cudnn_sdpa, backward=False)
266- if headdim == 128 or mode == 'fwd' :
281+
282+ if headdim <= 128 or mode == 'fwd' :
267283 time .sleep (1 )
268284 _ , m1 = bench_fn (flash_attn_func_v3 , q , k , v , causal = causal , repeats = repeats , verbose = verbose , desc = 'Fav3' )
269285 q_var = q .reshape (- 1 , q .shape [- 2 ], q .shape [- 1 ])
270286 k_var = k .reshape (- 1 , k .shape [- 2 ], k .shape [- 1 ])
271287 v_var = v .reshape (- 1 , v .shape [- 2 ], v .shape [- 1 ])
272288 time .sleep (1 )
273- _ , m1_var = bench_fn (flash_attn_varlen_func_v3 , q_var , k_var , v_var , cu_seqlens , cu_seqlens , seqlen , seqlen , causal = causal , repeats = repeats , verbose = verbose , desc = 'Fav3 var len' )
274289 if mode == 'bwd' :
275290 dv , v .grad = v .grad .clone (), None
276291 dk , k .grad = k .grad .clone (), None
@@ -279,15 +294,21 @@ def run_bwd(*args, **kwargs):
279294 torch .testing .assert_close (ref_dk , dk , atol = 0.05 , rtol = 0.05 )
280295 torch .testing .assert_close (ref_dq , dq , atol = 0.05 , rtol = 0.05 )
281296
297+ bench_var_fn = bench_fn
298+ if mode == 'bwd' :
299+ grad_var = grad .reshape (- 1 , grad .shape [- 2 ], grad .shape [- 1 ])
300+ bench_var_fn = partial (benchmark_backward , grad = grad_var )
301+ _ , m1_var = bench_var_fn (flash_attn_varlen_func_v3 , q_var , k_var , v_var , cu_seqlens , cu_seqlens , seqlen , seqlen , causal = causal , repeats = repeats , verbose = verbose , desc = 'Fav3 var len' )
302+
282303 # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False)
283304 print (f'Fav2: { m0 .mean * 1e3 :.3f} ms, { (f / m0 .mean * 1e-12 ):.1f} TFLOPS' )
284305 if headdim <= 128 :
285- if triton_attention is not None :
306+ if mode == 'fwd' and triton_attention is not None and nheads_kv == nheads :
286307 print (f'Triton: { m3 .mean * 1e3 :.3f} ms, { (f / m3 .mean * 1e-12 ):.1f} TFLOPS' )
287308 if cudnn is not None :
288309 print (f'CuDNN: { m2 .mean * 1e3 :.3f} ms, { (f / m2 .mean * 1e-12 ):.1f} TFLOPS' )
289310 print (f'CuDNN varlen: { m2_var .mean * 1e3 :.3f} ms, { (f / m2_var .mean * 1e-12 ):.1f} TFLOPS' )
290- if headdim = = 128 or mode == 'fwd' :
311+ if headdim < = 128 or mode == 'fwd' :
291312 print (f'Fav3: { m1 .mean * 1e3 :.3f} ms, { (f / m1 .mean * 1e-12 ):.1f} TFLOPS' )
292313 print (f'Fav3 varlen: { m1_var .mean * 1e3 :.3f} ms, { (f / m1_var .mean * 1e-12 ):.1f} TFLOPS' )
293314
0 commit comments