Skip to content

Commit bfb571c

Browse files
authored
Merge pull request #40 from rusty1s/metis
[WIP] Partition
2 parents e78637e + eee47ee commit bfb571c

29 files changed

+292
-218
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ jobs:
6060
install:
6161
- source script/cuda.sh
6262
- source script/conda.sh
63+
- source script/metis.sh
6364
- conda create --yes -n test python="${PYTHON_VERSION}"
6465
- source activate test
6566
- conda install pytorch=${TORCH_VERSION} ${TOOLKIT} -c pytorch --yes

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,20 @@ $ echo $CPATH
5959
>>> /usr/local/cuda/include:...
6060
```
6161

62+
If you want to additionally build `torch-sparse` with METIS support, *e.g.* for partioning, please download and install the [METIS library](http://glaros.dtc.umn.edu/gkhome/metis/metis/download) by following the instructions in the `Install.txt` file.
63+
Afterwards, set the environment variable `WITH_METIS=1`.
64+
6265
Then run:
6366

6467
```
6568
pip install torch-scatter torch-sparse
6669
```
6770

68-
When running in a docker container without nvidia driver, PyTorch needs to evaluate the compute capabilities and may fail.
71+
When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail.
6972
In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*:
7073

7174
```
72-
export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"
75+
export TORCH_CUDA_ARCH_LIST="6.0 6.1 7.2+PTX 7.5+PTX"
7376
```
7477

7578
## Functions

csrc/cpu/metis_cpu.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#include "metis_cpu.h"
2+
3+
#ifdef WITH_METIS
4+
#include <metis.h>
5+
#endif
6+
7+
#include "utils.h"
8+
9+
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
10+
int64_t num_parts, bool recursive) {
11+
#ifdef WITH_METIS
12+
CHECK_CPU(rowptr);
13+
CHECK_CPU(col);
14+
15+
int64_t nvtxs = rowptr.numel() - 1;
16+
auto part = torch::empty(nvtxs, rowptr.options());
17+
18+
auto *xadj = rowptr.data_ptr<int64_t>();
19+
auto *adjncy = col.data_ptr<int64_t>();
20+
int64_t ncon = 1;
21+
int64_t objval = -1;
22+
auto part_data = part.data_ptr<int64_t>();
23+
24+
if (recursive) {
25+
METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL,
26+
&num_parts, NULL, NULL, NULL, &objval, part_data);
27+
} else {
28+
METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL,
29+
&num_parts, NULL, NULL, NULL, &objval, part_data);
30+
}
31+
32+
return part;
33+
#else
34+
AT_ERROR("Not compiled with METIS support");
35+
#endif
36+
}

csrc/cpu/metis_cpu.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
6+
int64_t num_parts, bool recursive);

csrc/metis.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include <Python.h>
2+
#include <torch/script.h>
3+
4+
#include "cpu/metis_cpu.h"
5+
6+
#ifdef _WIN32
7+
PyMODINIT_FUNC PyInit__metis(void) { return NULL; }
8+
#endif
9+
10+
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
11+
int64_t num_parts, bool recursive) {
12+
if (rowptr.device().is_cuda()) {
13+
#ifdef WITH_CUDA
14+
AT_ERROR("No CUDA version supported");
15+
#else
16+
AT_ERROR("Not compiled with CUDA support");
17+
#endif
18+
} else {
19+
return partition_cpu(rowptr, col, num_parts, recursive);
20+
}
21+
}
22+
23+
static auto registry =
24+
torch::RegisterOperators().op("torch_sparse::partition", &partition);

script/metis.sh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/bin/bash
2+
3+
METIS=metis-5.1.0
4+
export WITH_METIS=1
5+
6+
wget -nv http://glaros.dtc.umn.edu/gkhome/fetch/sw/metis/${METIS}.tar.gz
7+
tar -xvzf ${METIS}.tar.gz
8+
cd ${METIS} || exit
9+
sed -i.bak -e 's/IDXTYPEWIDTH 32/IDXTYPEWIDTH 64/g' include/metis.h
10+
11+
if [ "${TRAVIS_OS_NAME}" != "windows" ]; then
12+
make config
13+
make
14+
sudo make install
15+
else
16+
# Fix GKlib on Windows: https://github.com/jlblancoc/suitesparse-metis-for-windows/issues/6
17+
sed -i.bak -e '61,69d' GKlib/gk_arch.h
18+
19+
cd build || exit
20+
21+
cmake .. -A x64 # Ensure we are building with x64
22+
cmake --build . --config "Release" --target ALL_BUILD
23+
cp libmetis/Release/metis.lib /c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/VC/Tools/MSVC/14.16.27023/lib/x64
24+
cp ../include/metis.h /c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/VC/Tools/MSVC/14.16.27023/include
25+
26+
cd ..
27+
fi
28+
29+
cd ..

setup.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,18 @@
1616

1717
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
1818

19+
WITH_METIS = False
20+
if os.getenv('WITH_METIS', '0') == '1':
21+
WITH_METIS = True
22+
1923

2024
def get_extensions():
2125
Extension = CppExtension
2226
define_macros = []
27+
libraries = []
28+
if WITH_METIS:
29+
define_macros += [('WITH_METIS', None)]
30+
libraries += ['metis']
2331
extra_compile_args = {'cxx': []}
2432
extra_link_args = []
2533

@@ -32,9 +40,9 @@ def get_extensions():
3240
extra_compile_args['nvcc'] = nvcc_flags
3341

3442
if sys.platform == 'win32':
35-
extra_link_args = ['cusparse.lib']
43+
extra_link_args += ['cusparse.lib']
3644
else:
37-
extra_link_args = ['-lcusparse', '-l', 'cusparse']
45+
extra_link_args += ['-lcusparse', '-l', 'cusparse']
3846

3947
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
4048
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
@@ -59,6 +67,7 @@ def get_extensions():
5967
define_macros=define_macros,
6068
extra_compile_args=extra_compile_args,
6169
extra_link_args=extra_link_args,
70+
libraries=libraries,
6271
)
6372
extensions += [extension]
6473

@@ -71,7 +80,7 @@ def get_extensions():
7180

7281
setup(
7382
name='torch_sparse',
74-
version='0.5.1',
83+
version='0.6.0',
7584
author='Matthias Fey',
7685
author_email='[email protected]',
7786
url='https://github.com/rusty1s/pytorch_sparse',

test/test_metis.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
import torch
3+
from torch_sparse.tensor import SparseTensor
4+
5+
from .utils import devices
6+
7+
8+
@pytest.mark.parametrize('device', devices)
9+
def test_metis(device):
10+
mat = SparseTensor.from_dense(torch.randn((6, 6), device=device))
11+
mat, partptr, perm = mat.partition(num_parts=2, recursive=False)
12+
assert partptr.numel() == 3
13+
assert perm.numel() == 6
14+
15+
mat, partptr, perm = mat.partition(num_parts=2, recursive=True)
16+
assert partptr.numel() == 3
17+
assert perm.numel() == 6

test/test_permute.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
import torch
3+
from torch_sparse.tensor import SparseTensor
4+
5+
from .utils import devices, tensor
6+
7+
8+
@pytest.mark.parametrize('device', devices)
9+
def test_permute(device):
10+
row, col = tensor([[0, 0, 1, 2, 2], [0, 1, 0, 1, 2]], torch.long, device)
11+
value = tensor([1, 2, 3, 4, 5], torch.float, device)
12+
adj = SparseTensor(row=row, col=col, value=value)
13+
14+
row, col, value = adj.permute(torch.tensor([1, 0, 2])).coo()
15+
assert row.tolist() == [0, 1, 1, 2, 2]
16+
assert col.tolist() == [1, 0, 1, 0, 2]
17+
assert value.tolist() == [3, 2, 1, 4, 5]

test/test_storage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def test_utility(dtype, device):
9393
storage = storage.set_value(value, layout='coo')
9494
assert storage.value().tolist() == [1, 2, 3, 4]
9595

96-
storage = storage.sparse_resize([3, 3])
97-
assert storage.sparse_sizes() == [3, 3]
96+
storage = storage.sparse_resize((3, 3))
97+
assert storage.sparse_sizes() == (3, 3)
9898

9999
new_storage = storage.copy()
100100
assert new_storage != storage

torch_sparse/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
import torch
55

6-
__version__ = '0.5.1'
6+
__version__ = '0.6.0'
77
expected_torch_version = (1, 4)
88

99
try:
10-
for library in ['_version', '_convert', '_diag', '_spmm', '_spspmm']:
10+
for library in [
11+
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis'
12+
]:
1113
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
1214
library, [osp.dirname(__file__)]).origin)
1315
except OSError as e:
@@ -45,12 +47,14 @@
4547
from .select import select # noqa
4648
from .index_select import index_select, index_select_nnz # noqa
4749
from .masked_select import masked_select, masked_select_nnz # noqa
50+
from .permute import permute # noqa
4851
from .diag import remove_diag, set_diag, fill_diag # noqa
4952
from .add import add, add_, add_nnz, add_nnz_ # noqa
5053
from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
5154
from .reduce import sum, mean, min, max # noqa
5255
from .matmul import matmul # noqa
5356
from .cat import cat, cat_diag # noqa
57+
from .metis import partition # noqa
5458

5559
from .convert import to_torch_sparse, from_torch_sparse # noqa
5660
from .convert import to_scipy, from_scipy # noqa
@@ -71,6 +75,7 @@
7175
'index_select_nnz',
7276
'masked_select',
7377
'masked_select_nnz',
78+
'permute',
7479
'remove_diag',
7580
'set_diag',
7681
'fill_diag',
@@ -89,6 +94,7 @@
8994
'matmul',
9095
'cat',
9196
'cat_diag',
97+
'partition',
9298
'to_torch_sparse',
9399
'from_torch_sparse',
94100
'to_scipy',

torch_sparse/add.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch_sparse.tensor import SparseTensor
66

77

8-
@torch.jit.script
98
def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
109
rowptr, col, value = src.csr()
1110
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
@@ -24,7 +23,6 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
2423
return src.set_value(value, layout='coo')
2524

2625

27-
@torch.jit.script
2826
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
2927
rowptr, col, value = src.csr()
3028
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
@@ -44,7 +42,6 @@ def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
4442
return src.set_value_(value, layout='coo')
4543

4644

47-
@torch.jit.script
4845
def add_nnz(src: SparseTensor, other: torch.Tensor,
4946
layout: Optional[str] = None) -> SparseTensor:
5047
value = src.storage.value()
@@ -55,7 +52,6 @@ def add_nnz(src: SparseTensor, other: torch.Tensor,
5552
return src.set_value(value, layout=layout)
5653

5754

58-
@torch.jit.script
5955
def add_nnz_(src: SparseTensor, other: torch.Tensor,
6056
layout: Optional[str] = None) -> SparseTensor:
6157
value = src.storage.value()

torch_sparse/cat.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch_sparse.tensor import SparseTensor
66

77

8-
@torch.jit.script
98
def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
109
assert len(tensors) > 0
1110
if dim < 0:
@@ -142,7 +141,6 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
142141
'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.')
143142

144143

145-
@torch.jit.script
146144
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
147145
assert len(tensors) > 0
148146

torch_sparse/coalesce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ def coalesce(index, value, m, n, op="add"):
2020
"""
2121

2222
storage = SparseStorage(row=index[0], col=index[1], value=value,
23-
sparse_sizes=torch.Size([m, n]), is_sorted=False)
23+
sparse_sizes=(m, n), is_sorted=False)
2424
storage = storage.coalesce(reduce=op)
2525
return torch.stack([storage.row(), storage.col()], dim=0), storage.value()

torch_sparse/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
def to_torch_sparse(index, value, m, n):
8-
return torch.sparse_coo_tensor(index.detach(), value, torch.Size([m, n]))
8+
return torch.sparse_coo_tensor(index.detach(), value, (m, n))
99

1010

1111
def from_torch_sparse(A):

0 commit comments

Comments
 (0)