Skip to content

Commit 8b37ad5

Browse files
puririshi98pre-commit-ci[bot]rusty1s
authored
Added support for torch.sparse.Tensor in DataLoader (pyg-team#7252)
this implementation isnt working yet, it currently fails with shape mismatch on a Linear layer but passes the collate part example repro: `cd /opt/pyg; pip uninstall -y torch-geometric torch-sparse; rm -rf pytorch_geometric; git clone -b collate_fix https://github.com/pyg-team/pytorch_geometric.git; cd /opt/pyg/pytorch_geometric; pip install .; python3 examples/gcn2_ppi.py` ``` e_idxs_to_cat.size()= [torch.Size([2, 48146]), torch.Size([2, 88335])] value.size()= torch.Size([4693, 2815]) Traceback (most recent call last): File "examples/gcn2_ppi.py", line 93, in <module> loss = train() File "examples/gcn2_ppi.py", line 70, in train loss = criterion(model(data.x, data.adj_t), data.y) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1533, in _call_impl return forward_call(*args, **kwargs) File "examples/gcn2_ppi.py", line 46, in forward h = conv(h, x_0, adj_t) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1533, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/conv/gcn2_conv.py", line 138, in forward x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) File "/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/conv/message_passing.py", line 437, in propagate out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs) File "/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/conv/gcn2_conv.py", line 159, in message_and_aggregate return spmm(adj_t, x, reduce=self.aggr) File "/usr/local/lib/python3.8/dist-packages/torch_geometric/utils/spmm.py", line 80, in spmm return torch.sparse.mm(src, other) ``` (just remove the check that triggers `This example requires 'torch-sparse'`) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s <[email protected]>
1 parent cc6256e commit 8b37ad5

File tree

8 files changed

+124
-15
lines changed

8 files changed

+124
-15
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
### Added
99

10+
- Added support for `torch.sparse.Tensor` in `DataLoader` ([#7252](https://github.com/pyg-team/pytorch_geometric/pull/7252))
1011
- Added `save` and `load` methods to `InMemoryDataset` ([#7250](https://github.com/pyg-team/pytorch_geometric/pull/7250))
1112
- Added an example for heterogeneous GNN explanation via `CaptumExplainer` ([#7096](https://github.com/pyg-team/pytorch_geometric/pull/7096))
1213
- Added `visualize_feature_importance` functionality to `HeteroExplanation` ([#7096](https://github.com/pyg-team/pytorch_geometric/pull/7096))

examples/gcn2_ppi.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99
from torch_geometric.datasets import PPI
1010
from torch_geometric.loader import DataLoader
1111
from torch_geometric.nn import GCN2Conv
12-
from torch_geometric.typing import WITH_TORCH_SPARSE
13-
14-
if not WITH_TORCH_SPARSE:
15-
quit("This example requires 'torch-sparse'")
1612

1713
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'GCN2_PPI')
1814
pre_transform = T.Compose([T.GCNNorm(), T.ToSparseTensor()])

test/data/test_batch.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import os.path as osp
22

33
import numpy as np
4+
import pytest
45
import torch
56

67
import torch_geometric
78
from torch_geometric.data import Batch, Data, HeteroData
89
from torch_geometric.testing import get_random_edge_index, withPackage
910
from torch_geometric.typing import SparseTensor
11+
from torch_geometric.utils import to_edge_index, to_torch_sparse_tensor
1012

1113

1214
def test_batch_basic():
@@ -466,12 +468,10 @@ def tr(n, m):
466468
d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)], a={"aa": tr(8, 3)},
467469
x=tr(8, 5))
468470

469-
# Dataset
470471
data_list = [d1, d2, d3, d4]
471472

472473
batch = Batch.from_data_list(data_list, follow_batch=['xs', 'a'])
473474

474-
# assert shapes
475475
assert batch.xs[0].shape == (19, 3)
476476
assert batch.xs[1].shape == (56, 4)
477477
assert batch.xs[2].shape == (7, 2)
@@ -480,7 +480,6 @@ def tr(n, m):
480480
assert len(batch.xs_batch) == 3
481481
assert len(batch.a_batch) == 1
482482

483-
# assert _batch
484483
assert batch.xs_batch[0].tolist() == \
485484
[0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3]
486485
assert batch.xs_batch[1].tolist() == \
@@ -490,3 +489,31 @@ def tr(n, m):
490489

491490
assert batch.a_batch['aa'].tolist() == \
492491
[0] * 11 + [1] * 2 + [2] * 4 + [3] * 8
492+
493+
494+
@withPackage('torch>=2.0.0')
495+
@pytest.mark.parametrize('layout', [
496+
torch.sparse_coo,
497+
torch.sparse_csr,
498+
torch.sparse_csc,
499+
])
500+
def test_torch_sparse_batch(layout):
501+
x_dense = torch.randn(3, 4)
502+
x = x_dense.to_sparse(layout=layout)
503+
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
504+
edge_attr = torch.rand(4)
505+
adj = to_torch_sparse_tensor(edge_index, edge_attr, layout=layout)
506+
507+
data = Data(x=x, adj=adj)
508+
509+
batch = Batch.from_data_list([data, data])
510+
511+
assert batch.x.size() == (6, 4)
512+
assert batch.x.layout == layout
513+
assert torch.equal(batch.x.to_dense(), torch.cat([x_dense, x_dense], 0))
514+
515+
assert batch.adj.size() == (6, 6)
516+
assert batch.adj.layout == layout
517+
out = to_edge_index(batch.adj.to_sparse(layout=torch.sparse_csr))
518+
assert torch.equal(out[0], torch.cat([edge_index, edge_index + 3], 1))
519+
assert torch.equal(out[1], torch.cat([edge_attr, edge_attr], 0))

torch_geometric/data/collate.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from torch_geometric.data.data import BaseData
1010
from torch_geometric.data.storage import BaseStorage, NodeStorage
1111
from torch_geometric.typing import SparseTensor, torch_sparse
12+
from torch_geometric.utils import is_sparse, is_torch_sparse_tensor
13+
from torch_geometric.utils.sparse import cat
1214

1315

1416
def collate(
@@ -122,7 +124,7 @@ def _collate(
122124

123125
elem = values[0]
124126

125-
if isinstance(elem, Tensor):
127+
if isinstance(elem, Tensor) and not is_sparse(elem):
126128
# Concatenate a list of `torch.Tensor` along the `cat_dim`.
127129
# NOTE: We need to take care of incrementing elements appropriately.
128130
key = str(key)
@@ -160,15 +162,18 @@ def _collate(
160162
value = torch.cat(values, dim=cat_dim or 0, out=out)
161163
return value, slices, incs
162164

163-
elif isinstance(elem, SparseTensor) and increment:
165+
elif is_sparse(elem) and increment:
164166
# Concatenate a list of `SparseTensor` along the `cat_dim`.
165167
# NOTE: `cat_dim` may return a tuple to allow for diagonal stacking.
166168
key = str(key)
167169
cat_dim = data_list[0].__cat_dim__(key, elem, stores[0])
168170
cat_dims = (cat_dim, ) if isinstance(cat_dim, int) else cat_dim
169171
repeats = [[value.size(dim) for dim in cat_dims] for value in values]
170172
slices = cumsum(repeats)
171-
value = torch_sparse.cat(values, dim=cat_dim)
173+
if is_torch_sparse_tensor(elem):
174+
value = cat(values, dim=cat_dim)
175+
else:
176+
value = torch_sparse.cat(values, dim=cat_dim)
172177
return value, slices, None
173178

174179
elif isinstance(elem, (int, float)):

torch_geometric/data/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
OptTensor,
3838
SparseTensor,
3939
)
40-
from torch_geometric.utils import select, subgraph
40+
from torch_geometric.utils import is_sparse, select, subgraph
4141

4242

4343
class BaseData:
@@ -518,7 +518,7 @@ def update(self, data: Union['Data', Dict[str, Any]]) -> 'Data':
518518
return self
519519

520520
def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:
521-
if isinstance(value, SparseTensor) and 'adj' in key:
521+
if is_sparse(value) and 'adj' in key:
522522
return (0, 1)
523523
elif 'index' in key or key == 'face':
524524
return -1

torch_geometric/data/separate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def _separate(
6565
start, end = int(slices[idx]), int(slices[idx + 1])
6666
value = narrow(value, cat_dim or 0, start, end - start)
6767
value = value.squeeze(0) if cat_dim is None else value
68-
if decrement and (incs.dim() > 1 or int(incs[idx]) != 0):
68+
if (decrement and incs is not None
69+
and (incs.dim() > 1 or int(incs[idx]) != 0)):
6970
value = value - incs[idx].to(value.device)
7071
return value
7172

torch_geometric/utils/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from .nested import to_nested_tensor, from_nested_tensor
2828
from .sparse import (dense_to_sparse, is_sparse, is_torch_sparse_tensor,
2929
to_torch_coo_tensor, to_torch_csr_tensor,
30-
to_torch_csc_tensor, to_edge_index)
30+
to_torch_csc_tensor, to_torch_sparse_tensor,
31+
to_edge_index)
3132
from .spmm import spmm
3233
from .unbatch import unbatch, unbatch_edge_index
3334
from .one_hot import one_hot
@@ -99,6 +100,7 @@
99100
'to_torch_coo_tensor',
100101
'to_torch_csr_tensor',
101102
'to_torch_csc_tensor',
103+
'to_torch_sparse_tensor',
102104
'to_edge_index',
103105
'spmm',
104106
'unbatch',

torch_geometric/utils/sparse.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Tuple, Union
1+
from typing import Any, List, Optional, Tuple, Union
22

33
import torch
44
from torch import Tensor
@@ -239,6 +239,44 @@ def to_torch_csc_tensor(
239239
return adj
240240

241241

242+
def to_torch_sparse_tensor(
243+
edge_index: Tensor,
244+
edge_attr: Optional[Tensor] = None,
245+
size: Optional[Union[int, Tuple[int, int]]] = None,
246+
is_coalesced: bool = False,
247+
layout: torch.layout = torch.sparse_coo,
248+
):
249+
r"""Converts a sparse adjacency matrix defined by edge indices and edge
250+
attributes to a :class:`torch.sparse.Tensor` with custom :obj:`layout`.
251+
See :meth:`~torch_geometric.utils.to_edge_index` for the reverse operation.
252+
253+
Args:
254+
edge_index (LongTensor): The edge indices.
255+
edge_attr (Tensor, optional): The edge attributes.
256+
(default: :obj:`None`)
257+
size (int or (int, int), optional): The size of the sparse matrix.
258+
If given as an integer, will create a quadratic sparse matrix.
259+
If set to :obj:`None`, will infer a quadratic sparse matrix based
260+
on :obj:`edge_index.max() + 1`. (default: :obj:`None`)
261+
is_coalesced (bool): If set to :obj:`True`, will assume that
262+
:obj:`edge_index` is already coalesced and thus avoids expensive
263+
computation. (default: :obj:`False`)
264+
layout (torch.layout, optional): The layout of the output sparse tensor
265+
(:obj:`torch.sparse_coo`, :obj:`torch.sparse_csr`,
266+
:obj:`torch.sparse_csc`). (default: :obj:`torch.sparse_coo`)
267+
268+
:rtype: :class:`torch.sparse.Tensor`
269+
"""
270+
if layout == torch.sparse_coo:
271+
return to_torch_coo_tensor(edge_index, edge_attr, size, is_coalesced)
272+
if layout == torch.sparse_csr:
273+
return to_torch_csr_tensor(edge_index, edge_attr, size, is_coalesced)
274+
if layout == torch.sparse_csc:
275+
return to_torch_csc_tensor(edge_index, edge_attr, size, is_coalesced)
276+
277+
raise ValueError(f"Unexpected sparse tensor layout (got '{layout}')")
278+
279+
242280
def to_edge_index(adj: Union[Tensor, SparseTensor]) -> Tuple[Tensor, Tensor]:
243281
r"""Converts a :class:`torch.sparse.Tensor` or a
244282
:class:`torch_sparse.SparseTensor` to edge indices and edge attributes.
@@ -341,3 +379,42 @@ def ptr2index(ptr: Tensor) -> Tensor:
341379
def index2ptr(index: Tensor, size: int) -> Tensor:
342380
return torch._convert_indices_from_coo_to_csr(
343381
index, size, out_int32=index.dtype == torch.int32)
382+
383+
384+
def cat(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor:
385+
# TODO (matthias) We can make this more efficient by directly operating on
386+
# the individual sparse tensor layouts.
387+
assert dim in {0, 1, (0, 1)}
388+
389+
size = [0, 0]
390+
edge_indices = []
391+
edge_attrs = []
392+
for tensor in tensors:
393+
assert is_torch_sparse_tensor(tensor)
394+
edge_index, edge_attr = to_edge_index(tensor)
395+
edge_index = edge_index.clone()
396+
397+
if dim == 0:
398+
edge_index[0] += size[0]
399+
size[0] += tensor.size(0)
400+
size[1] = max(size[1], tensor.size(1))
401+
elif dim == 1:
402+
edge_index[1] += size[1]
403+
size[0] = max(size[0], tensor.size(0))
404+
size[1] += tensor.size(1)
405+
else:
406+
edge_index[0] += size[0]
407+
edge_index[1] += size[1]
408+
size[0] += tensor.size(0)
409+
size[1] += tensor.size(1)
410+
411+
edge_indices.append(edge_index)
412+
edge_attrs.append(edge_attr)
413+
414+
return to_torch_sparse_tensor(
415+
edge_index=torch.cat(edge_indices, dim=1),
416+
edge_attr=torch.cat(edge_attrs, dim=0),
417+
size=size,
418+
is_coalesced=dim == (0, 1),
419+
layout=tensors[0].layout,
420+
)

0 commit comments

Comments
 (0)