Skip to content
Open
68 changes: 59 additions & 9 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def liger_cross_entropy_kernel(
loss_ptr,
z_loss_ptr,
loss_stride,
token_accuracy_ptr,
token_accuracy_stride,
n_cols,
n_non_ignore,
sum_non_ignore_weight,
Expand All @@ -42,6 +44,7 @@ def liger_cross_entropy_kernel(
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS: tl.constexpr,
RETURN_TOKEN_ACCURACY: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
Expand All @@ -60,6 +63,8 @@ def liger_cross_entropy_kernel(
loss_ptr: Pointer to tensor to store the loss.
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
loss_stride (int): The stride of the loss tensor.
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
token_accuracy_stride (int): The stride of the token accuracy tensor.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (float): The number of non-ignored elements in the batch.
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
Expand All @@ -69,7 +74,8 @@ def liger_cross_entropy_kernel(
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
Expand All @@ -92,11 +98,17 @@ def liger_cross_entropy_kernel(
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
# For ignored tokens, set token accuracy to 0
if RETURN_TOKEN_ACCURACY:
token_accuracy_ptr += program_id * token_accuracy_stride
tl.store(token_accuracy_ptr, 0.0)
return

loss_ptr += program_id * loss_stride
if RETURN_Z_LOSS:
z_loss_ptr += program_id * loss_stride
if RETURN_TOKEN_ACCURACY:
token_accuracy_ptr += program_id * token_accuracy_stride

if HAS_WEIGHT:
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
Expand All @@ -107,6 +119,7 @@ def liger_cross_entropy_kernel(
# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
if HAS_SOFTCAPPING:
ori_X_y = softcap * tanh(ori_X_y / softcap)
Expand All @@ -127,6 +140,16 @@ def liger_cross_entropy_kernel(
if HAS_SOFTCAPPING:
X_block = softcap * tanh(X_block / softcap)
block_max = tl.max(X_block)

# Track argmax for accuracy computation
if RETURN_TOKEN_ACCURACY and block_max > m:
# Find the index of the maximum value in this block
is_max_mask = X_block == block_max
# Mask out invalid indices with a value larger than n_cols
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
# Get the first (smallest) index where max occurs
argmax_idx = tl.min(masked_offsets)

if label_smoothing > 0:
# scale X beforehand to avoid overflow
if HAS_WEIGHT:
Expand Down Expand Up @@ -256,6 +279,10 @@ def liger_cross_entropy_kernel(
tl.store(loss_ptr, loss)
if RETURN_Z_LOSS:
tl.store(z_loss_ptr, z_loss)
if RETURN_TOKEN_ACCURACY:
# Store 1.0 if prediction is correct, 0.0 otherwise
is_correct = 1.0 if argmax_idx == y else 0.0
tl.store(token_accuracy_ptr, is_correct)


# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
Expand All @@ -274,8 +301,12 @@ def cross_entropy_forward(
reduction,
softcap,
return_z_loss,
return_token_accuracy=False,
):
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
assert isinstance(return_token_accuracy, bool), (
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
)

BT, V = _input.shape
n_rows = BT
Expand All @@ -285,6 +316,9 @@ def cross_entropy_forward(
# unreduced loss
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
token_accuracy_1d = (
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
)

target_mask = target != ignore_index
n_non_ignore = target_mask.sum().item()
Expand Down Expand Up @@ -321,6 +355,10 @@ def cross_entropy_forward(
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
token_accuracy_ptr=token_accuracy_1d,
token_accuracy_stride=token_accuracy_1d.stride(-1)
if return_token_accuracy
else 0, # always 1 if accuracy is enabled
n_cols=V,
n_non_ignore=n_non_ignore,
sum_non_ignore_weight=sum_non_ignore_weight,
Expand All @@ -331,6 +369,7 @@ def cross_entropy_forward(
reduction=reduction,
softcap=softcap,
RETURN_Z_LOSS=return_z_loss,
RETURN_TOKEN_ACCURACY=return_token_accuracy,
BLOCK_SIZE=BLOCK_SIZE,
HAS_WEIGHT=True if weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
Expand All @@ -343,11 +382,14 @@ def cross_entropy_forward(
if reduction == "none":
loss = loss_1d
z_loss = z_loss_1d if return_z_loss else None
token_accuracy = token_accuracy_1d if return_token_accuracy else None
else:
loss = torch.sum(loss_1d)
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
# For accuracy, we compute the mean across all non-ignored tokens
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None

return loss, z_loss, _input
return loss, z_loss, token_accuracy, _input


def cross_entropy_backward(_input, grad_output):
Expand Down Expand Up @@ -395,6 +437,7 @@ def forward(
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
return_token_accuracy: bool = False,
):
"""
The forward pass of the Liger Cross Entropy loss.
Expand All @@ -409,14 +452,15 @@ def forward(
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`

Returns:
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
"""
input_requires_grad = _input.requires_grad

loss, z_loss, _input = cross_entropy_forward(
loss, z_loss, token_accuracy, _input = cross_entropy_forward(
_input,
target,
weight,
Expand All @@ -426,30 +470,35 @@ def forward(
reduction,
softcap,
return_z_loss,
return_token_accuracy,
)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
if input_requires_grad:
ctx.save_for_backward(_input.detach())
ctx.return_z_loss = return_z_loss
ctx.return_token_accuracy = return_token_accuracy

return loss, z_loss
return loss, z_loss, token_accuracy

@staticmethod
def backward(ctx, grad_output, grad_ouput2):
def backward(ctx, grad_output, grad_output2, grad_output3):
"""
The backward pass of the Liger Cross Entropy loss.

Parameters:
ctx : The context object with saved tensors.
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
grad_output2 (tenosr): No use.
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
if ctx.return_z_loss:
del grad_ouput2 # z_loss is only for logging
del grad_output2 # z_loss is only for logging
if ctx.return_token_accuracy:
del grad_output3 # token_accuracy is only for metrics

(_input,) = ctx.saved_tensors
_input = cross_entropy_backward(_input, grad_output)
Expand All @@ -463,4 +512,5 @@ def backward(ctx, grad_output, grad_ouput2):
None,
None,
None,
None,
)
31 changes: 27 additions & 4 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ def fused_linear_cross_entropy_forward(
return_z_loss=False,
accum_dtype=None,
use_token_scaling=False,
return_token_accuracy=False,
):
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
assert isinstance(return_token_accuracy, bool), (
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
)
device = _input.device

input_requires_grad = _input.requires_grad
Expand Down Expand Up @@ -61,6 +65,7 @@ def fused_linear_cross_entropy_forward(

loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None

# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
target_mask = target != ignore_index
Expand Down Expand Up @@ -126,6 +131,7 @@ def fused_linear_cross_entropy_forward(
# unreduced loss
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None

# ensure _input and target are contiguous
logits_chunk = logits_chunk.contiguous()
Expand All @@ -141,6 +147,10 @@ def fused_linear_cross_entropy_forward(
loss_ptr=loss_1d_slice,
z_loss_ptr=z_loss_1d_slice,
loss_stride=loss_1d_slice.stride(-1), # always 1
token_accuracy_ptr=token_accuracy_1d_slice,
token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
if return_token_accuracy
else 0, # always 1 if accuracy is enabled
n_cols=V,
n_non_ignore=total_n_non_ignore,
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
Expand All @@ -151,6 +161,7 @@ def fused_linear_cross_entropy_forward(
reduction=reduction,
softcap=softcap,
RETURN_Z_LOSS=return_z_loss,
RETURN_TOKEN_ACCURACY=return_token_accuracy,
HAS_WEIGHT=True if ce_weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
HAS_GRADIENTS=input_requires_grad,
Expand All @@ -167,6 +178,8 @@ def fused_linear_cross_entropy_forward(
loss_1d[start_idx:end_idx] = loss_1d_slice
if return_z_loss:
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
if return_token_accuracy:
token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
grad_logits_chunk = logits_chunk # chunk_size x V

# Apply token scaling to gradients if requested
Expand Down Expand Up @@ -198,15 +211,18 @@ def fused_linear_cross_entropy_forward(
# Return per-token losses
loss = loss_1d
z_loss = z_loss_1d if return_z_loss else None
token_accuracy = token_accuracy_1d if return_token_accuracy else None
else:
loss = torch.sum(loss_1d)
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
# For accuracy, we compute the mean across all non-ignored tokens
token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None

# Cast back to original dtype
grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None

return loss, z_loss, grad_input, grad_weight, grad_bias
return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias


def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
Expand Down Expand Up @@ -274,6 +290,7 @@ def forward(
return_z_loss: bool = False,
accum_dtype=None,
use_token_scaling: bool = False,
return_token_accuracy: bool = False,
):
"""
Fusing the last linear layer with cross-entropy loss
Expand All @@ -297,9 +314,10 @@ def forward(
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
Default: False.
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
"""

loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
_input=_input,
weight=weight,
target=target,
Expand All @@ -313,6 +331,7 @@ def forward(
return_z_loss=return_z_loss,
accum_dtype=accum_dtype,
use_token_scaling=use_token_scaling,
return_token_accuracy=return_token_accuracy,
)
# downcast to dtype and store for backward
ctx.save_for_backward(
Expand All @@ -321,13 +340,16 @@ def forward(
grad_bias.detach() if bias is not None else None,
)
ctx.return_z_loss = return_z_loss
return loss, z_loss
ctx.return_token_accuracy = return_token_accuracy
return loss, z_loss, token_accuracy

@staticmethod
@amp_custom_bwd
def backward(ctx, grad_output, grad_output2):
def backward(ctx, grad_output, grad_output2, grad_output3):
if ctx.return_z_loss:
del grad_output2 # z_loss is only for logging
if ctx.return_token_accuracy:
del grad_output3 # token_accuracy is only for metrics
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
grad_output, grad_input, grad_weight, grad_bias
Expand All @@ -346,4 +368,5 @@ def backward(ctx, grad_output, grad_output2):
None,
None,
None, # use_token_scaling
None, # return_token_accuracy
)
11 changes: 8 additions & 3 deletions src/liger_kernel/transformers/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
from liger_kernel.transformers.functional import CrossEntropyOutput


class LigerCrossEntropyLoss(torch.nn.Module):
Expand All @@ -15,6 +16,7 @@ def __init__(
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
return_token_accuracy: bool = False,
):
super().__init__()
assert (label_smoothing >= 0) and (label_smoothing <= 1), (
Expand All @@ -33,9 +35,10 @@ def __init__(
self.reduction = reduction
self.softcap = softcap
self.return_z_loss = return_z_loss
self.return_token_accuracy = return_token_accuracy

def forward(self, _input: torch.Tensor, target: torch.Tensor):
loss, z_loss = LigerCrossEntropyFunction.apply(
loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
_input,
target,
self.weight,
Expand All @@ -45,7 +48,9 @@ def forward(self, _input: torch.Tensor, target: torch.Tensor):
self.reduction,
self.softcap,
self.return_z_loss,
self.return_token_accuracy,
)
if not self.return_z_loss:
if not self.return_z_loss and not self.return_token_accuracy:
return loss
return loss, z_loss

return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
Loading
Loading