4
4
from .utils .unique import unique
5
5
6
6
7
- def coalesce (index , value , m , n , op = 'add' , fill_value = 0 ):
7
+ def coalesce (index , value , m , n , op = 'add' ):
8
8
"""Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
9
9
entries are removed by scattering them together. For scattering, any
10
10
operation of `"torch_scatter"<https://github.com/rusty1s/pytorch_scatter>`_
@@ -17,8 +17,6 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
17
17
n (int): The second dimension of corresponding dense matrix.
18
18
op (string, optional): The scatter operation to use. (default:
19
19
:obj:`"add"`)
20
- fill_value (int, optional): The initial fill value of scatter
21
- operation. (default: :obj:`0`)
22
20
23
21
:rtype: (:class:`LongTensor`, :class:`Tensor`)
24
22
"""
@@ -37,8 +35,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
37
35
index = torch .stack ([row [perm ], col [perm ]], dim = 0 )
38
36
39
37
op = getattr (torch_scatter , 'scatter_{}' .format (op ))
40
- value = op (value , inv , 0 , None , perm .size (0 ), fill_value )
41
- if isinstance (value , tuple ):
42
- value = value [0 ]
38
+ value = op (value , inv , 0 , None , perm .size (0 ))
39
+ value = value [0 ] if isinstance (value , tuple ) else value
43
40
44
41
return index , value
0 commit comments