Skip to content

Commit a980efd

Browse files
Enable index_sort (#306)
* Enable `index_sort` * update * update --------- Co-authored-by: rusty1s <[email protected]>
1 parent e55e833 commit a980efd

File tree

3 files changed

+65
-18
lines changed

3 files changed

+65
-18
lines changed

torch_sparse/storage.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import warnings
2-
from typing import Optional, List, Tuple
2+
from typing import List, Optional, Tuple
33

44
import torch
5-
from torch_scatter import segment_csr, scatter_add
6-
from torch_sparse.utils import Final
5+
from torch_scatter import scatter_add, segment_csr
6+
7+
from torch_sparse.utils import Final, index_sort
78

89
layouts: Final[List[str]] = ['coo', 'csr', 'csc']
910

@@ -151,7 +152,8 @@ def __init__(
151152
idx[1:] *= self._sparse_sizes[1]
152153
idx[1:] += self._col
153154
if (idx[1:] < idx[:-1]).any():
154-
perm = idx[1:].argsort()
155+
max_value = self._sparse_sizes[0] * self._sparse_sizes[1]
156+
_, perm = index_sort(idx[1:], max_value)
155157
self._row = self.row()[perm]
156158
self._col = self._col[perm]
157159
if value is not None:
@@ -163,10 +165,20 @@ def __init__(
163165
def empty(self):
164166
row = torch.tensor([], dtype=torch.long)
165167
col = torch.tensor([], dtype=torch.long)
166-
return SparseStorage(row=row, rowptr=None, col=col, value=None,
167-
sparse_sizes=(0, 0), rowcount=None, colptr=None,
168-
colcount=None, csr2csc=None, csc2csr=None,
169-
is_sorted=True, trust_data=True)
168+
return SparseStorage(
169+
row=row,
170+
rowptr=None,
171+
col=col,
172+
value=None,
173+
sparse_sizes=(0, 0),
174+
rowcount=None,
175+
colptr=None,
176+
colcount=None,
177+
csr2csc=None,
178+
csc2csr=None,
179+
is_sorted=True,
180+
trust_data=True,
181+
)
170182

171183
def has_row(self) -> bool:
172184
return self._row is not None
@@ -209,8 +221,11 @@ def has_value(self) -> bool:
209221
def value(self) -> Optional[torch.Tensor]:
210222
return self._value
211223

212-
def set_value_(self, value: Optional[torch.Tensor],
213-
layout: Optional[str] = None):
224+
def set_value_(
225+
self,
226+
value: Optional[torch.Tensor],
227+
layout: Optional[str] = None,
228+
):
214229
if value is not None:
215230
if get_layout(layout) == 'csc':
216231
value = value[self.csc2csr()]
@@ -221,8 +236,11 @@ def set_value_(self, value: Optional[torch.Tensor],
221236
self._value = value
222237
return self
223238

224-
def set_value(self, value: Optional[torch.Tensor],
225-
layout: Optional[str] = None):
239+
def set_value(
240+
self,
241+
value: Optional[torch.Tensor],
242+
layout: Optional[str] = None,
243+
):
226244
if value is not None:
227245
if get_layout(layout) == 'csc':
228246
value = value[self.csc2csr()]
@@ -375,8 +393,11 @@ def colcount(self) -> torch.Tensor:
375393
if colptr is not None:
376394
colcount = colptr[1:] - colptr[:-1]
377395
else:
378-
colcount = scatter_add(torch.ones_like(self._col), self._col,
379-
dim_size=self._sparse_sizes[1])
396+
colcount = scatter_add(
397+
torch.ones_like(self._col),
398+
self._col,
399+
dim_size=self._sparse_sizes[1],
400+
)
380401
self._colcount = colcount
381402
return colcount
382403

@@ -389,7 +410,8 @@ def csr2csc(self) -> torch.Tensor:
389410
return csr2csc
390411

391412
idx = self._sparse_sizes[0] * self._col + self.row()
392-
csr2csc = idx.argsort()
413+
max_value = self._sparse_sizes[0] * self._sparse_sizes[1]
414+
_, csr2csc = index_sort(idx, max_value)
393415
self._csr2csc = csr2csc
394416
return csr2csc
395417

@@ -401,7 +423,8 @@ def csc2csr(self) -> torch.Tensor:
401423
if csc2csr is not None:
402424
return csc2csr
403425

404-
csc2csr = self.csr2csc().argsort()
426+
max_value = self._sparse_sizes[0] * self._sparse_sizes[1]
427+
_, csc2csr = index_sort(self.csr2csc(), max_value)
405428
self._csc2csr = csc2csr
406429
return csc2csr
407430

@@ -543,7 +566,8 @@ def type(self, dtype: torch.dtype, non_blocking: bool = False):
543566
else:
544567
return self.set_value(
545568
value.to(dtype=dtype, non_blocking=non_blocking),
546-
layout='coo')
569+
layout='coo',
570+
)
547571
else:
548572
return self
549573

torch_sparse/typing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
try:
2+
import pyg_lib # noqa
3+
WITH_PYG_LIB = True
4+
WITH_INDEX_SORT = hasattr(pyg_lib.ops, 'index_sort')
5+
except ImportError:
6+
pyg_lib = object
7+
WITH_PYG_LIB = False
8+
WITH_INDEX_SORT = False

torch_sparse/utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,25 @@
1-
from typing import Any
1+
from typing import Any, Optional, Tuple
2+
3+
import torch
4+
5+
import torch_sparse.typing
6+
from torch_sparse.typing import pyg_lib
27

38
try:
49
from typing_extensions import Final # noqa
510
except ImportError:
611
from torch.jit import Final # noqa
712

813

14+
def index_sort(
15+
inputs: torch.Tensor,
16+
max_value: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
17+
r"""See pyg-lib documentation for more details:
18+
https://pyg-lib.readthedocs.io/en/latest/modules/ops.html"""
19+
if not torch_sparse.typing.WITH_INDEX_SORT: # pragma: no cover
20+
return inputs.sort()
21+
return pyg_lib.ops.index_sort(inputs, max_value)
22+
23+
924
def is_scalar(other: Any) -> bool:
1025
return isinstance(other, int) or isinstance(other, float)

0 commit comments

Comments
 (0)