You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix HGTConvedge_type_vec construction (pyg-team#7194)
This pr fixes the utility function
https://github.com/pyg-team/pytorch_geometric/blob/3d4836bc24dbb1b180f29cbbbdbcd18b94116dd7/torch_geometric/nn/conv/hgt_conv.py#L123,
which constructs the type_vec of edges wrong and also crashes if some
edge_types are not present in the current edge_index_dict.
Consider the following scenario:
```python
# N =2, D=2, H=2 (2 nodes, head_dim 2, 2 heads)
k = [
[0,0,1,1],
[2,2,3,3]
]
```
after calling this line:
https://github.com/pyg-team/pytorch_geometric/blob/3d4836bc24dbb1b180f29cbbbdbcd18b94116dd7/torch_geometric/nn/conv/hgt_conv.py#L141
the matrix k looks like this:
```python
k= [
[0,0],
[1,1],
[2,2],
[3,3]]
# the type vec should look like this
type_vec = [0,1,0,1]
# but at current implementation it would look like this
type_vec = [0,0,1,1]
```
After the reshape the attention heads are interleaved but the type
vector that is currently constructed is sorted.
We fixed this issue by constructing interleaved type vec. Alternatively
we can transpose the k before the reshape to ensure that we can use
sorted type vec. This will also allow us to set `is_sorted=True` for the
heterolinear `k_rel` which would be more efficient.
Also note that we added a test case for missing edge type in
edge_index_dict.
---------
Co-authored-by: berke.kisin <[email protected]>
Co-authored-by: toensoff <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinu Sunil <[email protected]>
0 commit comments