Skip to content

Commit 5018ac6

Browse files
authored
Fp8 kernel with "in-kernel" transpose of V in producer (Dao-AILab#1100)
* base version * restructure pipelines, add special fp8 epilogue * add variants * add fp8 causal and modify dynamic tile scheduler * better causal schedule * maintain two schedules for non causal and causal * removing macros * fix regression * clean up unneeded methods and variants * fix mistake with NumProducerThreads * base version * restructure pipelines, add special fp8 epilogue * add variants * add fp8 causal and modify dynamic tile scheduler * better causal schedule * maintain two schedules for non causal and causal * removing macros * fix regression * clean up unneeded methods and variants * fix mistake with NumProducerThreads * use seqlen traits * add fp8 .cu files and benchmark script * fix merge issue * fix merge issue * fix merge issue * remove duplicate code * fix regression with varseqlen * move varseqlen init in constexpr * fix test script * more constexpr on varseqlen and add max offset * add back test cases
1 parent c4b9015 commit 5018ac6

16 files changed

+1540
-102
lines changed
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
# Install the newest triton version with
2+
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
3+
import pickle
4+
import math
5+
import time
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.functional as F
9+
10+
from einops import rearrange, repeat
11+
12+
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
13+
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
14+
15+
from flash_attn import flash_attn_qkvpacked_func
16+
from flash_attn_interface import flash_attn_func
17+
18+
try:
19+
from triton_fused_attention import attention as attention_triton
20+
except ImportError:
21+
attention_triton = None
22+
23+
try:
24+
import xformers.ops as xops
25+
except ImportError:
26+
xops = None
27+
28+
try:
29+
import cudnn
30+
except ImportError:
31+
cudnn = None
32+
33+
34+
def convert_to_cudnn_type(torch_type):
35+
if torch_type == torch.float16:
36+
return cudnn.data_type.HALF
37+
elif torch_type == torch.bfloat16:
38+
return cudnn.data_type.BFLOAT16
39+
elif torch_type == torch.float32:
40+
return cudnn.data_type.FLOAT
41+
elif torch_type == torch.int32:
42+
return cudnn.data_type.INT32
43+
elif torch_type == torch.int64:
44+
return cudnn.data_type.INT64
45+
elif torch_type == torch.float8_e4m3fn:
46+
return cudnn.data_type.FP8_E4M3
47+
elif torch_type == torch.float8_e4m3fn:
48+
return cudnn.data_type.FP8_E5M2
49+
else:
50+
raise ValueError("Unsupported tensor data type.")
51+
52+
def cudnn_spda_setup(qkv, seqlen_q, seqlen_k, causal=False):
53+
b, _, _, nheads, headdim = qkv.shape
54+
assert cudnn is not None, 'CUDNN is not available'
55+
o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device)
56+
o_gpu_transposed = torch.as_strided(
57+
o_gpu,
58+
[b, nheads, seqlen_q, headdim],
59+
[nheads * seqlen_q * headdim, headdim, nheads * headdim, 1],
60+
)
61+
stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=qkv.device)
62+
amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)
63+
amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)
64+
graph = cudnn.pygraph(
65+
io_data_type=convert_to_cudnn_type(qkv.dtype),
66+
intermediate_data_type=cudnn.data_type.FLOAT,
67+
compute_data_type=cudnn.data_type.FLOAT,
68+
)
69+
new_q = torch.as_strided(
70+
qkv,
71+
[b, nheads, seqlen_q, headdim],
72+
[seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
73+
storage_offset=0,
74+
)
75+
q = graph.tensor(
76+
name = "Q",
77+
dim = list(new_q.shape),
78+
stride = list(new_q.stride()),
79+
data_type=convert_to_cudnn_type(qkv.dtype)
80+
)
81+
new_k = torch.as_strided(
82+
qkv,
83+
[b, nheads, seqlen_k, headdim],
84+
[seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
85+
storage_offset=nheads * headdim,
86+
)
87+
k = graph.tensor(
88+
name = "K",
89+
dim = list(new_k.shape),
90+
stride = list(new_k.stride()),
91+
data_type=convert_to_cudnn_type(qkv.dtype)
92+
)
93+
new_v = torch.as_strided(
94+
qkv,
95+
[b, nheads, seqlen_k, headdim],
96+
[seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
97+
storage_offset=nheads * headdim * 2,
98+
)
99+
v = graph.tensor(
100+
name = "V",
101+
dim = list(new_v.shape),
102+
stride = list(new_v.stride()),
103+
data_type=convert_to_cudnn_type(qkv.dtype)
104+
)
105+
106+
def get_default_scale_tensor():
107+
return graph.tensor(
108+
dim = [1, 1, 1, 1],
109+
stride = [1, 1, 1, 1],
110+
data_type=cudnn.data_type.FLOAT
111+
)
112+
113+
default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda")
114+
descale_q = get_default_scale_tensor()
115+
descale_k = get_default_scale_tensor()
116+
descale_v = get_default_scale_tensor()
117+
descale_s = get_default_scale_tensor()
118+
scale_s = get_default_scale_tensor()
119+
scale_o = get_default_scale_tensor()
120+
121+
o, _, amax_s, amax_o = graph.sdpa_fp8(
122+
q=q,
123+
k=k,
124+
v=v,
125+
descale_q=descale_q,
126+
descale_k=descale_k,
127+
descale_v=descale_v,
128+
descale_s=descale_s,
129+
scale_s=scale_s,
130+
scale_o=scale_o,
131+
is_inference=True,
132+
attn_scale=1.0 / math.sqrt(headdim),
133+
use_causal_mask=causal,
134+
name="sdpa",
135+
)
136+
137+
o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride())
138+
139+
amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride())
140+
amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride())
141+
# stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
142+
143+
graph.validate()
144+
graph.build_operation_graph()
145+
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
146+
graph.check_support()
147+
graph.build_plans()
148+
149+
variant_pack = {
150+
q: new_q,
151+
k: new_k,
152+
v: new_v,
153+
descale_q: default_scale_gpu,
154+
descale_k: default_scale_gpu,
155+
descale_v: default_scale_gpu,
156+
descale_s: default_scale_gpu,
157+
scale_s: default_scale_gpu,
158+
scale_o: default_scale_gpu,
159+
o: o_gpu_transposed,
160+
amax_s: amax_s_gpu,
161+
amax_o: amax_o_gpu,
162+
}
163+
164+
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
165+
166+
def run(*args, **kwargs):
167+
graph.execute(variant_pack, workspace)
168+
return o_gpu, amax_o_gpu
169+
170+
return run
171+
172+
173+
def attention_pytorch(qkv, dropout_p=0.0, causal=True):
174+
"""
175+
Arguments:
176+
qkv: (batch_size, seqlen, 3, nheads, head_dim)
177+
dropout_p: float
178+
Output:
179+
output: (batch_size, seqlen, nheads, head_dim)
180+
"""
181+
batch_size, seqlen, _, nheads, d = qkv.shape
182+
q, k, v = qkv.unbind(dim=2)
183+
q = rearrange(q, 'b t h d -> (b h) t d')
184+
k = rearrange(k, 'b s h d -> (b h) d s')
185+
softmax_scale = 1.0 / math.sqrt(d)
186+
# Preallocate attn_weights for `baddbmm`
187+
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
188+
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
189+
'(b h) t s -> b h t s', h=nheads)
190+
if causal:
191+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
192+
# So we have to construct the mask in float
193+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
194+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
195+
scores = scores + causal_mask.to(dtype=scores.dtype)
196+
attention = torch.softmax(scores, dim=-1)
197+
attention_drop = F.dropout(attention, dropout_p)
198+
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
199+
return output.to(dtype=qkv.dtype)
200+
201+
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
202+
assert mode in ["fwd", "bwd", "fwd_bwd"]
203+
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
204+
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
205+
206+
def efficiency(flop, time):
207+
return (flop / time / 10**12) if not math.isnan(time) else 0.0
208+
209+
def time_fwd(func, *args, **kwargs):
210+
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
211+
time_f = benchmark_forward(func, *args, **kwargs)
212+
return time_f[1].mean
213+
214+
215+
torch.manual_seed(0)
216+
217+
repeats = 30
218+
device = 'cuda'
219+
# dtype = torch.float16
220+
dtype = torch.float8_e4m3fn
221+
222+
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]
223+
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
224+
# bs_seqlen_vals = [(4, 4096), (2, 8192), (1, 8192 * 2), (4, 4224), (2, 8448), (1, 8448 * 2)]
225+
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048)]
226+
causal_vals = [False, True]
227+
headdim_vals = [128]
228+
dim = 2048
229+
# dim = 256
230+
dropout_p = 0.0
231+
232+
methods = (["Pytorch", "Flash3", "cuDNN"]
233+
# + (["Triton"] if attention_triton is not None else [])
234+
# + (["xformers.c"] if xops is not None else [])
235+
# + (["xformers.f"] if xops is not None else [])
236+
)
237+
238+
time_f = {}
239+
time_b = {}
240+
time_f_b = {}
241+
speed_f = {}
242+
speed_b = {}
243+
speed_f_b = {}
244+
for causal in causal_vals:
245+
for headdim in headdim_vals:
246+
for batch_size, seqlen in bs_seqlen_vals:
247+
torch.cuda.empty_cache()
248+
config = (causal, headdim, batch_size, seqlen)
249+
nheads = dim // headdim
250+
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.float16, requires_grad=False) for _ in range(3)]
251+
252+
qkv = torch.stack([q, k, v], dim=2)
253+
qkv = qkv.to(torch.float16)
254+
f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False)
255+
time_f[config, "Pytorch"] = f
256+
res_baseline = attention_pytorch(qkv, dropout_p, causal=causal)
257+
258+
if attention_triton is not None:
259+
q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
260+
k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
261+
v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn)
262+
scale = 1 / math.sqrt(headdim)
263+
f = time_fwd(
264+
attention_triton, q_transposed, k_transposed, v_transposed,
265+
causal, scale, repeats=5, verbose=False, desc='Triton'
266+
)
267+
f = time_fwd(
268+
attention_triton, q_transposed, k_transposed, v_transposed,
269+
causal, scale, repeats=repeats, verbose=False, desc='Triton'
270+
)
271+
time_f[config, "Triton"] = f
272+
res = attention_triton(
273+
q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2),
274+
causal, scale
275+
).half().transpose(1, 2)
276+
torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5)
277+
278+
# out = torch.empty_like(q)
279+
q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)
280+
f = time_fwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False)
281+
282+
# res = flash_attn_func(q, k, v, causal=causal)
283+
# torch.testing.assert_close(res.half(), res_baseline, atol=0.05, rtol=0.05)
284+
285+
time_f[config, "Flash3"] = f
286+
287+
if cudnn is not None:
288+
qkv_fp8 = qkv.to(dtype)
289+
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
290+
f = time_fwd(
291+
cudnn_spda_setup(
292+
qkv_fp8, seqlen, seqlen,
293+
causal=causal
294+
),
295+
repeats=repeats, verbose=False
296+
)
297+
time_f[config, "cuDNN"] = f
298+
# res, amax_o = cudnn_spda_setup(
299+
# qkv_fp8, seqlen, seqlen,
300+
# causal=causal
301+
# )()
302+
# res = res.half()
303+
# TODO: CUDNN has numerics issues when
304+
# num_heads=16, dim=128, seq_len=1024, batch_size=2
305+
# or larger sizes.
306+
# res_cpu = res.cpu().reshape(-1)
307+
# res_baseline_cpu = res_baseline.cpu().reshape(-1)
308+
# print(amax_o)
309+
# print(res)
310+
# print(res_baseline)
311+
# for i in range(len(res_cpu)):
312+
# item = res_cpu[i]
313+
# item_baseline = res_baseline_cpu[i]
314+
# if abs(item - item_baseline) > 0.5:
315+
# print(i)
316+
# print(item)
317+
# print(item_baseline)
318+
# torch.testing.assert_close(res, res_baseline, atol=0.05, rtol=0.05)
319+
320+
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
321+
for method in methods:
322+
speed_f[config, method] = efficiency(
323+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
324+
time_f[config, method]
325+
)
326+
#print (time_f[config,method])
327+
print(
328+
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {time_f[config, method] * 1e3} ms, "
329+
)
330+
331+
332+
# with open('flash3_attn_time.plk', 'wb') as fp:
333+
# pickle.dump((time_f, time_b, time_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)

0 commit comments

Comments
 (0)