Skip to content

Commit 8c6df2a

Browse files
authored
Merge pull request #472 from twitter-forks/indexlinear-fix
Fixing incorrect normalized values in IndexLinear during training
2 parents f2b2286 + 36d1b76 commit 8c6df2a

File tree

2 files changed

+31
-50
lines changed

2 files changed

+31
-50
lines changed

lib/THCUNN/IndexLinear.cu

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,32 +15,11 @@ const long NNZ_PER_BLOCK_MAX = 1024;
1515
#define clamp(a, low, high) max(min((a), (high)), (low))
1616
#endif
1717

18-
#ifndef ATOMIC_REAL_MINMAX
19-
#define ATOMIC_REAL_MINMAX(func) \
20-
__device__ void atomic_##func(double *address, double val) { \
21-
unsigned long long int* address_as_ull = (unsigned long long int*)address; \
22-
unsigned long long int old = *address_as_ull; \
23-
unsigned long long int assumed; \
24-
do { \
25-
assumed = old; \
26-
old = atomicCAS(address_as_ull, assumed, \
27-
__double_as_longlong(func(val, __longlong_as_double(assumed)))); \
28-
} while (assumed != old); \
29-
} \
30-
__device__ void atomic_##func(float *address, float val) { \
31-
int* address_as_int = (int*)address; \
32-
int old = *address_as_int; \
33-
int assumed; \
34-
do { \
35-
assumed = old; \
36-
old = atomicCAS(address_as_int, assumed, \
37-
__float_as_int(func(val, __int_as_float(assumed)))); \
38-
} while (assumed != old); \
39-
} \
40-
41-
ATOMIC_REAL_MINMAX(max)
42-
ATOMIC_REAL_MINMAX(min)
43-
#endif
18+
__device__ double atomicExch(double *address, double val) {
19+
unsigned long long int* address_as_ull = (unsigned long long int*)address;
20+
unsigned long long res = atomicExch(address_as_ull, __double_as_longlong(val));
21+
return __longlong_as_double(res);
22+
}
4423

4524
template<typename Ty, bool train>
4625
__global__ static
@@ -113,14 +92,16 @@ void updateOutput(
11392
Ty *nWeightCurr = nWeight + nWeightOffset;
11493
if (train) {
11594
Ty absVal = fabs(val);
116-
Ty maxVal = nWeight[key * weightStride + 0];
95+
Ty maxVal = nWeightCurr[0];
11796
if (absVal > maxVal) {
11897
// Updating maxVal and invMaxVal. Go hogwild!
119-
atomic_max(nWeightCurr + 0, absVal);
120-
atomic_min(nWeightCurr + 1, 1.0/absVal);
98+
Ty invAbsVal = 1.0 / absVal;
99+
atomicExch(nWeightCurr + 0, absVal);
100+
atomicExch(nWeightCurr + 1, invAbsVal);
121101
}
122-
val = val * nWeightCurr[1] + nWeightCurr[3];
102+
val = clamp(val * nWeightCurr[1], -1.0, 1.0) + nWeightCurr[3];
123103
normalizedValues[id + tid] = val;
104+
nWeightCurr[2] = 1;
124105
} else {
125106
val = clamp(val * nWeightCurr[1], -1.0, 1.0) + nWeightCurr[3];
126107
}

test.lua

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5939,16 +5939,16 @@ function cunntest.ModuleConversionFunctions()
59395939
end
59405940

59415941
function cunntest.IndexLinear()
5942-
isize = 500E3
5943-
osize = 250
5944-
weightDecay = 0.01
5945-
nnzMin = 1000
5946-
nnzMax = 1500
5947-
idxMin = 1
5948-
idxMax = isize
5949-
batchSize = 128
5950-
lr = 0.01
5951-
ntests = 1
5942+
local isize = 500E3
5943+
local osize = 250
5944+
local weightDecay = 0.01
5945+
local nnzMin = 1000
5946+
local nnzMax = 1500
5947+
local idxMin = 1
5948+
local idxMax = isize
5949+
local batchSize = 128
5950+
local lr = 0.01
5951+
local ntests = 1
59525952

59535953
local errNorm = function(a, b)
59545954
return torch.Tensor(1):fill(torch.cdiv((a - b):abs(), a:abs()):max())
@@ -6101,16 +6101,16 @@ function cunntest.IndexLinear()
61016101
end
61026102

61036103
function cunntest.IndexLinearMaxNorm()
6104-
isize = 500E3
6105-
osize = 250
6106-
weightDecay = 0
6107-
nnzMin = 1000
6108-
nnzMax = 1500
6109-
idxMin = 1
6110-
idxMax = isize
6111-
batchSize = 128
6112-
lr = 0.01
6113-
ntests = 1
6104+
local isize = 500E3
6105+
local osize = 250
6106+
local weightDecay = 0
6107+
local nnzMin = 1000
6108+
local nnzMax = 1500
6109+
local idxMin = 1
6110+
local idxMax = isize
6111+
local batchSize = 128
6112+
local lr = 0.01
6113+
local ntests = 1
61146114

61156115
local errNorm = function(a, b)
61166116
return torch.Tensor(1):fill(torch.cdiv((a - b):abs(), a:abs()):max())

0 commit comments

Comments
 (0)