Skip to content

Commit 2e40857

Browse files
t-vifacebook-github-bot
authored andcommitted
Fix CTC loss for zero-length targets on GPU (pytorch#23298)
Summary: Fixes: pytorch#18215 at last! Also sprinkle tests... Pull Request resolved: pytorch#23298 Differential Revision: D16582145 Pulled By: soumith fbshipit-source-id: bc8b1a629de0c2606e70a2218ccd135f4a9cdc5d
1 parent 08f7f27 commit 2e40857

File tree

4 files changed

+133
-50
lines changed

4 files changed

+133
-50
lines changed

aten/src/ATen/native/LossCTC.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,8 @@ Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef inpu
374374
}
375375
}
376376
if (reduction == Reduction::Mean) {
377-
auto target_lengths_t = at::tensor(target_lengths, res.options());
377+
auto target_lengths_t =
378+
at::tensor(target_lengths, res.options()).clamp_min(1);
378379
return (res / target_lengths_t).mean();
379380
} else if (reduction == Reduction::Sum) {
380381
return res.sum();

aten/src/ATen/native/cuda/LossCTC.cu

Lines changed: 89 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,20 @@ namespace native {
2424

2525
namespace {
2626

27-
// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1]) note that no bound-checking is done
28-
// __restrict__ impact to be measured, https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
29-
template<typename target_t>
30-
__device__ static inline int64_t get_target_prime(const target_t* __restrict__ target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK) {
27+
// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1])
28+
// so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in
29+
// l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK
30+
// - note that no bound-checking is done
31+
// - it is important to only call it witth idx == 0 if the target length is 0
32+
// - __restrict__ impact to be measured, see
33+
// https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
34+
template <typename target_t>
35+
__device__ static inline int64_t get_target_prime(
36+
const target_t* __restrict__ target,
37+
int64_t offset,
38+
int64_t stride,
39+
int64_t idx,
40+
int64_t BLANK) {
3141
if (idx % 2 == 0) {
3242
return BLANK;
3343
} else {
@@ -80,12 +90,16 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
8090
la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK];
8191
break;
8292
case 1:
83-
if (target_length > 0) {
84-
la = log_probs_data[lp_batch_offset + lp_char_stride * get_target_prime(targets_data, tg_batch_offset, tg_target_stride, 1, BLANK)];
85-
}
86-
else {
87-
la = neginf;
88-
}
93+
la = target_length == 0 ? neginf
94+
: log_probs_data
95+
[lp_batch_offset +
96+
lp_char_stride *
97+
get_target_prime(
98+
targets_data,
99+
tg_batch_offset,
100+
tg_target_stride,
101+
1,
102+
BLANK)];
89103
break;
90104
default:
91105
la = neginf;
@@ -100,16 +114,28 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
100114
// These two only depend on s, so we can cache them.
101115
int64_t current_char; // l_s in eq (6)
102116
bool have_three; // flag which of the two cases in eq (6) we have
103-
if (s < 2*target_length+1) {
104-
current_char = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
105-
have_three = ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s-2, BLANK) != current_char));
117+
if (s < 2 * target_length + 1 && target_length > 0) {
118+
current_char = get_target_prime(
119+
targets_data,
120+
tg_batch_offset,
121+
tg_target_stride,
122+
s,
123+
BLANK);
124+
have_three =
125+
((s > 1) &&
126+
(get_target_prime(
127+
targets_data,
128+
tg_batch_offset,
129+
tg_target_stride,
130+
s - 2,
131+
BLANK) != current_char));
106132
} else {
107133
current_char = BLANK;
108134
have_three = false;
109135
}
110136
for (int64_t t=1; t < max_input_length; t++) {
111137
__syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch
112-
if ((t < input_length) && (target_length > 0) && (s < 2*target_length+1)) {
138+
if ((t < input_length) && (s < 2 * target_length + 1)) {
113139
// only for valid t, s. This is equation (6) and (7), la1, la2, la3 are the three summands,
114140
// lamax is the maximum for the logsumexp trick.
115141
scalar_t la1 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * s];
@@ -146,7 +172,11 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
146172
// compute the loss (eq (8))
147173
if (threadIdx.x == 0) {
148174
scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2)];
149-
scalar_t l2 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2-1)];
175+
scalar_t l2 = target_length > 0
176+
? log_alpha_data
177+
[la_batch_offset + la_input_stride * (input_length - 1) +
178+
la_target_stride * (target_length * 2 - 1)]
179+
: neginf;
150180
scalar_t m = ((l1 > l2) ? l1 : l2);
151181
m = ((m == neginf) ? 0 : m);
152182
scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
@@ -236,7 +266,6 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
236266
threads_target /= 2;
237267
}
238268
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
239-
240269
dim3 block(threads_target, threads_batch);
241270
dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);
242271
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -285,8 +314,13 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
285314
scalar_t lb;
286315
if (s == 2*target_length) {
287316
lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * BLANK];
288-
} else if ((target_length > 0) && (s == 2*target_length-1)) {
289-
int64_t current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
317+
} else if (s == 2 * target_length - 1) { // false for target_length == 0
318+
int64_t current_target_prime = get_target_prime(
319+
targets_data,
320+
tg_batch_offset,
321+
tg_target_stride,
322+
s,
323+
BLANK);
290324
lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * current_target_prime];
291325
} else {
292326
lb = neginf;
@@ -301,19 +335,29 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
301335
int64_t s = threadIdx.x + block_s;
302336
int64_t current_target_prime;
303337
bool have_three;
304-
if (s < 2*target_length+1) {
305-
current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
306-
have_three = ((s < 2*target_length-1) &&
307-
(get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s+2, BLANK) !=
308-
current_target_prime));
338+
if (s < 2 * target_length + 1 && target_length > 0) {
339+
current_target_prime = get_target_prime(
340+
targets_data,
341+
tg_batch_offset,
342+
tg_target_stride,
343+
s,
344+
BLANK);
345+
have_three =
346+
((s < 2 * target_length - 1) &&
347+
(get_target_prime(
348+
targets_data,
349+
tg_batch_offset,
350+
tg_target_stride,
351+
s + 2,
352+
BLANK) != current_target_prime));
309353
} else {
310354
current_target_prime = BLANK;
311355
have_three = false;
312356
}
313357
// now go backward in t. Note that we need to skip the last timestep that we did above.
314358
for (int64_t t=max_input_length-2; t>=0; t--) {
315359
__syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch item
316-
if ((t < input_length-1) && (target_length > 0) && (s < 2*target_length+1)) {
360+
if ((t < input_length - 1) && (s < 2 * target_length + 1)) {
317361
scalar_t lb1 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * s];
318362
scalar_t lbmax = lb1;
319363
scalar_t lb2, lb3;
@@ -339,8 +383,13 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
339383
+ log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime];
340384

341385
log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb;
342-
} else if ((s < 2*max_target_length+1) && ((target_length == 0) || (s >= 2*target_length+1) || (t >= input_length))) {
343-
log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = neginf;
386+
} else if (
387+
(s < 2 * max_target_length + 1) &&
388+
(((target_length == 0) && (s > 0)) || (s >= 2 * target_length + 1) ||
389+
(t >= input_length))) {
390+
log_beta_data
391+
[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] =
392+
neginf;
344393
}
345394
}
346395
}
@@ -448,8 +497,13 @@ ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
448497

449498
// collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s]
450499
for (int s = 0; s < 2*max_target_length+1; s++) {
451-
if ((target_length > 0) && (s < 2*target_length+1)) {
452-
int64_t current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
500+
if (s < 2 * target_length + 1) { // if target_length == 0, s == 0
501+
int64_t current_target_prime = get_target_prime(
502+
targets_data,
503+
tg_batch_offset,
504+
tg_target_stride,
505+
s,
506+
BLANK);
453507
scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s]
454508
+ log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s]);
455509
scalar_t& lcab = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * current_target_prime];
@@ -569,7 +623,6 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
569623
{
570624
dim3 block(threads_target, threads_batch);
571625
dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);
572-
573626
ctc_loss_backward_log_beta_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
574627
(log_beta.data<scalar_t>(),
575628
log_probs.data<scalar_t>(), input_lengths_t.data<int64_t>(), log_probs.size(0),
@@ -612,12 +665,16 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
612665
// For the non-blank characters, we use a kernel to compute the subtrahend.
613666
// Again we might configure block and grid in a better way.
614667
int threads_target = max_threads;
615-
while (threads_target / 2 >= max_target_length) {
668+
while (threads_target / 2 >= max_target_length && threads_target > 1) {
616669
threads_target /= 2;
617670
}
618671
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
619672
dim3 block(threads_target, threads_batch);
620-
dim3 grid((max_target_length + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);
673+
dim3 grid(
674+
std::max<int>(
675+
(max_target_length + threads_target - 1) / threads_target, 1),
676+
(batch_size + threads_batch - 1) / threads_batch,
677+
1);
621678
ctc_loss_backward_collect_nonblank_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
622679
(grad.data<scalar_t>(),
623680
grad_out.data<scalar_t>(), grad_out.stride(0),
@@ -635,13 +692,12 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
635692
} else { // small problem, use naive algorithm
636693
// Still no block/grid configuration guru...
637694
int threads_input = max_threads;
638-
while (threads_input / 2 >= log_probs.size(0)) {
695+
while (threads_input / 2 >= log_probs.size(0) && threads_input > 1) {
639696
threads_input /= 2;
640697
}
641698
threads_batch = std::min(max_threads / threads_input, (int) batch_size);
642699
dim3 block(threads_input, threads_batch);
643700
dim3 grid((log_probs.size(0) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch);
644-
645701
ctc_loss_backward_collect_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
646702
(grad.data<scalar_t>(),
647703
grad_out.data<scalar_t>(), grad_out.stride(0),

test/test_autograd.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,25 +1610,42 @@ def test_ctc_loss(self):
16101610
target_length = 15
16111611
gradcheck_input_size = 10
16121612

1613-
# device, input_length
1614-
tests = [('cpu', 150, False),
1615-
('cpu', 150, True)]
1613+
ZERO_NONE = 0
1614+
ZERO_SOME = 1
1615+
ZERO_ALL = 2
1616+
1617+
# device, input_length, vary_lengths, zero_lengths
1618+
tests = [('cpu', 150, False, ZERO_NONE),
1619+
('cpu', 150, True, ZERO_NONE),
1620+
('cpu', 50, True, ZERO_SOME),
1621+
('cpu', 50, True, ZERO_ALL)]
16161622
if torch.cuda.is_available():
1617-
tests += [('cuda', 50, False),
1618-
('cuda', 150, False),
1619-
('cuda', 50, True),
1620-
('cuda', 150, True)]
1621-
1622-
for device, input_length, vary_lengths in tests:
1623+
tests += [('cuda', 50, False, ZERO_NONE),
1624+
('cuda', 150, False, ZERO_NONE),
1625+
('cuda', 50, True, ZERO_NONE),
1626+
('cuda', 150, True, ZERO_NONE),
1627+
('cuda', 50, True, ZERO_SOME),
1628+
('cuda', 150, True, ZERO_SOME),
1629+
('cuda', 50, True, ZERO_ALL),
1630+
('cuda', 150, True, ZERO_ALL)]
1631+
1632+
for device, input_length, vary_lengths, zero_mode in tests:
16231633
targets = torch.randint(1, num_labels, (batch_size, target_length),
16241634
device=device, dtype=torch.long)
16251635
x = torch.randn(gradcheck_input_size, device=device, requires_grad=True)
16261636
tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1,
16271637
device=device)
16281638
input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item()
16291639
if vary_lengths or i == 0 else input_length) for i in range(batch_size)]
1630-
target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
1631-
if vary_lengths else target_length) for i in range(batch_size)]
1640+
if zero_mode == ZERO_ALL:
1641+
target_lengths = [0 for _ in range(batch_size)]
1642+
else:
1643+
target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
1644+
if vary_lengths else target_length) for _ in range(batch_size)]
1645+
if zero_mode == ZERO_SOME:
1646+
idxes = torch.randint(0, batch_size, (10,))
1647+
for i in idxes:
1648+
target_lengths[i] = 0
16321649

16331650
def ctc_after_softmax(x):
16341651
x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels]

test/test_nn.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5608,20 +5608,29 @@ def test_CTCLoss_lengthchecks_cpu(self):
56085608
with self.assertRaises(RuntimeError):
56095609
torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
56105610

5611-
def test_CTCLoss_empty_target_cpu(self):
5611+
def _test_CTCLoss_empty_target(self, device):
56125612
target_lengths = [0, 0, 0]
56135613
input_lengths = [50, 50, 50]
5614-
targets = torch.randint(1, 15, (0,), dtype=torch.int)
5615-
log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2)
5614+
targets = torch.randint(1, 15, (0,), dtype=torch.long, device=device)
5615+
log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2)
56165616
loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
56175617
self.assertTrue((loss >= 0).all().item())
5618+
self.assertAlmostEqual(-log_probs.sum(0)[:, 0], loss)
56185619

56195620
target_lengths = [0, 9, 0]
56205621
input_lengths = [50, 50, 50]
5621-
targets = torch.randint(1, 15, (9,), dtype=torch.int)
5622-
log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2)
5622+
targets = torch.randint(1, 15, (9,), dtype=torch.long, device=device)
5623+
log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2)
56235624
loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
56245625
self.assertTrue((loss >= 0).all().item())
5626+
self.assertAlmostEqual(-log_probs.sum(0)[[0, 2], 0], loss[[0, 2]])
5627+
5628+
def test_CTCLoss_empty_target_cpu(self):
5629+
self._test_CTCLoss_empty_target('cpu')
5630+
5631+
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
5632+
def test_CTCLoss_empty_target_cuda(self):
5633+
self._test_CTCLoss_empty_target('cuda')
56255634

56265635
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
56275636
def test_CTCLoss_zero_infinity(self):

0 commit comments

Comments
 (0)