Skip to content

Commit a7d9875

Browse files
committed
Merge commit '4e49aed5eaa5a4abaf0a51bb87a49b44394ea3c3'
2 parents c3cda26 + 4e49aed commit a7d9875

File tree

5 files changed

+609
-34
lines changed

5 files changed

+609
-34
lines changed

torch/lib/THNN/generic/ClassNLLCriterion.c

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ void THNN_(ClassNLLCriterion_updateOutput)(
99
THTensor *output,
1010
bool sizeAverage,
1111
THTensor *weights,
12-
THTensor *total_weight)
12+
THTensor *total_weight,
13+
long ignore_index)
1314
{
1415
THNN_CHECK_DIM_SIZE(output, 1, 0, 1);
1516
THNN_CHECK_DIM_SIZE(total_weight, 1, 0, 1);
1617
int n_dims = THTensor_(nDimension)(input);
1718
int n_classes = THTensor_(size)(input, n_dims - 1);
19+
ignore_index -= TH_INDEX_BASE;
1820

1921
if (THIndexTensor_(nDimension)(target) > 1) {
2022
THError("multi-target not supported");
@@ -42,9 +44,11 @@ void THNN_(ClassNLLCriterion_updateOutput)(
4244

4345
if (THTensor_(nDimension)(input) == 1) {
4446
int cur_target = target_data[0] - TH_INDEX_BASE;
45-
THAssert(cur_target >= 0 && cur_target < n_classes);
46-
total_weight_data[0] = weights ? weights_data[cur_target] : 1.0f;
47-
output_data[0] = -input_data[cur_target] * total_weight_data[0];
47+
if (cur_target != ignore_index) {
48+
THAssert(cur_target >= 0 && cur_target < n_classes);
49+
total_weight_data[0] = weights ? weights_data[cur_target] : 1.0f;
50+
output_data[0] = -input_data[cur_target] * total_weight_data[0];
51+
}
4852
} else if (THTensor_(nDimension)(input) == 2) {
4953
int batch_size = THTensor_(size)(input, 0);
5054
THAssert(THIndexTensor_(size)(target, 0) == batch_size);
@@ -54,11 +58,13 @@ void THNN_(ClassNLLCriterion_updateOutput)(
5458
int i;
5559
for (i = 0; i < batch_size; i++) {
5660
int cur_target = target_data[i] - TH_INDEX_BASE;
57-
THAssert(cur_target >= 0 && cur_target < n_classes);
61+
if (cur_target != ignore_index) {
62+
THAssert(cur_target >= 0 && cur_target < n_classes);
5863

59-
real cur_weight = weights ? weights_data[cur_target] : 1.0f;
60-
total_weight_data[0] += cur_weight;
61-
output_data[0] -= input_data[i * n_target + cur_target] * cur_weight;
64+
real cur_weight = weights ? weights_data[cur_target] : 1.0f;
65+
total_weight_data[0] += cur_weight;
66+
output_data[0] -= input_data[i * n_target + cur_target] * cur_weight;
67+
}
6268
}
6369
}
6470

@@ -80,10 +86,12 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
8086
THTensor *gradInput,
8187
bool sizeAverage,
8288
THTensor *weights,
83-
THTensor *total_weight)
89+
THTensor *total_weight,
90+
long ignore_index)
8491
{
8592
int n_dims = THTensor_(nDimension)(input);
8693
int n_classes = THTensor_(size)(input, n_dims - 1);
94+
ignore_index -= TH_INDEX_BASE;
8795

8896
if (!THTensor_(isContiguous)(gradInput)) {
8997
THError("gradInput must be contiguous");
@@ -102,7 +110,7 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
102110
if (THTensor_(nDimension)(input) > 2) {
103111
THError("input tensor should be 1D or 2D");
104112
}
105-
113+
106114
if (weights && THTensor_(nElement)(weights) != n_classes) {
107115
THError("weight tensor should be defined either for all or no classes");
108116
}
@@ -116,10 +124,12 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
116124

117125
if (THTensor_(nDimension)(input) == 1) {
118126
int cur_target = target_data[0] - TH_INDEX_BASE;
119-
THAssert(cur_target >= 0 && cur_target < n_classes);
127+
if (cur_target != ignore_index) {
128+
THAssert(cur_target >= 0 && cur_target < n_classes);
120129

121-
gradInput_data[cur_target] =
122-
(!sizeAverage && weights) ? -weights_data[cur_target] : -1;
130+
gradInput_data[cur_target] =
131+
(!sizeAverage && weights) ? -weights_data[cur_target] : -1;
132+
}
123133

124134
} else if (THTensor_(nDimension)(input) == 2) {
125135
int batch_size = THTensor_(size)(input, 0);
@@ -131,13 +141,15 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
131141
for (i = 0; i < batch_size; i++){
132142
int cur_target = target_data[i] - TH_INDEX_BASE;
133143

134-
THAssert(cur_target >= 0 && cur_target < n_classes);
144+
if (cur_target != ignore_index) {
145+
THAssert(cur_target >= 0 && cur_target < n_classes);
135146

136-
gradInput_data[i * n_target + cur_target] =
137-
-(weights ? weights_data[cur_target] : 1.0f);
147+
gradInput_data[i * n_target + cur_target] =
148+
-(weights ? weights_data[cur_target] : 1.0f);
138149

139-
if (sizeAverage && *total_weight_data) {
140-
gradInput_data[i * n_target + cur_target] /= *total_weight_data;
150+
if (sizeAverage && *total_weight_data) {
151+
gradInput_data[i * n_target + cur_target] /= *total_weight_data;
152+
}
141153
}
142154
}
143155
}

torch/lib/THNN/generic/FusedRNNKernel.c

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,19 @@ void THNN_(GRUFused_updateOutput)(
99
THTensor *bias1,
1010
THTensor *bias2,
1111
THTensor *hx,
12-
THTensor *hy)
12+
THTensor *hy,
13+
THTensor *storage)
1314
{
1415
THAssertMsg(false, "Not implemented for CPU");
1516
}
1617

1718
void THNN_(GRUFused_updateGradInput)(
1819
THNNState *state,
19-
THTensor *input,
20-
THTensor *hidden,
20+
THTensor *gradInInput,
21+
THTensor *gradInHidden,
2122
THTensor *gradOutput,
22-
THTensor *gradInput)
23+
THTensor *gradInputHx,
24+
THTensor *storage)
2325
{
2426
THAssertMsg(false, "Not implemented for CPU");
2527
}
@@ -39,13 +41,13 @@ void THNN_(LSTMFused_updateOutput)(
3941

4042
void THNN_(LSTMFused_updateGradInput)(
4143
THNNState *state,
42-
THTensor *input,
43-
THTensor *hidden,
44+
THTensor *storage,
45+
THTensor *gradInGates,
4446
THTensor *prevC,
4547
THTensor *cy,
4648
THTensor *gradOutput,
4749
THTensor *gradOutputCell,
48-
THTensor *gradInput)
50+
THTensor *gradInputCx)
4951
{
5052
THAssertMsg(false, "Not implemented for CPU");
5153
}

0 commit comments

Comments
 (0)