|
1 | 1 | import copy
|
| 2 | +import warnings |
| 3 | +from typing import List |
2 | 4 |
|
3 | 5 | import pytest
|
4 | 6 | import torch
|
| 7 | +from torch import Tensor |
5 | 8 | from torch.nn import Linear as PTLinear
|
6 | 9 | from torch.nn.parameter import UninitializedParameter
|
7 | 10 |
|
8 | 11 | import torch_geometric.typing
|
9 | 12 | from torch_geometric.nn import HeteroDictLinear, HeteroLinear, Linear
|
| 13 | +from torch_geometric.profile import benchmark |
10 | 14 | from torch_geometric.testing import withCUDA, withPackage
|
| 15 | +from torch_geometric.typing import pyg_lib |
11 | 16 |
|
12 | 17 | weight_inits = ['glorot', 'kaiming_uniform', None]
|
13 | 18 | bias_inits = ['zeros', None]
|
@@ -216,3 +221,77 @@ def test_hetero_linear_sort(type_vec, device):
|
216 | 221 | node_type = int(type_vec[i])
|
217 | 222 | expected = x[i] @ lin.weight[node_type] + lin.bias[node_type]
|
218 | 223 | assert torch.allclose(out[i], expected, atol=1e-3)
|
| 224 | + |
| 225 | + |
| 226 | +if __name__ == '__main__': |
| 227 | + import argparse |
| 228 | + |
| 229 | + import dgl |
| 230 | + |
| 231 | + warnings.filterwarnings('ignore', '.*API of nested tensors.*') |
| 232 | + warnings.filterwarnings('ignore', '.*TypedStorage is deprecated.*') |
| 233 | + |
| 234 | + parser = argparse.ArgumentParser() |
| 235 | + parser.add_argument('--device', type=str, default='cuda') |
| 236 | + parser.add_argument('--backward', action='store_true') |
| 237 | + args = parser.parse_args() |
| 238 | + |
| 239 | + torch.manual_seed(12345) |
| 240 | + |
| 241 | + def get_xs(mean: float, std: float, num_types: int, |
| 242 | + channels: int) -> List[Tensor]: |
| 243 | + num_nodes_list = torch.normal( |
| 244 | + mean=torch.tensor([mean] * num_types, dtype=torch.float), |
| 245 | + std=torch.tensor([std] * num_types, dtype=torch.float), |
| 246 | + ).round().to(torch.long).tolist() |
| 247 | + |
| 248 | + return [ |
| 249 | + torch.randn(num_nodes, channels, device=args.device) |
| 250 | + for num_nodes in num_nodes_list |
| 251 | + ] |
| 252 | + |
| 253 | + def sequential(xs: List[Tensor], weights: List[Tensor]) -> List[Tensor]: |
| 254 | + return [x @ weight for x, weight in zip(xs, weights)] |
| 255 | + |
| 256 | + def nested(xs: List[Tensor], weights: List[Tensor]) -> List[Tensor]: |
| 257 | + x = torch.nested.nested_tensor(xs) |
| 258 | + weight = torch.nested.nested_tensor(weights) |
| 259 | + return list(torch.matmul(x, weight).unbind(0)) |
| 260 | + |
| 261 | + def grouped(x: Tensor, ptr: Tensor, weight: Tensor) -> Tensor: |
| 262 | + return pyg_lib.ops.segment_matmul(x, ptr, weight) |
| 263 | + |
| 264 | + def padded(x: Tensor, weight: Tensor) -> Tensor: |
| 265 | + return torch.matmul(x, weight) |
| 266 | + |
| 267 | + def dgl_mm(x: Tensor, count: Tensor, weight: Tensor) -> Tensor: |
| 268 | + return dgl.ops.segment_mm(x, weight, count) |
| 269 | + |
| 270 | + num_nodes, channels = 1_000_000, 64 |
| 271 | + |
| 272 | + for num_types in [3, 5, 10, 50, 100, 200, 500, 1000]: |
| 273 | + print(f'Number of types: {num_types}') |
| 274 | + mean = num_nodes // num_types |
| 275 | + std = mean // 4 |
| 276 | + |
| 277 | + xs = get_xs(mean, std, num_types, channels) |
| 278 | + count = torch.tensor([x.size(0) for x in xs]) |
| 279 | + ptr = torch.tensor([0] + [x.size(0) for x in xs]).cumsum(0) |
| 280 | + x = torch.cat(xs, dim=0) |
| 281 | + padded_x = torch.nested.nested_tensor(xs).to_padded_tensor(padding=0.0) |
| 282 | + weight = torch.randn(num_types, channels, channels, device=args.device) |
| 283 | + weights = list(weight.unbind(0)) |
| 284 | + |
| 285 | + benchmark( |
| 286 | + funcs=[sequential, grouped, padded, dgl_mm], |
| 287 | + func_names=['Sequential', 'Grouped', 'Padded', 'DGL'], |
| 288 | + args=[ |
| 289 | + (xs, weights), |
| 290 | + (x, ptr, weight), |
| 291 | + (padded_x, weight), |
| 292 | + (x, count, weight), |
| 293 | + ], |
| 294 | + num_steps=50 if args.device == 'cpu' else 500, |
| 295 | + num_warmups=10 if args.device == 'cpu' else 100, |
| 296 | + backward=args.backward, |
| 297 | + ) |
0 commit comments