Skip to content

Commit c9fef62

Browse files
berkekisinberke.kisintoenshoffpre-commit-ci[bot]wsad1
authored
Fix HGTConv edge_type_vec construction (pyg-team#7194)
This pr fixes the utility function https://github.com/pyg-team/pytorch_geometric/blob/3d4836bc24dbb1b180f29cbbbdbcd18b94116dd7/torch_geometric/nn/conv/hgt_conv.py#L123, which constructs the type_vec of edges wrong and also crashes if some edge_types are not present in the current edge_index_dict. Consider the following scenario: ```python # N =2, D=2, H=2 (2 nodes, head_dim 2, 2 heads) k = [ [0,0,1,1], [2,2,3,3] ] ``` after calling this line: https://github.com/pyg-team/pytorch_geometric/blob/3d4836bc24dbb1b180f29cbbbdbcd18b94116dd7/torch_geometric/nn/conv/hgt_conv.py#L141 the matrix k looks like this: ```python k= [ [0,0], [1,1], [2,2], [3,3]] # the type vec should look like this type_vec = [0,1,0,1] # but at current implementation it would look like this type_vec = [0,0,1,1] ``` After the reshape the attention heads are interleaved but the type vector that is currently constructed is sorted. We fixed this issue by constructing interleaved type vec. Alternatively we can transpose the k before the reshape to ensure that we can use sorted type vec. This will also allow us to set `is_sorted=True` for the heterolinear `k_rel` which would be more efficient. Also note that we added a test case for missing edge type in edge_index_dict. --------- Co-authored-by: berke.kisin <[email protected]> Co-authored-by: toensoff <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinu Sunil <[email protected]>
1 parent edbf8fc commit c9fef62

File tree

4 files changed

+79
-33
lines changed

4 files changed

+79
-33
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929

3030
### Changed
3131

32+
- Fixed `HGTConv` utility function `_construct_src_node_feat` ([#7194](https://github.com/pyg-team/pytorch_geometric/pull/7194))
3233
- Extend dataset summary to create stats for each node/edge type ([#7203](https://github.com/pyg-team/pytorch_geometric/pull/7203))
3334
- Added an optional `batch_size` argument to `avg_pool_x` and `max_pool_x` ([#7216](https://github.com/pyg-team/pytorch_geometric/pull/7216))
3435
- Fixed `subgraph` on unordered inputs ([#7187](https://github.com/pyg-team/pytorch_geometric/pull/7187))

test/nn/conv/test_hgt_conv.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,45 @@ def test_hgt_conv_missing_dst_node_type():
193193
out_dict = conv(data.x_dict, data.edge_index_dict)
194194
assert out_dict['author'].size() == (4, 64)
195195
assert out_dict['paper'].size() == (6, 64)
196-
assert out_dict['university'] is None
196+
assert 'university' not in out_dict
197+
198+
199+
def test_hgt_conv_missing_input_node_type():
200+
data = HeteroData()
201+
data['author'].x = torch.randn(4, 16)
202+
data['paper'].x = torch.randn(6, 32)
203+
data['author', 'writes',
204+
'paper'].edge_index = get_random_edge_index(4, 6, 20)
205+
206+
# Some nodes from metadata are missing in data.
207+
# This might happen while using NeighborLoader.
208+
metadata = (['author', 'paper',
209+
'university'], [('author', 'writes', 'paper')])
210+
conv = HGTConv(-1, 64, metadata, heads=1)
211+
212+
out_dict = conv(data.x_dict, data.edge_index_dict)
213+
assert out_dict['paper'].size() == (6, 64)
214+
assert 'university' not in out_dict
215+
216+
217+
def test_hgt_conv_missing_edge_type():
218+
data = HeteroData()
219+
data['author'].x = torch.randn(4, 16)
220+
data['paper'].x = torch.randn(6, 32)
221+
data['university'].x = torch.randn(10, 32)
222+
223+
data['author', 'writes',
224+
'paper'].edge_index = get_random_edge_index(4, 6, 20)
225+
226+
metadata = (['author', 'paper',
227+
'university'], [('author', 'writes', 'paper'),
228+
('university', 'employs', 'author')])
229+
conv = HGTConv(-1, 64, metadata, heads=1)
230+
231+
out_dict = conv(data.x_dict, data.edge_index_dict)
232+
assert out_dict['author'].size() == (4, 64)
233+
assert out_dict['paper'].size() == (6, 64)
234+
assert 'university' not in out_dict
197235

198236

199237
if __name__ == '__main__':

torch_geometric/nn/conv/hgt_conv.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def __init__(
7171
self.heads = heads
7272
self.node_types = metadata[0]
7373
self.edge_types = metadata[1]
74+
self.edge_types_map = {
75+
edge_type: i
76+
for i, edge_type in enumerate(metadata[1])
77+
}
7478

7579
self.dst_node_types = set([key[-1] for key in self.edge_types])
7680

@@ -83,10 +87,10 @@ def __init__(
8387
dim = out_channels // heads
8488
num_types = heads * len(self.edge_types)
8589

86-
self.k_rel = HeteroLinear(dim, dim, num_types, is_sorted=True,
87-
bias=False)
88-
self.v_rel = HeteroLinear(dim, dim, num_types, is_sorted=True,
89-
bias=False)
90+
self.k_rel = HeteroLinear(dim, dim, num_types, bias=False,
91+
is_sorted=True)
92+
self.v_rel = HeteroLinear(dim, dim, num_types, bias=False,
93+
is_sorted=True)
9094

9195
self.skip = ParameterDict({
9296
node_type: Parameter(torch.Tensor(1))
@@ -121,36 +125,40 @@ def _cat(self, x_dict: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, int]]:
121125
return torch.cat(outs, dim=0), offset
122126

123127
def _construct_src_node_feat(
124-
self,
125-
k_dict: Dict[str, Tensor],
126-
v_dict: Dict[str, Tensor],
128+
self, k_dict: Dict[str, Tensor], v_dict: Dict[str, Tensor],
129+
edge_index_dict: Dict[EdgeType, Adj]
127130
) -> Tuple[Tensor, Tensor, Dict[EdgeType, int]]:
128131
"""Constructs the source node representations."""
129-
count = 0
130132
cumsum = 0
133+
num_edge_types = len(self.edge_types)
131134
H, D = self.heads, self.out_channels // self.heads
132135

133136
# Flatten into a single tensor with shape [num_edge_types * heads, D]:
134137
ks: List[Tensor] = []
135138
vs: List[Tensor] = []
136-
type_list: List[int] = []
139+
type_list: List[Tensor] = []
137140
offset: Dict[EdgeType] = {}
138-
for edge_type in self.edge_types:
139-
src, _, _ = edge_type
140-
141-
ks.append(k_dict[src].reshape(-1, D))
142-
vs.append(v_dict[src].reshape(-1, D))
143-
141+
for edge_type in edge_index_dict.keys():
142+
src = edge_type[0]
144143
N = k_dict[src].size(0)
145-
for _ in range(H):
146-
type_list.append(torch.full((N, ), count, dtype=torch.long))
147-
count += 1
148144
offset[edge_type] = cumsum
149145
cumsum += N
150146

151-
type_vec = torch.cat(type_list, dim=0)
152-
k = self.k_rel(torch.cat(ks, dim=0), type_vec).view(-1, H, D)
153-
v = self.v_rel(torch.cat(vs, dim=0), type_vec).view(-1, H, D)
147+
# construct type_vec for curr edge_type with shape [H, D]
148+
edge_type_offset = self.edge_types_map[edge_type]
149+
type_vec = torch.arange(H, dtype=torch.long).view(-1, 1).repeat(
150+
1, N) * num_edge_types + edge_type_offset
151+
152+
type_list.append(type_vec)
153+
ks.append(k_dict[src])
154+
vs.append(v_dict[src])
155+
156+
ks = torch.cat(ks, dim=0).transpose(0, 1).reshape(-1, D)
157+
vs = torch.cat(vs, dim=0).transpose(0, 1).reshape(-1, D)
158+
type_vec = torch.cat(type_list, dim=1).flatten()
159+
160+
k = self.k_rel(ks, type_vec).view(H, -1, D).transpose(0, 1)
161+
v = self.v_rel(vs, type_vec).view(H, -1, D).transpose(0, 1)
154162

155163
return k, v, offset
156164

@@ -184,12 +192,14 @@ def forward(
184192
# Compute K, Q, V over node types:
185193
kqv_dict = self.kqv_lin(x_dict)
186194
for key, val in kqv_dict.items():
187-
k_dict[key] = val[:, :F].view(-1, H, D)
188-
q_dict[key] = val[:, F:2 * F].view(-1, H, D)
189-
v_dict[key] = val[:, 2 * F:].view(-1, H, D)
195+
k, q, v = torch.tensor_split(val, 3, dim=1)
196+
k_dict[key] = k.view(-1, H, D)
197+
q_dict[key] = q.view(-1, H, D)
198+
v_dict[key] = v.view(-1, H, D)
190199

191200
q, dst_offset = self._cat(q_dict)
192-
k, v, src_offset = self._construct_src_node_feat(k_dict, v_dict)
201+
k, v, src_offset = self._construct_src_node_feat(
202+
k_dict, v_dict, edge_index_dict)
193203

194204
edge_index, edge_attr = construct_bipartite_edge_index(
195205
edge_index_dict, src_offset, dst_offset, edge_attr_dict=self.p_rel)
@@ -200,7 +210,8 @@ def forward(
200210
# Reconstruct output node embeddings dict:
201211
for node_type, start_offset in dst_offset.items():
202212
end_offset = start_offset + q_dict[node_type].size(0)
203-
out_dict[node_type] = out[start_offset:end_offset]
213+
if node_type in self.dst_node_types:
214+
out_dict[node_type] = out[start_offset:end_offset]
204215

205216
# Transform output node embeddings:
206217
a_dict = self.out_lin({
@@ -210,11 +221,7 @@ def forward(
210221

211222
# Iterate over node types:
212223
for node_type, out in out_dict.items():
213-
if node_type not in self.dst_node_types:
214-
out_dict[node_type] = None
215-
continue
216-
else:
217-
out = a_dict[node_type]
224+
out = a_dict[node_type]
218225

219226
if out.size(-1) == x_dict[node_type].size(-1):
220227
alpha = self.skip[node_type].sigmoid()

torch_geometric/nn/dense/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def forward(
387387
biases.append(lin.bias)
388388
biases = None if biases[0] is None else biases
389389
outs = pyg_lib.ops.grouped_matmul(xs, weights, biases)
390-
for key, out in zip(self.lins.keys(), outs):
390+
for key, out in zip(x_dict.keys(), outs):
391391
if key in x_dict:
392392
out_dict[key] = out
393393
else:

0 commit comments

Comments
 (0)