@@ -76,7 +76,8 @@ def spmm_max(src: SparseTensor,
76
76
return torch .ops .torch_sparse .spmm_max (rowptr , col , value , other )
77
77
78
78
79
- def spmm (src : SparseTensor , other : torch .Tensor ,
79
+ def spmm (src : SparseTensor ,
80
+ other : torch .Tensor ,
80
81
reduce : str = "sum" ) -> torch .Tensor :
81
82
if reduce == 'sum' or reduce == 'add' :
82
83
return spmm_sum (src , other )
@@ -97,7 +98,7 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
97
98
edge_index = C ._indices ()
98
99
row , col = edge_index [0 ], edge_index [1 ]
99
100
value : Optional [Tensor ] = None
100
- if src .has_value () and other .has_value ():
101
+ if src .has_value () or other .has_value ():
101
102
value = C ._values ()
102
103
103
104
return SparseTensor (
@@ -114,7 +115,8 @@ def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
114
115
return spspmm_sum (src , other )
115
116
116
117
117
- def spspmm (src : SparseTensor , other : SparseTensor ,
118
+ def spspmm (src : SparseTensor ,
119
+ other : SparseTensor ,
118
120
reduce : str = "sum" ) -> SparseTensor :
119
121
if reduce == 'sum' or reduce == 'add' :
120
122
return spspmm_sum (src , other )
0 commit comments