Skip to content

Commit ed2af8e

Browse files
committed
fix
1 parent 1da33cd commit ed2af8e

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torch_sparse/matmul.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def spmm_max(src: SparseTensor,
7676
return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
7777

7878

79-
def spmm(src: SparseTensor, other: torch.Tensor,
79+
def spmm(src: SparseTensor,
80+
other: torch.Tensor,
8081
reduce: str = "sum") -> torch.Tensor:
8182
if reduce == 'sum' or reduce == 'add':
8283
return spmm_sum(src, other)
@@ -97,7 +98,7 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
9798
edge_index = C._indices()
9899
row, col = edge_index[0], edge_index[1]
99100
value: Optional[Tensor] = None
100-
if src.has_value() and other.has_value():
101+
if src.has_value() or other.has_value():
101102
value = C._values()
102103

103104
return SparseTensor(
@@ -114,7 +115,8 @@ def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
114115
return spspmm_sum(src, other)
115116

116117

117-
def spspmm(src: SparseTensor, other: SparseTensor,
118+
def spspmm(src: SparseTensor,
119+
other: SparseTensor,
118120
reduce: str = "sum") -> SparseTensor:
119121
if reduce == 'sum' or reduce == 'add':
120122
return spspmm_sum(src, other)

0 commit comments

Comments
 (0)