Skip to content

Commit bf2591a

Browse files
authored
[ROCm] fixes ambiguous calls to shfl* where there is no explicit type conversion from c10::Half to __half (#360)
[ROCm] fixes ambiguous calls to `shfl*` where there is no explicit type conversion from `c10::Half` to `__half`
1 parent 05b62b1 commit bf2591a

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

csrc/cuda/utils.cuh

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
77
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
88

9-
__device__ __inline__ at::Half
10-
__shfl_sync(const unsigned mask, const at::Half var, const int srcLane) {
11-
return __shfl_sync(mask, var.operator __half(), srcLane);
9+
__device__ __inline__ at::Half __shfl_up_sync(const unsigned mask,
10+
const at::Half var,
11+
const unsigned int delta) {
12+
return __shfl_up_sync(mask, var.operator __half(), delta);
1213
}
1314

1415
__device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
@@ -17,6 +18,27 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
1718
return __shfl_down_sync(mask, var.operator __half(), delta);
1819
}
1920

21+
__device__ __inline__ at::Half __shfl_sync(const unsigned mask,
22+
const at::Half var,
23+
const int delta) {
24+
return __shfl_sync(mask, var.operator __half(), delta);
25+
}
26+
27+
__device__ __inline__ at::Half __shfl_up(const at::Half var,
28+
const unsigned int delta) {
29+
return __shfl_up(var.operator __half(), delta);
30+
}
31+
32+
__device__ __inline__ at::Half __shfl_down(const at::Half var,
33+
const unsigned int delta) {
34+
return __shfl_down(var.operator __half(), delta);
35+
}
36+
37+
__device__ __inline__ at::Half
38+
__shfl(const at::Half var, const int delta) {
39+
return __shfl(var.operator __half(), delta);
40+
}
41+
2042
#ifdef USE_ROCM
2143
__device__ __inline__ at::Half __ldg(const at::Half* ptr) {
2244
return __ldg(reinterpret_cast<const __half*>(ptr));

0 commit comments

Comments
 (0)