Skip to content

Commit e3e63d6

Browse files
authored
Add segment_matmul micro-benchmark (pyg-team#7215)
1 parent 50c29f7 commit e3e63d6

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

test/nn/dense/test_linear.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import copy
2+
import warnings
3+
from typing import List
24

35
import pytest
46
import torch
7+
from torch import Tensor
58
from torch.nn import Linear as PTLinear
69
from torch.nn.parameter import UninitializedParameter
710

811
import torch_geometric.typing
912
from torch_geometric.nn import HeteroDictLinear, HeteroLinear, Linear
13+
from torch_geometric.profile import benchmark
1014
from torch_geometric.testing import withCUDA, withPackage
15+
from torch_geometric.typing import pyg_lib
1116

1217
weight_inits = ['glorot', 'kaiming_uniform', None]
1318
bias_inits = ['zeros', None]
@@ -216,3 +221,77 @@ def test_hetero_linear_sort(type_vec, device):
216221
node_type = int(type_vec[i])
217222
expected = x[i] @ lin.weight[node_type] + lin.bias[node_type]
218223
assert torch.allclose(out[i], expected, atol=1e-3)
224+
225+
226+
if __name__ == '__main__':
227+
import argparse
228+
229+
import dgl
230+
231+
warnings.filterwarnings('ignore', '.*API of nested tensors.*')
232+
warnings.filterwarnings('ignore', '.*TypedStorage is deprecated.*')
233+
234+
parser = argparse.ArgumentParser()
235+
parser.add_argument('--device', type=str, default='cuda')
236+
parser.add_argument('--backward', action='store_true')
237+
args = parser.parse_args()
238+
239+
torch.manual_seed(12345)
240+
241+
def get_xs(mean: float, std: float, num_types: int,
242+
channels: int) -> List[Tensor]:
243+
num_nodes_list = torch.normal(
244+
mean=torch.tensor([mean] * num_types, dtype=torch.float),
245+
std=torch.tensor([std] * num_types, dtype=torch.float),
246+
).round().to(torch.long).tolist()
247+
248+
return [
249+
torch.randn(num_nodes, channels, device=args.device)
250+
for num_nodes in num_nodes_list
251+
]
252+
253+
def sequential(xs: List[Tensor], weights: List[Tensor]) -> List[Tensor]:
254+
return [x @ weight for x, weight in zip(xs, weights)]
255+
256+
def nested(xs: List[Tensor], weights: List[Tensor]) -> List[Tensor]:
257+
x = torch.nested.nested_tensor(xs)
258+
weight = torch.nested.nested_tensor(weights)
259+
return list(torch.matmul(x, weight).unbind(0))
260+
261+
def grouped(x: Tensor, ptr: Tensor, weight: Tensor) -> Tensor:
262+
return pyg_lib.ops.segment_matmul(x, ptr, weight)
263+
264+
def padded(x: Tensor, weight: Tensor) -> Tensor:
265+
return torch.matmul(x, weight)
266+
267+
def dgl_mm(x: Tensor, count: Tensor, weight: Tensor) -> Tensor:
268+
return dgl.ops.segment_mm(x, weight, count)
269+
270+
num_nodes, channels = 1_000_000, 64
271+
272+
for num_types in [3, 5, 10, 50, 100, 200, 500, 1000]:
273+
print(f'Number of types: {num_types}')
274+
mean = num_nodes // num_types
275+
std = mean // 4
276+
277+
xs = get_xs(mean, std, num_types, channels)
278+
count = torch.tensor([x.size(0) for x in xs])
279+
ptr = torch.tensor([0] + [x.size(0) for x in xs]).cumsum(0)
280+
x = torch.cat(xs, dim=0)
281+
padded_x = torch.nested.nested_tensor(xs).to_padded_tensor(padding=0.0)
282+
weight = torch.randn(num_types, channels, channels, device=args.device)
283+
weights = list(weight.unbind(0))
284+
285+
benchmark(
286+
funcs=[sequential, grouped, padded, dgl_mm],
287+
func_names=['Sequential', 'Grouped', 'Padded', 'DGL'],
288+
args=[
289+
(xs, weights),
290+
(x, ptr, weight),
291+
(padded_x, weight),
292+
(x, count, weight),
293+
],
294+
num_steps=50 if args.device == 'cpu' else 500,
295+
num_warmups=10 if args.device == 'cpu' else 100,
296+
backward=args.backward,
297+
)

torch_geometric/profile/benchmark.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@ def benchmark(
8383
t_forward += time.perf_counter() - t_start
8484

8585
if backward:
86-
if isinstance(out, dict): # TODO Generalize this logic.
86+
# TODO Generalize this logic. This is also a bit unfair as the
87+
# concatenation leads to incorrectly measured backward speeds.
88+
if isinstance(out, (tuple, list)):
89+
out = torch.cat(out, dim=0)
90+
elif isinstance(out, dict):
8791
out = torch.cat(list(out.values()), dim=0)
8892

8993
out_grad = torch.randn_like(out)

0 commit comments

Comments
 (0)