@@ -15,32 +15,11 @@ const long NNZ_PER_BLOCK_MAX = 1024;
1515#define clamp (a, low, high ) max(min((a), (high)), (low))
1616#endif
1717
18- #ifndef ATOMIC_REAL_MINMAX
19- #define ATOMIC_REAL_MINMAX (func ) \
20- __device__ void atomic_##func(double *address, double val) { \
21- unsigned long long int * address_as_ull = (unsigned long long int *)address; \
22- unsigned long long int old = *address_as_ull; \
23- unsigned long long int assumed; \
24- do { \
25- assumed = old; \
26- old = atomicCAS (address_as_ull, assumed, \
27- __double_as_longlong (func (val, __longlong_as_double (assumed)))); \
28- } while (assumed != old); \
29- } \
30- __device__ void atomic_##func(float *address, float val) { \
31- int * address_as_int = (int *)address; \
32- int old = *address_as_int; \
33- int assumed; \
34- do { \
35- assumed = old; \
36- old = atomicCAS (address_as_int, assumed, \
37- __float_as_int (func (val, __int_as_float (assumed)))); \
38- } while (assumed != old); \
39- } \
40-
41- ATOMIC_REAL_MINMAX (max)
42- ATOMIC_REAL_MINMAX(min)
43- #endif
18+ __device__ double atomicExch (double *address, double val) {
19+ unsigned long long int * address_as_ull = (unsigned long long int *)address;
20+ unsigned long long res = atomicExch (address_as_ull, __double_as_longlong (val));
21+ return __longlong_as_double (res);
22+ }
4423
4524template <typename Ty, bool train>
4625__global__ static
@@ -113,14 +92,16 @@ void updateOutput(
11392 Ty *nWeightCurr = nWeight + nWeightOffset;
11493 if (train) {
11594 Ty absVal = fabs (val);
116- Ty maxVal = nWeight[key * weightStride + 0 ];
95+ Ty maxVal = nWeightCurr[ 0 ];
11796 if (absVal > maxVal) {
11897 // Updating maxVal and invMaxVal. Go hogwild!
119- atomic_max (nWeightCurr + 0 , absVal);
120- atomic_min (nWeightCurr + 1 , 1.0 /absVal);
98+ Ty invAbsVal = 1.0 / absVal;
99+ atomicExch (nWeightCurr + 0 , absVal);
100+ atomicExch (nWeightCurr + 1 , invAbsVal);
121101 }
122- val = val * nWeightCurr[1 ] + nWeightCurr[3 ];
102+ val = clamp ( val * nWeightCurr[1 ], - 1.0 , 1.0 ) + nWeightCurr[3 ];
123103 normalizedValues[id + tid] = val;
104+ nWeightCurr[2 ] = 1 ;
124105 } else {
125106 val = clamp (val * nWeightCurr[1 ], -1.0 , 1.0 ) + nWeightCurr[3 ];
126107 }
0 commit comments