Skip to content

Commit 1fb5fa4

Browse files
committed
torch-scatter=2.0 support
1 parent 2984f28 commit 1fb5fa4

File tree

4 files changed

+8
-12
lines changed

4 files changed

+8
-12
lines changed

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ Note that only `value` comes with autograd support, as `index` is discrete and t
2828

2929
## Installation
3030

31-
Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
31+
Ensure that at least PyTorch 1.4.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
3232

3333
```
3434
$ python -c "import torch; print(torch.__version__)"
35-
>>> 1.1.0
35+
>>> 1.4.0
3636
3737
$ echo $PATH
3838
>>> /usr/local/cuda/bin:...
@@ -53,7 +53,7 @@ Be sure to import `torch` first before using this package to resolve symbols the
5353
## Coalesce
5454

5555
```
56-
torch_sparse.coalesce(index, value, m, n, op="add", fill_value=0) -> (torch.LongTensor, torch.Tensor)
56+
torch_sparse.coalesce(index, value, m, n, op="add") -> (torch.LongTensor, torch.Tensor)
5757
```
5858

5959
Row-wise sorts `index` and removes duplicate entries.
@@ -67,7 +67,6 @@ For scattering, any operation of [`torch_scatter`](https://github.com/rusty1s/py
6767
* **m** *(int)* - The first dimension of corresponding dense matrix.
6868
* **n** *(int)* - The second dimension of corresponding dense matrix.
6969
* **op** *(string, optional)* - The scatter operation to use. (default: `"add"`)
70-
* **fill_value** *(int, optional)* - The initial fill value of scatter operation. (default: `0`)
7170

7271
### Returns
7372

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
extra_compile_args=extra_compile_args),
4040
]
4141

42-
__version__ = '0.4.3'
42+
__version__ = '0.4.4'
4343
url = 'https://github.com/rusty1s/pytorch_sparse'
4444

4545
install_requires = ['scipy']

torch_sparse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .spmm import spmm
66
from .spspmm import spspmm
77

8-
__version__ = '0.4.3'
8+
__version__ = '0.4.4'
99

1010
__all__ = [
1111
'__version__',

torch_sparse/coalesce.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .utils.unique import unique
55

66

7-
def coalesce(index, value, m, n, op='add', fill_value=0):
7+
def coalesce(index, value, m, n, op='add'):
88
"""Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
99
entries are removed by scattering them together. For scattering, any
1010
operation of `"torch_scatter"<https://github.com/rusty1s/pytorch_scatter>`_
@@ -17,8 +17,6 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
1717
n (int): The second dimension of corresponding dense matrix.
1818
op (string, optional): The scatter operation to use. (default:
1919
:obj:`"add"`)
20-
fill_value (int, optional): The initial fill value of scatter
21-
operation. (default: :obj:`0`)
2220
2321
:rtype: (:class:`LongTensor`, :class:`Tensor`)
2422
"""
@@ -37,8 +35,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
3735
index = torch.stack([row[perm], col[perm]], dim=0)
3836

3937
op = getattr(torch_scatter, 'scatter_{}'.format(op))
40-
value = op(value, inv, 0, None, perm.size(0), fill_value)
41-
if isinstance(value, tuple):
42-
value = value[0]
38+
value = op(value, inv, 0, None, perm.size(0))
39+
value = value[0] if isinstance(value, tuple) else value
4340

4441
return index, value

0 commit comments

Comments
 (0)