6
6
AT_ASSERTM (x.device().is_cuda(), #x " must be CUDA tensor" )
7
7
#define CHECK_INPUT (x ) AT_ASSERTM(x, " Input mismatch" )
8
8
9
- __device__ __inline__ at::Half
10
- __shfl_sync (const unsigned mask, const at::Half var, const int srcLane) {
11
- return __shfl_sync (mask, var.operator __half (), srcLane);
9
+ __device__ __inline__ at::Half __shfl_up_sync (const unsigned mask,
10
+ const at::Half var,
11
+ const unsigned int delta) {
12
+ return __shfl_up_sync (mask, var.operator __half (), delta);
12
13
}
13
14
14
15
__device__ __inline__ at::Half __shfl_down_sync (const unsigned mask,
@@ -17,6 +18,27 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
17
18
return __shfl_down_sync (mask, var.operator __half (), delta);
18
19
}
19
20
21
+ __device__ __inline__ at::Half __shfl_sync (const unsigned mask,
22
+ const at::Half var,
23
+ const int delta) {
24
+ return __shfl_sync (mask, var.operator __half (), delta);
25
+ }
26
+
27
+ __device__ __inline__ at::Half __shfl_up (const at::Half var,
28
+ const unsigned int delta) {
29
+ return __shfl_up (var.operator __half (), delta);
30
+ }
31
+
32
+ __device__ __inline__ at::Half __shfl_down (const at::Half var,
33
+ const unsigned int delta) {
34
+ return __shfl_down (var.operator __half (), delta);
35
+ }
36
+
37
+ __device__ __inline__ at::Half
38
+ __shfl (const at::Half var, const int delta) {
39
+ return __shfl (var.operator __half (), delta);
40
+ }
41
+
20
42
#ifdef USE_ROCM
21
43
__device__ __inline__ at::Half __ldg (const at::Half* ptr) {
22
44
return __ldg (reinterpret_cast <const __half*>(ptr));
0 commit comments