Skip to content

Commit 85ace25

Browse files
authored
Fix cudaSetDevice for CUDA 12 (#370)
update
1 parent 2d55981 commit 85ace25

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

csrc/cuda/convert_cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ __global__ void ind2ptr_kernel(const int64_t *ind_data, int64_t *out_data,
2525

2626
torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) {
2727
CHECK_CUDA(ind);
28-
cudaSetDevice(ind.get_device());
28+
c10::cuda::MaybeSetDevice(ind.get_device());
2929

3030
auto out = torch::empty({M + 1}, ind.options());
3131

@@ -55,7 +55,7 @@ __global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data,
5555

5656
torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
5757
CHECK_CUDA(ptr);
58-
cudaSetDevice(ptr.get_device());
58+
c10::cuda::MaybeSetDevice(ptr.get_device());
5959

6060
auto out = torch::empty({E}, ptr.options());
6161
auto ptr_data = ptr.data_ptr<int64_t>();

csrc/cuda/diag_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
4343
int64_t M, int64_t N, int64_t k) {
4444
CHECK_CUDA(row);
4545
CHECK_CUDA(col);
46-
cudaSetDevice(row.get_device());
46+
c10::cuda::MaybeSetDevice(row.get_device());
4747

4848
auto E = row.size(0);
4949
auto num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);

csrc/cuda/rw_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
3333
CHECK_CUDA(rowptr);
3434
CHECK_CUDA(col);
3535
CHECK_CUDA(start);
36-
cudaSetDevice(rowptr.get_device());
36+
c10::cuda::MaybeSetDevice(rowptr.get_device());
3737

3838
CHECK_INPUT(rowptr.dim() == 1);
3939
CHECK_INPUT(col.dim() == 1);

csrc/cuda/spmm_cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
9999
if (optional_value.has_value())
100100
CHECK_CUDA(optional_value.value());
101101
CHECK_CUDA(mat);
102-
cudaSetDevice(rowptr.get_device());
102+
c10::cuda::MaybeSetDevice(rowptr.get_device());
103103

104104
CHECK_INPUT(rowptr.dim() == 1);
105105
CHECK_INPUT(col.dim() == 1);
@@ -201,7 +201,7 @@ torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
201201
CHECK_CUDA(col);
202202
CHECK_CUDA(mat);
203203
CHECK_CUDA(grad);
204-
cudaSetDevice(row.get_device());
204+
c10::cuda::MaybeSetDevice(row.get_device());
205205

206206
mat = mat.contiguous();
207207
grad = grad.contiguous();

0 commit comments

Comments
 (0)