Skip to content

Commit bfdca11

Browse files
committed
fix GPU device handling
1 parent c4c6db4 commit bfdca11

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torch_sparse/diag.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from torch import Tensor
5+
56
from torch_sparse.storage import SparseStorage
67
from torch_sparse.tensor import SparseTensor
78

@@ -97,7 +98,7 @@ def get_diag(src: SparseTensor) -> Tensor:
9798
row, col, value = src.coo()
9899

99100
if value is None:
100-
value = torch.ones(row.size(0))
101+
value = torch.ones(row.size(0), device=row.device)
101102

102103
sizes = list(value.size())
103104
sizes[0] = min(src.size(0), src.size(1))

0 commit comments

Comments
 (0)