Skip to content

Commit 01dc504

Browse files
yanbing-jrusty1s
andauthored
Add spmm bf16 support (#269)
* Add spmm bf16 support and add ut * Add bf16 support for spspmm * Disable bf16 test before torch_scatter 2.0.9 * Refactor version compare * Update test/utils.py Co-authored-by: Matthias Fey <[email protected]>
1 parent f7c74ec commit 01dc504

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

csrc/cpu/spmm_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
4444
auto K = mat.size(-1);
4545
auto B = mat.numel() / (N * K);
4646

47-
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
47+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, mat.scalar_type(), "spmm_cpu", [&] {
4848
scalar_t *value_data = nullptr;
4949
auto mat_data = mat.data_ptr<scalar_t>();
5050
auto out_data = out.data_ptr<scalar_t>();
@@ -123,7 +123,7 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
123123
auto row_data = row.data_ptr<int64_t>();
124124
auto rowptr_data = rowptr.data_ptr<int64_t>();
125125
auto col_data = col.data_ptr<int64_t>();
126-
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
126+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, mat.scalar_type(), "spmm_value_bw_cpu", [&] {
127127
auto mat_data = mat.data_ptr<scalar_t>();
128128
auto grad_data = grad.data_ptr<scalar_t>();
129129
auto out_data = out.data_ptr<scalar_t>();

csrc/cpu/spspmm_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
5555
torch::Tensor colC;
5656
torch::optional<torch::Tensor> optional_valueC = torch::nullopt;
5757

58-
AT_DISPATCH_ALL_TYPES(scalar_type, "spspmm", [&] {
58+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, scalar_type, "spspmm", [&] {
5959
AT_DISPATCH_HAS_VALUE(optional_valueA, [&] {
6060
scalar_t *valA_data = nullptr, *valB_data = nullptr;
6161
if (HAS_VALUE) {
@@ -77,7 +77,7 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
7777
if (HAS_VALUE)
7878
tmp_vals[cB] += valA_data[eA] * valB_data[eB];
7979
else
80-
tmp_vals[cB]++;
80+
tmp_vals[cB] += 1;
8181
}
8282
}
8383

test/test_spspmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_sparse_tensor_spspmm(dtype, device):
3535
], dtype=dtype, device=device),
3636
)
3737

38-
expected = torch.eye(10, dtype=dtype, device=device)
38+
expected = torch.eye(10, device=device).to(dtype)
3939

4040
out = x @ x.to_dense().t()
4141
assert torch.allclose(out, expected, atol=1e-2)

test/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import torch
2+
import torch_scatter
3+
from packaging import version
24

35
reductions = ['sum', 'add', 'mean', 'min', 'max']
46

57
dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long]
68
grad_dtypes = [torch.half, torch.float, torch.double]
79

10+
if version.parse(torch_scatter.__version__) > version.parse("2.0.9"):
11+
dtypes.append(torch.bfloat16)
12+
grad_dtypes.append(torch.bfloat16)
13+
814
devices = [torch.device('cpu')]
915
if torch.cuda.is_available():
1016
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')]

0 commit comments

Comments
 (0)