Skip to content

Add activation sparsity (24 + fp8 dynamic quant) subclass #2213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip
  • Loading branch information
jcaip committed May 27, 2025
commit 2a68435e726f40b0992595d8006a8b0b53c89204
78 changes: 73 additions & 5 deletions test/sparsity/test_activation24.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ def test_srelu_fp8_semi_sparse_activation_linear(M=512, K=2048, N=1024):
)

# define reference implementation
def srelu_linear(x):
def reference_srelu(x):
x = F.relu(x) ** 2
return reference_linear(x)

reference_srelu = torch.compile(srelu_linear, fullgraph=True)
# reference_srelu = torch.compile(reference_srelu, fullgraph=True)

# this only works with fullgraph=True, errors in eager
# TODO figure out exactly why this happens
Expand All @@ -134,16 +134,55 @@ def srelu_linear(x):
SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(),
)
# (reference_linear_copy)
reference_linear_copy.forward = torch.compile(
reference_linear_copy.forward, fullgraph=True
)
# reference_linear_copy.forward = torch.compile(
# reference_linear_copy.forward, fullgraph=True
# )

reference_output = reference_srelu(input_tensor)
custom_output = reference_linear_copy(input_tensor)

torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01)


from torchao.sparsity.sparse_api import ActivationSparseLinearConfig
@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90")
def test_asdf(M=512, K=2048, N=1024):
with torch.no_grad():
torch.manual_seed(0)
input_tensor = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda()
# we have to wrap in a sequential block for quantize_ to work properly
reference_linear = torch.nn.Sequential(
torch.nn.Linear(K, N, bias=False).cuda().to(torch.bfloat16)
)
reference_linear_copy = copy.deepcopy(reference_linear)

quantize_(
reference_linear,
Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=False)
),
)

# this only works with fullgraph=True, errors in eager
# TODO figure out exactly why this happens
sparsify_(
reference_linear_copy,
ActivationSparseLinearConfig(),
)
# (reference_linear_copy)
# reference_linear_copy.forward = torch.compile(
# reference_linear_copy.forward, fullgraph=True
# )

reference_output = reference_linear(input_tensor)
custom_output = reference_linear_copy(input_tensor)

print(reference_output)
print(custom_output)

torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01)


@unittest.skipIf(not torch.cuda.is_available(), "Needs cuda to run")
def test_splitk_sparse_gemv():
torch.manual_seed(0)
Expand Down Expand Up @@ -224,3 +263,32 @@ def _to_fp8_rowwise(x: torch.Tensor, dtype):
A, B, scale_a=a_scale, scale_b=b_scale, out_dtype=out_sparse.dtype
)
assert torch.allclose(out_sparse, out_ref, rtol=0.01, atol=0.01)

@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90")
def test_sparse24_fp8_sm90_cutlass_gemm_random_tensor_compile(
M=512, N=1024, K=256, dtype=torch.float8_e4m3fn
) -> None:
def _to_fp8_rowwise(x: torch.Tensor, dtype):
max_v = torch.finfo(dtype).max
x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float()
x = (x / x_scale).to(dtype)
return x, x_scale

torch.manual_seed(0)
A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda()
A, a_scale = _to_fp8_rowwise(A_dense, dtype)

B_dense = torch.randn([N, K], device="cuda", dtype=torch.bfloat16)
B, b_scale = _to_fp8_rowwise(B_dense, dtype)

B = B.T
b_scale = b_scale.T

A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A)
out_sparse = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm(
A_packed, A_mdata, B, a_scale=a_scale, b_scale=b_scale
)
out_ref = torch._scaled_mm(
A, B, scale_a=a_scale, scale_b=b_scale, out_dtype=out_sparse.dtype
)
assert torch.allclose(out_sparse, out_ref, rtol=0.01, atol=0.01)
10 changes: 5 additions & 5 deletions torchao/csrc/cuda/activation24/sparse_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,9 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm<false>));
}

TORCH_LIBRARY_IMPL(torchao, Meta, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchao::sparse24_fp8_sm90_cutlass_gemm"),
TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm<true>));
}
// TORCH_LIBRARY_IMPL(torchao, Meta, m) {
// m.impl(
// TORCH_SELECTIVE_NAME("torchao::sparse24_fp8_sm90_cutlass_gemm"),
// TORCH_FN(torchao::_sparse24_fp8_sm90_cutlass_gemm<true>));
// }
#endif
10 changes: 5 additions & 5 deletions torchao/csrc/cuda/activation24/sparsify24.cu
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
TORCH_FN(sparse24_sm90_sparsify<false>));
}

TORCH_LIBRARY_IMPL(torchao, Meta, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchao::sparse24_sm90_sparsify"),
TORCH_FN(sparse24_sm90_sparsify<true>));
}
// TORCH_LIBRARY_IMPL(torchao, Meta, m) {
// m.impl(
// TORCH_SELECTIVE_NAME("torchao::sparse24_sm90_sparsify"),
// TORCH_FN(sparse24_sm90_sparsify<true>));
// }
20 changes: 6 additions & 14 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,26 +843,18 @@ def sparse24_sm90_sparsify(
)


def sparse24_fp8_sm90_cutlass_gemm(
@register_custom_op("torchao::sparse24_fp8_sm90_cutlass_gemm")
def _(
a: Tensor,
meta: Tensor,
b: Tensor,
a_scale: Optional[Tensor],
b_scale: Optional[Tensor],
swizzle_size: int,
swizzle_axis: str,
sm_count: int,
swizzle_size: int = 8,
swizzle_axis: str = 'n',
sm_count: int = 128,
) -> Tensor:
return torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm(
a,
meta,
b,
a_scale=a_scale,
b_scale=b_scale,
swizzle_size=swizzle_size,
swizzle_axis=swizzle_axis,
sm_count=sm_count,
)
return torch.empty(a.shape[0], b.shape[1], dtype=torch.bfloat16, device=a.device)


def swizzle_mm(
Expand Down
41 changes: 30 additions & 11 deletions torchao/prototype/sparsity/activation/srelu_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
):
return FP8SemiSparseActivationLinear.from_dense(module, config)

def _to_fp8_rowwise(x: torch.Tensor, dtype):
max_v = torch.finfo(dtype).max
x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float()
x = (x / x_scale).to(dtype)
return x, x_scale


class FP8SemiSparseActivationLinear(nn.Module):
"""
Expand All @@ -48,12 +54,16 @@ def __init__(self, weight, config) -> None:
super().__init__()
self.config = config

W_aqt = _float8_cutlass_quant(weight, self.config.weight_dtype)
self.Wq = W_aqt.tensor_impl.float8_data
self.W_scale = W_aqt.tensor_impl.scale
# W_aqt = _float8_cutlass_quant(weight, self.config.weight_dtype)
# self.Wq = W_aqt.tensor_impl.float8_data
# self.W_scale = W_aqt.tensor_impl.scale
W, W_scale = _to_fp8_rowwise(weight, self.config.weight_dtype)
self.W = W
self.W_scale = W_scale

def forward(self, x):
X_scale = torch.empty([x.shape[0], 1], device=x.device, dtype=torch.float32)
# X_scale = _float8_cutlass_quant(x, self.config.activation_dtype).tensor_impl.scale.repeat([x.shape[0], 1])
Xq_sparse, X_meta = torch.ops.torchao.sparse24_sm90_sparsify(
x,
"cutlass",
Expand All @@ -62,16 +72,25 @@ def forward(self, x):
dtype=self.config.activation_dtype,
scale=X_scale,
)

result = rowwise_scaled_linear_sparse_cutlass_f8f8(
self.Wq,
self.W_scale,
breakpoint()
result = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm(
Xq_sparse,
X_meta,
X_scale,
bias=None,
out_dtype=torch.bfloat16,
).t()
self.W.T,
a_scale=X_scale,
b_scale=self.W_scale.T,
)


# result = rowwise_scaled_linear_sparse_cutlass_f8f8(
# self.Wq,
# self.W_scale,
# Xq_sparse,
# X_meta,
# X_scale,
# bias=None,
# out_dtype=torch.bfloat16,
# ).t()

return result

Expand Down
Empty file.
Loading