Skip to content

Commit e66e00c

Browse files
t-vifacebook-github-bot
authored andcommitted
Fix native ctc_loss gradient indexing bug for large target sizes (pytorch#27460)
Summary: Fixes: pytorch#27442 Thank you Mohamed Yousef (ASDen) for the report with minimal reproducing example and detailed analysis! Pull Request resolved: pytorch#27460 Differential Revision: D17789378 Pulled By: soumith fbshipit-source-id: dc01a31b998cced4462e933d4b32e09b331f7e41
1 parent 17a54e1 commit e66e00c

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

aten/src/ATen/native/cuda/LossCTC.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_da
428428
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
429429
int64_t batch_size, int64_t num_labels, int64_t BLANK, bool zero_infinity) {
430430
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
431-
int64_t s = threadIdx.x + blockIdx.x * blockDim.y; // note, this directly indexes into targets, no targets prime!
431+
int64_t s = threadIdx.x + blockIdx.x * blockDim.x; // note, this directly indexes into targets, not targets prime!
432432

433433
if (b >= batch_size)
434434
return;

test/test_nn.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4190,6 +4190,30 @@ def test_CTCLoss_lengthchecks_cpu(self):
41904190
with self.assertRaises(RuntimeError):
41914191
torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
41924192

4193+
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
4194+
def test_CTCLoss_long_targets(self):
4195+
input_length = 4000
4196+
vocab_size = 3
4197+
batch_size = 4
4198+
target_length = 1200
4199+
4200+
log_probs = torch.randn(input_length, batch_size, vocab_size).log_softmax(2).requires_grad_()
4201+
targets = torch.randint(low=1, high=vocab_size - 1, size=(batch_size, target_length), dtype=torch.long)
4202+
input_lengths = batch_size * [input_length]
4203+
target_lengths = batch_size * [target_length]
4204+
4205+
res_cpu = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths,
4206+
reduction='sum', zero_infinity=True)
4207+
grad_out = torch.randn_like(res_cpu)
4208+
grad_cpu, = torch.autograd.grad(res_cpu, log_probs, grad_out)
4209+
4210+
with torch.backends.cudnn.flags(enabled=False):
4211+
res_gpu = torch.nn.functional.ctc_loss(log_probs.cuda(), targets.cuda(), input_lengths, target_lengths,
4212+
reduction='sum', zero_infinity=True)
4213+
grad_gpu, = torch.autograd.grad(res_gpu, log_probs, grad_out.cuda())
4214+
self.assertAlmostEqual(res_cpu, res_gpu, delta=1e-4)
4215+
self.assertAlmostEqual(grad_cpu, grad_gpu, delta=1e-4)
4216+
41934217
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
41944218
def test_CTCLoss_zero_infinity(self):
41954219
target_lengths = [60, 25, 20]

0 commit comments

Comments
 (0)