Skip to content

Add activation sparsity (24 + fp8 dynamic quant) subclass #2213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
jcaip committed May 30, 2025
commit 61aedfda446b013e1f63933c7d6524ed3fbaffc5
70 changes: 5 additions & 65 deletions test/sparsity/test_activation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import copy
import unittest

import torch
import torch.nn.functional as F
from parameterized import parameterized

from torchao.dtypes.floatx.cutlass_semi_sparse_layout import ActivationFunction
from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv
from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8
from torchao.prototype.sparsity.activation.utils import SquaredReLU
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8MMConfig,
Expand All @@ -14,16 +17,8 @@
from torchao.sparsity.sparse_api import (
Float8DynamicSemiSparseActivationFloat8WeightConfig,
)

import copy
import unittest

from parameterized import parameterized

from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv
from torchao.sparsity.utils import create_binary_tensor, create_semi_structured_tensor
from torchao.utils import is_sm_at_least_90
from torchao.sparsity import ActivationFunction


@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90")
Expand Down Expand Up @@ -156,61 +151,6 @@ def test_fp8_semi_sparse_activation_linear(M, K, N, do_compile=False):

torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01)

@parameterized.expand(
[
# (1, 8192, 1024, True),
# (64, 8192, 1024, True),
# (1024, 8192, 1024, True),
# (1, 8192, 1024, False),
(64, 8192, 1024, False),
# (1024, 8192, 1024, False),
]
)
@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90")
def test_srelu_fp8_semi_sparse_activation_linear(M, K, N, do_compile=False):
with torch.no_grad():
torch.manual_seed(0)
input_tensor = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda()
# we have to wrap in a sequential block for quantize_ to work properly
reference_linear = torch.nn.Sequential(
SquaredReLU(),
torch.nn.Linear(K, N, bias=False).cuda().to(torch.bfloat16)
)
reference_linear_copy = copy.deepcopy(reference_linear[1])
print(reference_linear_copy)

quantize_(
reference_linear,
Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True)
),
)

if do_compile:
reference_linear.forward = torch.compile(
reference_linear.forward,
fullgraph=True,
)

quantize_(
reference_linear_copy,
Float8DynamicSemiSparseActivationFloat8WeightConfig(
activation_fn=ActivationFunction.SQUARED_RELU,
granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True)
),
)
print(reference_linear_copy)

if do_compile:
reference_linear_copy.forward = torch.compile(
reference_linear_copy.forward, fullgraph=True
)

reference_output = reference_linear(input_tensor)
custom_output = reference_linear_copy(input_tensor)

torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01)


@unittest.skipIf(not torch.cuda.is_available(), "Needs cuda to run")
def test_splitk_sparse_gemv():
Expand Down
Loading