Skip to content

Commit 7e62971

Browse files
committed
Merge commit '71ccedbc6c4e460d38c794737bba780e7673e888'
2 parents a7d9875 + 71ccedb commit 7e62971

File tree

7 files changed

+902
-202
lines changed

7 files changed

+902
-202
lines changed

torch/lib/THCUNN/ClassNLLCriterion.cu

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,22 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(Dtype *output,
1515
THCIndex_t *target,
1616
Dtype *weights,
1717
int size_average,
18-
int n_classes) {
18+
int n_classes,
19+
long ignore_index) {
1920
assert(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0);
2021

2122
// TODO: T4951791 Reuse code between updateOutput_kernel1 and
2223
// updateOutput_kernel.
2324

2425
int t = (int)*target - TH_INDEX_BASE;
25-
assert(t >= 0 && t < n_classes);
26-
Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
27-
*output = -cur_weight * input[t];
28-
*total_weight = cur_weight;
29-
if (size_average && *total_weight > 0) {
30-
*output /= *total_weight;
26+
if (t != ignore_index) {
27+
assert(t >= 0 && t < n_classes);
28+
Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
29+
*output = -cur_weight * input[t];
30+
*total_weight = cur_weight;
31+
if (size_average && *total_weight > 0) {
32+
*output /= *total_weight;
33+
}
3134
}
3235
}
3336

@@ -40,7 +43,8 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output,
4043
int size_average,
4144
int nframe,
4245
int ndim,
43-
int n_classes) {
46+
int n_classes,
47+
long ignore_index) {
4448
__shared__ Acctype shInputs[NTHREADS], acc_weight[NTHREADS];
4549
int i, t;
4650
Dtype cur_weight;
@@ -49,10 +53,12 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output,
4953
acc_weight[threadIdx.x] = ScalarConvert<int, Acctype>::to(0);
5054
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
5155
t = target[i] - TH_INDEX_BASE;
52-
assert(t >= 0 && t < n_classes);
53-
cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
54-
shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
55-
acc_weight[threadIdx.x] += cur_weight;
56+
if (t != ignore_index) {
57+
assert(t >= 0 && t < n_classes);
58+
cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
59+
shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
60+
acc_weight[threadIdx.x] += cur_weight;
61+
}
5662
}
5763
__syncthreads();
5864

@@ -84,15 +90,18 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1(
8490
THCIndex_t* target,
8591
Dtype* total_weight,
8692
int size_average,
87-
int n_classes)
93+
int n_classes,
94+
long ignore_index)
8895
{
8996
if (*total_weight <= 0) {
9097
return;
9198
}
9299
Dtype norm = size_average ? (ScalarConvert<int, Dtype>::to(1) / *total_weight) : ScalarConvert<int, Dtype>::to(1);
93100
int t = (int)*target - TH_INDEX_BASE;
94-
assert(t >= 0 && t < n_classes);
95-
gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
101+
if (t != ignore_index) {
102+
assert(t >= 0 && t < n_classes);
103+
gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
104+
}
96105
}
97106

98107
template <typename Dtype>
@@ -104,7 +113,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
104113
int size_average,
105114
int nframe,
106115
int ndim,
107-
int n_classes)
116+
int n_classes,
117+
long ignore_index)
108118
{
109119
if (*total_weight <= 0) {
110120
return;
@@ -114,8 +124,10 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
114124

115125
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
116126
t = (int)target[i] - TH_INDEX_BASE;
117-
assert(t >= 0 && t < n_classes);
118-
gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
127+
if (t != ignore_index) {
128+
assert(t >= 0 && t < n_classes);
129+
gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
130+
}
119131
}
120132
}
121133

torch/lib/THCUNN/IndexLinear.cu

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,32 +15,11 @@ const long NNZ_PER_BLOCK_MAX = 1024;
1515
#define clamp(a, low, high) max(min((a), (high)), (low))
1616
#endif
1717

18-
#ifndef ATOMIC_REAL_MINMAX
19-
#define ATOMIC_REAL_MINMAX(func) \
20-
__device__ void atomic_##func(double *address, double val) { \
21-
unsigned long long int* address_as_ull = (unsigned long long int*)address; \
22-
unsigned long long int old = *address_as_ull; \
23-
unsigned long long int assumed; \
24-
do { \
25-
assumed = old; \
26-
old = atomicCAS(address_as_ull, assumed, \
27-
__double_as_longlong(func(val, __longlong_as_double(assumed)))); \
28-
} while (assumed != old); \
29-
} \
30-
__device__ void atomic_##func(float *address, float val) { \
31-
int* address_as_int = (int*)address; \
32-
int old = *address_as_int; \
33-
int assumed; \
34-
do { \
35-
assumed = old; \
36-
old = atomicCAS(address_as_int, assumed, \
37-
__float_as_int(func(val, __int_as_float(assumed)))); \
38-
} while (assumed != old); \
39-
} \
40-
41-
ATOMIC_REAL_MINMAX(max)
42-
ATOMIC_REAL_MINMAX(min)
43-
#endif
18+
__device__ double atomicExch(double *address, double val) {
19+
unsigned long long int* address_as_ull = (unsigned long long int*)address;
20+
unsigned long long res = atomicExch(address_as_ull, __double_as_longlong(val));
21+
return __longlong_as_double(res);
22+
}
4423

4524
template<typename Ty, bool train>
4625
__global__ static
@@ -113,14 +92,16 @@ void updateOutput(
11392
Ty *nWeightCurr = nWeight + nWeightOffset;
11493
if (train) {
11594
Ty absVal = fabs(val);
116-
Ty maxVal = nWeight[key * weightStride + 0];
95+
Ty maxVal = nWeightCurr[0];
11796
if (absVal > maxVal) {
11897
// Updating maxVal and invMaxVal. Go hogwild!
119-
atomic_max(nWeightCurr + 0, absVal);
120-
atomic_min(nWeightCurr + 1, 1.0/absVal);
98+
Ty invAbsVal = 1.0 / absVal;
99+
atomicExch(nWeightCurr + 0, absVal);
100+
atomicExch(nWeightCurr + 1, invAbsVal);
121101
}
122-
val = val * nWeightCurr[1] + nWeightCurr[3];
102+
val = clamp(val * nWeightCurr[1], -1.0, 1.0) + nWeightCurr[3];
123103
normalizedValues[id + tid] = val;
104+
nWeightCurr[2] = 1;
124105
} else {
125106
val = clamp(val * nWeightCurr[1], -1.0, 1.0) + nWeightCurr[3];
126107
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#include "THCUNN.h"
2+
#include "common.h"
3+
#include "im2col.h"
4+
5+
#include "THCHalf.h"
6+
#include "THCHalfAutoNumerics.cuh"
7+
8+
#include "generic/SpatialDepthWiseConvolution.cu"
9+
#include "THCGenerateFloatTypes.h"

torch/lib/THCUNN/generic/ClassNLLCriterion.cu

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ void THNN_(ClassNLLCriterion_updateOutput)(
99
THCTensor *output,
1010
bool sizeAverage,
1111
THCTensor *weights,
12-
THCTensor *total_weight) {
12+
THCTensor *total_weight,
13+
long ignore_index) {
1314
THCUNN_check_dim_size(state, output, 1, 0, 1);
1415
THCUNN_check_dim_size(state, total_weight, 1, 0, 1);
16+
ignore_index -= TH_INDEX_BASE;
1517

1618
if (THCIndexTensor_(nDimension)(state, target) > 1) {
1719
THError("multi-target not supported");
@@ -63,7 +65,8 @@ void THNN_(ClassNLLCriterion_updateOutput)(
6365
target_data,
6466
weights_data,
6567
sizeAverage,
66-
n_classes
68+
n_classes,
69+
ignore_index
6770
);
6871

6972
} else if (THCTensor_(nDimension)(state, input) == 2) {
@@ -77,7 +80,8 @@ void THNN_(ClassNLLCriterion_updateOutput)(
7780
sizeAverage,
7881
THCTensor_(size)(state, input, 0),
7982
THCTensor_(size)(state, input, 1),
80-
n_classes
83+
n_classes,
84+
ignore_index
8185
);
8286
}
8387
THCudaCheck(cudaGetLastError());
@@ -96,10 +100,12 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
96100
THCTensor *gradInput,
97101
bool sizeAverage,
98102
THCTensor *weights,
99-
THCTensor *total_weight) {
103+
THCTensor *total_weight,
104+
long ignore_index) {
100105
if (THCIndexTensor_(nDimension)(state, target) > 1) {
101106
THError("multi-target not supported");
102107
}
108+
ignore_index -= TH_INDEX_BASE;
103109

104110
int n_dims = THCTensor_(nDimension)(state, input);
105111
int n_classes = THCTensor_(size)(state, input, n_dims - 1);
@@ -145,7 +151,8 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
145151
target_data,
146152
total_weight_data,
147153
sizeAverage,
148-
n_classes
154+
n_classes,
155+
ignore_index
149156
);
150157
} else {
151158
cunn_ClassNLLCriterion_updateGradInput_kernel<real>
@@ -157,7 +164,8 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
157164
sizeAverage,
158165
THCTensor_(size)(state, input, 0),
159166
THCTensor_(size)(state, input, 1),
160-
n_classes
167+
n_classes,
168+
ignore_index
161169
);
162170
}
163171
THCudaCheck(cudaGetLastError());

0 commit comments

Comments
 (0)