|
| 1 | +# Install the newest triton version with |
| 2 | +# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python" |
| 3 | +import pickle |
| 4 | +import math |
| 5 | +import time |
| 6 | +import torch |
| 7 | +import torch.nn as nn |
| 8 | +import torch.nn.functional as F |
| 9 | + |
| 10 | +from einops import rearrange, repeat |
| 11 | + |
| 12 | +from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward |
| 13 | +from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined |
| 14 | + |
| 15 | +from flash_attn import flash_attn_qkvpacked_func |
| 16 | +from flash_attn_interface import flash_attn_func |
| 17 | + |
| 18 | +try: |
| 19 | + from triton_fused_attention import attention as attention_triton |
| 20 | +except ImportError: |
| 21 | + attention_triton = None |
| 22 | + |
| 23 | +try: |
| 24 | + import xformers.ops as xops |
| 25 | +except ImportError: |
| 26 | + xops = None |
| 27 | + |
| 28 | +try: |
| 29 | + import cudnn |
| 30 | +except ImportError: |
| 31 | + cudnn = None |
| 32 | + |
| 33 | + |
| 34 | +def convert_to_cudnn_type(torch_type): |
| 35 | + if torch_type == torch.float16: |
| 36 | + return cudnn.data_type.HALF |
| 37 | + elif torch_type == torch.bfloat16: |
| 38 | + return cudnn.data_type.BFLOAT16 |
| 39 | + elif torch_type == torch.float32: |
| 40 | + return cudnn.data_type.FLOAT |
| 41 | + elif torch_type == torch.int32: |
| 42 | + return cudnn.data_type.INT32 |
| 43 | + elif torch_type == torch.int64: |
| 44 | + return cudnn.data_type.INT64 |
| 45 | + elif torch_type == torch.float8_e4m3fn: |
| 46 | + return cudnn.data_type.FP8_E4M3 |
| 47 | + elif torch_type == torch.float8_e4m3fn: |
| 48 | + return cudnn.data_type.FP8_E5M2 |
| 49 | + else: |
| 50 | + raise ValueError("Unsupported tensor data type.") |
| 51 | + |
| 52 | +def cudnn_spda_setup(qkv, seqlen_q, seqlen_k, causal=False): |
| 53 | + b, _, _, nheads, headdim = qkv.shape |
| 54 | + assert cudnn is not None, 'CUDNN is not available' |
| 55 | + o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device) |
| 56 | + o_gpu_transposed = torch.as_strided( |
| 57 | + o_gpu, |
| 58 | + [b, nheads, seqlen_q, headdim], |
| 59 | + [nheads * seqlen_q * headdim, headdim, nheads * headdim, 1], |
| 60 | + ) |
| 61 | + stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=qkv.device) |
| 62 | + amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device) |
| 63 | + amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device) |
| 64 | + graph = cudnn.pygraph( |
| 65 | + io_data_type=convert_to_cudnn_type(qkv.dtype), |
| 66 | + intermediate_data_type=cudnn.data_type.FLOAT, |
| 67 | + compute_data_type=cudnn.data_type.FLOAT, |
| 68 | + ) |
| 69 | + new_q = torch.as_strided( |
| 70 | + qkv, |
| 71 | + [b, nheads, seqlen_q, headdim], |
| 72 | + [seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], |
| 73 | + storage_offset=0, |
| 74 | + ) |
| 75 | + q = graph.tensor( |
| 76 | + name = "Q", |
| 77 | + dim = list(new_q.shape), |
| 78 | + stride = list(new_q.stride()), |
| 79 | + data_type=convert_to_cudnn_type(qkv.dtype) |
| 80 | + ) |
| 81 | + new_k = torch.as_strided( |
| 82 | + qkv, |
| 83 | + [b, nheads, seqlen_k, headdim], |
| 84 | + [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], |
| 85 | + storage_offset=nheads * headdim, |
| 86 | + ) |
| 87 | + k = graph.tensor( |
| 88 | + name = "K", |
| 89 | + dim = list(new_k.shape), |
| 90 | + stride = list(new_k.stride()), |
| 91 | + data_type=convert_to_cudnn_type(qkv.dtype) |
| 92 | + ) |
| 93 | + new_v = torch.as_strided( |
| 94 | + qkv, |
| 95 | + [b, nheads, seqlen_k, headdim], |
| 96 | + [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], |
| 97 | + storage_offset=nheads * headdim * 2, |
| 98 | + ) |
| 99 | + v = graph.tensor( |
| 100 | + name = "V", |
| 101 | + dim = list(new_v.shape), |
| 102 | + stride = list(new_v.stride()), |
| 103 | + data_type=convert_to_cudnn_type(qkv.dtype) |
| 104 | + ) |
| 105 | + |
| 106 | + def get_default_scale_tensor(): |
| 107 | + return graph.tensor( |
| 108 | + dim = [1, 1, 1, 1], |
| 109 | + stride = [1, 1, 1, 1], |
| 110 | + data_type=cudnn.data_type.FLOAT |
| 111 | + ) |
| 112 | + |
| 113 | + default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda") |
| 114 | + descale_q = get_default_scale_tensor() |
| 115 | + descale_k = get_default_scale_tensor() |
| 116 | + descale_v = get_default_scale_tensor() |
| 117 | + descale_s = get_default_scale_tensor() |
| 118 | + scale_s = get_default_scale_tensor() |
| 119 | + scale_o = get_default_scale_tensor() |
| 120 | + |
| 121 | + o, _, amax_s, amax_o = graph.sdpa_fp8( |
| 122 | + q=q, |
| 123 | + k=k, |
| 124 | + v=v, |
| 125 | + descale_q=descale_q, |
| 126 | + descale_k=descale_k, |
| 127 | + descale_v=descale_v, |
| 128 | + descale_s=descale_s, |
| 129 | + scale_s=scale_s, |
| 130 | + scale_o=scale_o, |
| 131 | + is_inference=True, |
| 132 | + attn_scale=1.0 / math.sqrt(headdim), |
| 133 | + use_causal_mask=causal, |
| 134 | + name="sdpa", |
| 135 | + ) |
| 136 | + |
| 137 | + o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride()) |
| 138 | + |
| 139 | + amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride()) |
| 140 | + amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride()) |
| 141 | + # stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) |
| 142 | + |
| 143 | + graph.validate() |
| 144 | + graph.build_operation_graph() |
| 145 | + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) |
| 146 | + graph.check_support() |
| 147 | + graph.build_plans() |
| 148 | + |
| 149 | + variant_pack = { |
| 150 | + q: new_q, |
| 151 | + k: new_k, |
| 152 | + v: new_v, |
| 153 | + descale_q: default_scale_gpu, |
| 154 | + descale_k: default_scale_gpu, |
| 155 | + descale_v: default_scale_gpu, |
| 156 | + descale_s: default_scale_gpu, |
| 157 | + scale_s: default_scale_gpu, |
| 158 | + scale_o: default_scale_gpu, |
| 159 | + o: o_gpu_transposed, |
| 160 | + amax_s: amax_s_gpu, |
| 161 | + amax_o: amax_o_gpu, |
| 162 | + } |
| 163 | + |
| 164 | + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) |
| 165 | + |
| 166 | + def run(*args, **kwargs): |
| 167 | + graph.execute(variant_pack, workspace) |
| 168 | + return o_gpu, amax_o_gpu |
| 169 | + |
| 170 | + return run |
| 171 | + |
| 172 | + |
| 173 | +def attention_pytorch(qkv, dropout_p=0.0, causal=True): |
| 174 | + """ |
| 175 | + Arguments: |
| 176 | + qkv: (batch_size, seqlen, 3, nheads, head_dim) |
| 177 | + dropout_p: float |
| 178 | + Output: |
| 179 | + output: (batch_size, seqlen, nheads, head_dim) |
| 180 | + """ |
| 181 | + batch_size, seqlen, _, nheads, d = qkv.shape |
| 182 | + q, k, v = qkv.unbind(dim=2) |
| 183 | + q = rearrange(q, 'b t h d -> (b h) t d') |
| 184 | + k = rearrange(k, 'b s h d -> (b h) d s') |
| 185 | + softmax_scale = 1.0 / math.sqrt(d) |
| 186 | + # Preallocate attn_weights for `baddbmm` |
| 187 | + scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) |
| 188 | + scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), |
| 189 | + '(b h) t s -> b h t s', h=nheads) |
| 190 | + if causal: |
| 191 | + # "triu_tril_cuda_template" not implemented for 'BFloat16' |
| 192 | + # So we have to construct the mask in float |
| 193 | + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) |
| 194 | + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) |
| 195 | + scores = scores + causal_mask.to(dtype=scores.dtype) |
| 196 | + attention = torch.softmax(scores, dim=-1) |
| 197 | + attention_drop = F.dropout(attention, dropout_p) |
| 198 | + output = torch.einsum('bhts,bshd->bthd', attention_drop , v) |
| 199 | + return output.to(dtype=qkv.dtype) |
| 200 | + |
| 201 | +def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): |
| 202 | + assert mode in ["fwd", "bwd", "fwd_bwd"] |
| 203 | + f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) |
| 204 | + return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) |
| 205 | + |
| 206 | +def efficiency(flop, time): |
| 207 | + return (flop / time / 10**12) if not math.isnan(time) else 0.0 |
| 208 | + |
| 209 | +def time_fwd(func, *args, **kwargs): |
| 210 | + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark |
| 211 | + time_f = benchmark_forward(func, *args, **kwargs) |
| 212 | + return time_f[1].mean |
| 213 | + |
| 214 | + |
| 215 | +torch.manual_seed(0) |
| 216 | + |
| 217 | +repeats = 30 |
| 218 | +device = 'cuda' |
| 219 | +# dtype = torch.float16 |
| 220 | +dtype = torch.float8_e4m3fn |
| 221 | + |
| 222 | +bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)] |
| 223 | +# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)] |
| 224 | +# bs_seqlen_vals = [(4, 4096), (2, 8192), (1, 8192 * 2), (4, 4224), (2, 8448), (1, 8448 * 2)] |
| 225 | +# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048)] |
| 226 | +causal_vals = [False, True] |
| 227 | +headdim_vals = [128] |
| 228 | +dim = 2048 |
| 229 | +# dim = 256 |
| 230 | +dropout_p = 0.0 |
| 231 | + |
| 232 | +methods = (["Pytorch", "Flash3", "cuDNN"] |
| 233 | + # + (["Triton"] if attention_triton is not None else []) |
| 234 | + # + (["xformers.c"] if xops is not None else []) |
| 235 | + # + (["xformers.f"] if xops is not None else []) |
| 236 | + ) |
| 237 | + |
| 238 | +time_f = {} |
| 239 | +time_b = {} |
| 240 | +time_f_b = {} |
| 241 | +speed_f = {} |
| 242 | +speed_b = {} |
| 243 | +speed_f_b = {} |
| 244 | +for causal in causal_vals: |
| 245 | + for headdim in headdim_vals: |
| 246 | + for batch_size, seqlen in bs_seqlen_vals: |
| 247 | + torch.cuda.empty_cache() |
| 248 | + config = (causal, headdim, batch_size, seqlen) |
| 249 | + nheads = dim // headdim |
| 250 | + q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.float16, requires_grad=False) for _ in range(3)] |
| 251 | + |
| 252 | + qkv = torch.stack([q, k, v], dim=2) |
| 253 | + qkv = qkv.to(torch.float16) |
| 254 | + f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False) |
| 255 | + time_f[config, "Pytorch"] = f |
| 256 | + res_baseline = attention_pytorch(qkv, dropout_p, causal=causal) |
| 257 | + |
| 258 | + if attention_triton is not None: |
| 259 | + q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn) |
| 260 | + k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn) |
| 261 | + v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn) |
| 262 | + scale = 1 / math.sqrt(headdim) |
| 263 | + f = time_fwd( |
| 264 | + attention_triton, q_transposed, k_transposed, v_transposed, |
| 265 | + causal, scale, repeats=5, verbose=False, desc='Triton' |
| 266 | + ) |
| 267 | + f = time_fwd( |
| 268 | + attention_triton, q_transposed, k_transposed, v_transposed, |
| 269 | + causal, scale, repeats=repeats, verbose=False, desc='Triton' |
| 270 | + ) |
| 271 | + time_f[config, "Triton"] = f |
| 272 | + res = attention_triton( |
| 273 | + q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2), |
| 274 | + causal, scale |
| 275 | + ).half().transpose(1, 2) |
| 276 | + torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5) |
| 277 | + |
| 278 | + # out = torch.empty_like(q) |
| 279 | + q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) |
| 280 | + f = time_fwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False) |
| 281 | + |
| 282 | + # res = flash_attn_func(q, k, v, causal=causal) |
| 283 | + # torch.testing.assert_close(res.half(), res_baseline, atol=0.05, rtol=0.05) |
| 284 | + |
| 285 | + time_f[config, "Flash3"] = f |
| 286 | + |
| 287 | + if cudnn is not None: |
| 288 | + qkv_fp8 = qkv.to(dtype) |
| 289 | + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark |
| 290 | + f = time_fwd( |
| 291 | + cudnn_spda_setup( |
| 292 | + qkv_fp8, seqlen, seqlen, |
| 293 | + causal=causal |
| 294 | + ), |
| 295 | + repeats=repeats, verbose=False |
| 296 | + ) |
| 297 | + time_f[config, "cuDNN"] = f |
| 298 | + # res, amax_o = cudnn_spda_setup( |
| 299 | + # qkv_fp8, seqlen, seqlen, |
| 300 | + # causal=causal |
| 301 | + # )() |
| 302 | + # res = res.half() |
| 303 | + # TODO: CUDNN has numerics issues when |
| 304 | + # num_heads=16, dim=128, seq_len=1024, batch_size=2 |
| 305 | + # or larger sizes. |
| 306 | + # res_cpu = res.cpu().reshape(-1) |
| 307 | + # res_baseline_cpu = res_baseline.cpu().reshape(-1) |
| 308 | + # print(amax_o) |
| 309 | + # print(res) |
| 310 | + # print(res_baseline) |
| 311 | + # for i in range(len(res_cpu)): |
| 312 | + # item = res_cpu[i] |
| 313 | + # item_baseline = res_baseline_cpu[i] |
| 314 | + # if abs(item - item_baseline) > 0.5: |
| 315 | + # print(i) |
| 316 | + # print(item) |
| 317 | + # print(item_baseline) |
| 318 | + # torch.testing.assert_close(res, res_baseline, atol=0.05, rtol=0.05) |
| 319 | + |
| 320 | + print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") |
| 321 | + for method in methods: |
| 322 | + speed_f[config, method] = efficiency( |
| 323 | + flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), |
| 324 | + time_f[config, method] |
| 325 | + ) |
| 326 | + #print (time_f[config,method]) |
| 327 | + print( |
| 328 | + f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {time_f[config, method] * 1e3} ms, " |
| 329 | + ) |
| 330 | + |
| 331 | + |
| 332 | +# with open('flash3_attn_time.plk', 'wb') as fp: |
| 333 | +# pickle.dump((time_f, time_b, time_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) |
0 commit comments