1
1
import warnings
2
- from typing import Optional , List , Tuple
2
+ from typing import List , Optional , Tuple
3
3
4
4
import torch
5
- from torch_scatter import segment_csr , scatter_add
6
- from torch_sparse .utils import Final
5
+ from torch_scatter import scatter_add , segment_csr
6
+
7
+ from torch_sparse .utils import Final , index_sort
7
8
8
9
layouts : Final [List [str ]] = ['coo' , 'csr' , 'csc' ]
9
10
@@ -151,7 +152,8 @@ def __init__(
151
152
idx [1 :] *= self ._sparse_sizes [1 ]
152
153
idx [1 :] += self ._col
153
154
if (idx [1 :] < idx [:- 1 ]).any ():
154
- perm = idx [1 :].argsort ()
155
+ max_value = self ._sparse_sizes [0 ] * self ._sparse_sizes [1 ]
156
+ _ , perm = index_sort (idx [1 :], max_value )
155
157
self ._row = self .row ()[perm ]
156
158
self ._col = self ._col [perm ]
157
159
if value is not None :
@@ -163,10 +165,20 @@ def __init__(
163
165
def empty (self ):
164
166
row = torch .tensor ([], dtype = torch .long )
165
167
col = torch .tensor ([], dtype = torch .long )
166
- return SparseStorage (row = row , rowptr = None , col = col , value = None ,
167
- sparse_sizes = (0 , 0 ), rowcount = None , colptr = None ,
168
- colcount = None , csr2csc = None , csc2csr = None ,
169
- is_sorted = True , trust_data = True )
168
+ return SparseStorage (
169
+ row = row ,
170
+ rowptr = None ,
171
+ col = col ,
172
+ value = None ,
173
+ sparse_sizes = (0 , 0 ),
174
+ rowcount = None ,
175
+ colptr = None ,
176
+ colcount = None ,
177
+ csr2csc = None ,
178
+ csc2csr = None ,
179
+ is_sorted = True ,
180
+ trust_data = True ,
181
+ )
170
182
171
183
def has_row (self ) -> bool :
172
184
return self ._row is not None
@@ -209,8 +221,11 @@ def has_value(self) -> bool:
209
221
def value (self ) -> Optional [torch .Tensor ]:
210
222
return self ._value
211
223
212
- def set_value_ (self , value : Optional [torch .Tensor ],
213
- layout : Optional [str ] = None ):
224
+ def set_value_ (
225
+ self ,
226
+ value : Optional [torch .Tensor ],
227
+ layout : Optional [str ] = None ,
228
+ ):
214
229
if value is not None :
215
230
if get_layout (layout ) == 'csc' :
216
231
value = value [self .csc2csr ()]
@@ -221,8 +236,11 @@ def set_value_(self, value: Optional[torch.Tensor],
221
236
self ._value = value
222
237
return self
223
238
224
- def set_value (self , value : Optional [torch .Tensor ],
225
- layout : Optional [str ] = None ):
239
+ def set_value (
240
+ self ,
241
+ value : Optional [torch .Tensor ],
242
+ layout : Optional [str ] = None ,
243
+ ):
226
244
if value is not None :
227
245
if get_layout (layout ) == 'csc' :
228
246
value = value [self .csc2csr ()]
@@ -375,8 +393,11 @@ def colcount(self) -> torch.Tensor:
375
393
if colptr is not None :
376
394
colcount = colptr [1 :] - colptr [:- 1 ]
377
395
else :
378
- colcount = scatter_add (torch .ones_like (self ._col ), self ._col ,
379
- dim_size = self ._sparse_sizes [1 ])
396
+ colcount = scatter_add (
397
+ torch .ones_like (self ._col ),
398
+ self ._col ,
399
+ dim_size = self ._sparse_sizes [1 ],
400
+ )
380
401
self ._colcount = colcount
381
402
return colcount
382
403
@@ -389,7 +410,8 @@ def csr2csc(self) -> torch.Tensor:
389
410
return csr2csc
390
411
391
412
idx = self ._sparse_sizes [0 ] * self ._col + self .row ()
392
- csr2csc = idx .argsort ()
413
+ max_value = self ._sparse_sizes [0 ] * self ._sparse_sizes [1 ]
414
+ _ , csr2csc = index_sort (idx , max_value )
393
415
self ._csr2csc = csr2csc
394
416
return csr2csc
395
417
@@ -401,7 +423,8 @@ def csc2csr(self) -> torch.Tensor:
401
423
if csc2csr is not None :
402
424
return csc2csr
403
425
404
- csc2csr = self .csr2csc ().argsort ()
426
+ max_value = self ._sparse_sizes [0 ] * self ._sparse_sizes [1 ]
427
+ _ , csc2csr = index_sort (self .csr2csc (), max_value )
405
428
self ._csc2csr = csc2csr
406
429
return csc2csr
407
430
@@ -543,7 +566,8 @@ def type(self, dtype: torch.dtype, non_blocking: bool = False):
543
566
else :
544
567
return self .set_value (
545
568
value .to (dtype = dtype , non_blocking = non_blocking ),
546
- layout = 'coo' )
569
+ layout = 'coo' ,
570
+ )
547
571
else :
548
572
return self
549
573
0 commit comments