Skip to content

Commit d79f9b4

Browse files
committed
[CrossEntropy] Use online softmax to simplify implementation
1 parent 32792d3 commit d79f9b4

File tree

1 file changed

+27
-34
lines changed

1 file changed

+27
-34
lines changed

flash_attn/ops/triton/cross_entropy.py

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -34,34 +34,37 @@ def cross_entropy_fwd_kernel(
3434
total_classes,
3535
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
3636
n_cols, # shapes
37-
n_rows,
3837
logits_row_stride, # strides
3938
BLOCK_SIZE: tl.constexpr,
4039
HAS_SMOOTHING: tl.constexpr,
4140
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
4241
SPLIT: tl.constexpr,
4342
):
4443
row_idx = tl.program_id(0)
45-
col_block_idx = tl.program_id(1)
4644
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
47-
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
45+
sum_logits = 0.0 # For smoothing
46+
# Statistics for online softmax
47+
m_i = -float("inf")
48+
l_i = 0.0
49+
for col_offset in range(0, n_cols, BLOCK_SIZE):
50+
cols = col_offset + tl.arange(0, BLOCK_SIZE)
51+
logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
52+
tl.float32
53+
) * logit_scale
54+
if HAS_SMOOTHING:
55+
sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))
56+
m_i_new = tl.maximum(m_i, tl.max(logits))
57+
l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
58+
m_i = m_i_new
59+
lse = tl.log(l_i) + m_i
60+
tl.store(lse_ptr + row_idx, lse)
4861
label_idx = tl.load(labels_ptr + row_idx)
49-
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
50-
tl.float32
51-
) * logit_scale
52-
max_logits = tl.max(logits, 0)
53-
if HAS_SMOOTHING:
54-
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
55-
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
56-
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
5762
if label_idx == ignore_index:
5863
loss = 0.0
5964
z_loss = 0.0
6065
else:
6166
label_idx -= class_start_idx
62-
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
63-
n_cols, (col_block_idx + 1) * BLOCK_SIZE
64-
):
67+
if label_idx >= 0 and label_idx < n_cols:
6568
logits_label = tl.load(logits_ptr + label_idx) * logit_scale
6669
if HAS_SMOOTHING:
6770
loss = (
@@ -82,9 +85,9 @@ def cross_entropy_fwd_kernel(
8285
loss += z_loss
8386
else:
8487
z_loss = 0.0
85-
tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
88+
tl.store(loss_ptr + row_idx, loss)
8689
if not SPLIT:
87-
tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
90+
tl.store(z_loss_ptr + row_idx, z_loss)
8891

8992

9093
@triton.heuristics(
@@ -161,27 +164,20 @@ def forward(
161164

162165
if logits.stride(-1) != 1:
163166
logits = logits.contiguous()
164-
# Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
165-
MAX_BLOCK_SIZE = 64 * 1024
167+
MAX_BLOCK_SIZE = 16 * 1024
166168
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
167169
num_warps = (
168170
4
169171
if BLOCK_SIZE < 2048
170172
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
171173
)
172-
# We may split the lse computation across multiple blocks, then do a reduction
173-
# lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
174-
# where having just one thread block processing more than 64k elements is slow.
175-
split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
176-
n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
177-
loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
178-
losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
179-
lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
180-
z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
174+
losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
175+
lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
176+
z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
181177
# Need this, otherwise Triton tries to launch from cuda:0 and we get
182178
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
183179
with torch.cuda.device(logits.device.index):
184-
cross_entropy_fwd_kernel[(n_rows, n_splits)](
180+
cross_entropy_fwd_kernel[(n_rows,)](
185181
losses, # data ptrs
186182
lse,
187183
z_losses,
@@ -194,23 +190,19 @@ def forward(
194190
total_classes,
195191
class_start_idx,
196192
n_cols, # shapes
197-
n_rows,
198193
logits.stride(0), # strides
199194
BLOCK_SIZE=BLOCK_SIZE, # constants
200195
num_warps=num_warps,
201-
SPLIT=split,
196+
SPLIT=world_size > 1,
202197
)
203198

204-
if split:
199+
if world_size > 1:
205200
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
206201
# - predicted logit, and 0 otherwise.
207202
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
208203
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
209204
# For labels not in the vocab of this partition, losses contains
210205
# -0.1 * sum logit / total_classes.
211-
if n_splits > 1:
212-
lse = torch.logsumexp(lse, dim=0)
213-
losses = losses.sum(dim=0)
214206
if world_size > 1:
215207
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
216208
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
@@ -243,6 +235,7 @@ def forward(
243235
ctx.class_start_idx = class_start_idx
244236
ctx.inplace_backward = inplace_backward
245237

238+
246239
return losses, z_losses
247240

248241
@staticmethod

0 commit comments

Comments
 (0)