Skip to content

Commit ffc8682

Browse files
committed
Add benchmarking code for Alibi (from Sanghun Cho)
1 parent 204c3c6 commit ffc8682

File tree

1 file changed

+275
-0
lines changed

1 file changed

+275
-0
lines changed

benchmarks/benchmark_alibi.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# Copyright (c) 2024, Sanghun Cho, Tri Dao.
2+
3+
import pickle
4+
import math
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
9+
from einops import rearrange, repeat
10+
from flash_attn.layers.rotary import apply_rotary_emb
11+
12+
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
13+
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
14+
15+
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
16+
17+
try:
18+
import xformers.ops as xops
19+
except ImportError:
20+
xops = None
21+
22+
23+
def generate_cos_sin(seqlen, rotary_dim, device, dtype):
24+
assert rotary_dim % 2 == 0
25+
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
26+
cos = torch.cos(angle).to(dtype=dtype)
27+
sin = torch.sin(angle).to(dtype=dtype)
28+
return cos, sin
29+
30+
31+
def flash_rotary(q, k, v, cos, sin, causal=False):
32+
# corrected by @tridao comments
33+
q = apply_rotary_emb(
34+
q, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
35+
)
36+
k = apply_rotary_emb(
37+
k, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
38+
)
39+
40+
return flash_attn_func(q, k, v, causal=causal)
41+
42+
43+
def attn_bias_from_alibi_slopes(
44+
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
45+
):
46+
batch, nheads = slopes.shape
47+
device = slopes.device
48+
slopes = rearrange(slopes, "b h -> b h 1 1")
49+
if causal:
50+
return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes
51+
else:
52+
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
53+
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
54+
sk = (
55+
seqlen_k
56+
if key_padding_mask is None
57+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
58+
)
59+
sq = (
60+
seqlen_q
61+
if query_padding_mask is None
62+
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
63+
)
64+
relative_pos = torch.abs(row_idx + sk - sq - col_idx)
65+
return -slopes * relative_pos.to(dtype=slopes.dtype)
66+
67+
68+
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
69+
assert mode in ["fwd", "bwd", "fwd_bwd"]
70+
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
71+
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
72+
73+
74+
def efficiency(flop, time):
75+
return (flop / time / 10**12) if not math.isnan(time) else 0.0
76+
77+
78+
def attention_pytorch(q, k, v, dropout_p=0.0, causal=True, attn_bias=None):
79+
"""
80+
Arguments:
81+
q, k, v: (batch_size, seqlen, nheads, head_dim)
82+
dropout_p: float
83+
attn_bias: (batch_size, nheads, seqlen, seqlen) or (1, nheads, seqlen, seqlen)
84+
Output:
85+
output: (batch_size, seqlen, nheads, head_dim)
86+
"""
87+
batch_size, seqlen, nheads, d = q.shape
88+
q = rearrange(q, 'b t h d -> (b h) t d')
89+
k = rearrange(k, 'b s h d -> (b h) d s')
90+
softmax_scale = 1.0 / math.sqrt(d)
91+
# Preallocate attn_weights for `baddbmm`
92+
if attn_bias is not None:
93+
scores = rearrange(attn_bias, 'b h t s -> (b h) t s')
94+
else:
95+
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=q.dtype, device=q.device)
96+
scores = rearrange(torch.baddbmm(scores, q, k, beta=1.0, alpha=softmax_scale),
97+
'(b h) t s -> b h t s', h=nheads)
98+
if causal:
99+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
100+
# So we have to construct the mask in float
101+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
102+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
103+
scores = scores + causal_mask.to(dtype=scores.dtype)
104+
attention = torch.softmax(scores, dim=-1)
105+
attention_drop = F.dropout(attention, dropout_p)
106+
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
107+
return output.to(dtype=q.dtype)
108+
109+
110+
def time_fwd_bwd(func, *args, **kwargs):
111+
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
112+
return time_f[1].mean, time_b[1].mean
113+
114+
115+
repeats = 30
116+
device = 'cuda'
117+
dtype = torch.float16
118+
119+
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
120+
causal_vals = [False, True]
121+
headdim_vals = [64, 128]
122+
dim = 2048
123+
dropout_p = 0.0
124+
125+
methods = (["fa2_alibi", "torch"]
126+
+ (["xformers"] if xops is not None else [])
127+
+ ["sdpa"]
128+
+ ["fa2_baseline"]
129+
+ ["fa2_rotary"])
130+
131+
time_f = {}
132+
time_b = {}
133+
time_f_b = {}
134+
speed_f = {}
135+
speed_b = {}
136+
speed_f_b = {}
137+
for causal in causal_vals:
138+
for headdim in headdim_vals:
139+
for batch_size, seqlen in bs_seqlen_vals:
140+
config = (causal, headdim, batch_size, seqlen)
141+
nheads = dim // headdim
142+
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
143+
requires_grad=True) for _ in range(3)]
144+
# alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
145+
alibi_slopes = torch.rand(1, nheads, device=device, dtype=torch.float32) * 0.3
146+
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal).to(dtype)
147+
attn_bias = repeat(attn_bias, "1 ... -> b ...", b=batch_size)
148+
f, b = time_fwd_bwd(
149+
flash_attn_func,
150+
q, k, v,
151+
dropout_p,
152+
causal=causal,
153+
# alibi_slopes=alibi_slopes,
154+
alibi_slopes=None,
155+
repeats=repeats,
156+
verbose=False
157+
)
158+
time_f[config, "fa2_baseline"] = f
159+
time_b[config, "fa2_baseline"] = b
160+
161+
q = q.detach().requires_grad_(True)
162+
k = k.detach().requires_grad_(True)
163+
v = v.detach().requires_grad_(True)
164+
f, b = time_fwd_bwd(
165+
flash_attn_func,
166+
q, k, v,
167+
dropout_p,
168+
causal=causal,
169+
alibi_slopes=rearrange(alibi_slopes, "1 h -> h"),
170+
# alibi_slopes=None,
171+
repeats=repeats,
172+
verbose=False
173+
)
174+
time_f[config, "fa2_alibi"] = f
175+
time_b[config, "fa2_alibi"] = b
176+
177+
try:
178+
q = q.detach().requires_grad_(True)
179+
k = k.detach().requires_grad_(True)
180+
v = v.detach().requires_grad_(True)
181+
f, b = time_fwd_bwd(
182+
attention_pytorch,
183+
q, k, v,
184+
dropout_p,
185+
causal=causal,
186+
attn_bias=attn_bias,
187+
repeats=repeats,
188+
verbose=False
189+
)
190+
except: # Skip if OOM
191+
f, b = float('nan'), float('nan')
192+
time_f[config, "torch"] = f
193+
time_b[config, "torch"] = b
194+
195+
# F.sdpa doesn't currently (torch 2.1) dispatch to flash-attn but just to be safe
196+
with torch.backends.cuda.sdp_kernel(enable_flash=False):
197+
q_pt = q.detach().requires_grad_(True).transpose(1, 2)
198+
k_pt = k.detach().requires_grad_(True).transpose(1, 2)
199+
v_pt = v.detach().requires_grad_(True).transpose(1, 2)
200+
f, b = time_fwd_bwd(
201+
F.scaled_dot_product_attention,
202+
q_pt, k_pt, v_pt,
203+
attn_mask=attn_bias,
204+
dropout_p=dropout_p,
205+
is_causal=causal,
206+
repeats=repeats,
207+
verbose=False
208+
)
209+
time_f[config, "sdpa"] = f
210+
time_b[config, "sdpa"] = b
211+
212+
if xops is not None:
213+
q = q.detach().requires_grad_(True)
214+
k = k.detach().requires_grad_(True)
215+
v = v.detach().requires_grad_(True)
216+
if causal:
217+
attn_bias_xops = xops.LowerTriangularMask().add_bias(attn_bias.expand(-1, -1, seqlen, -1).to(dtype=q.dtype))
218+
# NotImplementedError: No operator found for `memory_efficient_attention_backward` with inputs:
219+
# `[email protected]` is not supported because:
220+
# attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
221+
# `cutlassB` is not supported because:
222+
# attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
223+
attn_bias_xops = attn_bias_xops.materialize((batch_size, nheads, seqlen, seqlen), dtype=q.dtype, device=device)
224+
else:
225+
attn_bias_xops = attn_bias.to(dtype=q.dtype)
226+
f, b = time_fwd_bwd(
227+
xops.memory_efficient_attention,
228+
q, k, v,
229+
attn_bias_xops,
230+
dropout_p,
231+
repeats=repeats,
232+
verbose=False
233+
)
234+
time_f[config, "xformers"] = f
235+
time_b[config, "xformers"] = b
236+
237+
q = q.detach().requires_grad_(True)
238+
k = k.detach().requires_grad_(True)
239+
v = v.detach().requires_grad_(True)
240+
cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
241+
f, b = time_fwd_bwd(
242+
flash_rotary,
243+
q, k, v,
244+
cos, sin,
245+
causal,
246+
repeats=repeats,
247+
verbose=False
248+
)
249+
time_f[config, "fa2_rotary"] = f
250+
time_b[config, "fa2_rotary"] = b
251+
252+
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
253+
csv_output = ""
254+
csv_output += f"{causal},{headdim},{batch_size},{seqlen},"
255+
for method in methods:
256+
time_f_b[config, method] = time_f[config, method] + time_b[config, method]
257+
speed_f[config, method] = efficiency(
258+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
259+
time_f[config, method]
260+
)
261+
speed_b[config, method] = efficiency(
262+
flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
263+
time_b[config, method]
264+
)
265+
speed_f_b[config, method] = efficiency(
266+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
267+
time_f_b[config, method]
268+
)
269+
print(
270+
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
271+
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
272+
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
273+
)
274+
csv_output += f"{speed_f[config, method]:.2f},{speed_b[config, method]:.2f},{speed_f_b[config, method]:.2f},"
275+
print(csv_output)

0 commit comments

Comments
 (0)