@@ -24,10 +24,20 @@ namespace native {
2424
2525namespace {
2626
27- // this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1]) note that no bound-checking is done
28- // __restrict__ impact to be measured, https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
29- template <typename target_t >
30- __device__ static inline int64_t get_target_prime (const target_t * __restrict__ target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK) {
27+ // this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1])
28+ // so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in
29+ // l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK
30+ // - note that no bound-checking is done
31+ // - it is important to only call it witth idx == 0 if the target length is 0
32+ // - __restrict__ impact to be measured, see
33+ // https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
34+ template <typename target_t >
35+ __device__ static inline int64_t get_target_prime (
36+ const target_t * __restrict__ target,
37+ int64_t offset,
38+ int64_t stride,
39+ int64_t idx,
40+ int64_t BLANK) {
3141 if (idx % 2 == 0 ) {
3242 return BLANK;
3343 } else {
@@ -80,12 +90,16 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
8090 la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK];
8191 break ;
8292 case 1 :
83- if (target_length > 0 ) {
84- la = log_probs_data[lp_batch_offset + lp_char_stride * get_target_prime (targets_data, tg_batch_offset, tg_target_stride, 1 , BLANK)];
85- }
86- else {
87- la = neginf;
88- }
93+ la = target_length == 0 ? neginf
94+ : log_probs_data
95+ [lp_batch_offset +
96+ lp_char_stride *
97+ get_target_prime (
98+ targets_data,
99+ tg_batch_offset,
100+ tg_target_stride,
101+ 1 ,
102+ BLANK)];
89103 break ;
90104 default :
91105 la = neginf;
@@ -100,16 +114,28 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
100114 // These two only depend on s, so we can cache them.
101115 int64_t current_char; // l_s in eq (6)
102116 bool have_three; // flag which of the two cases in eq (6) we have
103- if (s < 2 *target_length+1 ) {
104- current_char = get_target_prime (targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
105- have_three = ((s > 1 ) && (get_target_prime (targets_data, tg_batch_offset, tg_target_stride, s-2 , BLANK) != current_char));
117+ if (s < 2 * target_length + 1 && target_length > 0 ) {
118+ current_char = get_target_prime (
119+ targets_data,
120+ tg_batch_offset,
121+ tg_target_stride,
122+ s,
123+ BLANK);
124+ have_three =
125+ ((s > 1 ) &&
126+ (get_target_prime (
127+ targets_data,
128+ tg_batch_offset,
129+ tg_target_stride,
130+ s - 2 ,
131+ BLANK) != current_char));
106132 } else {
107133 current_char = BLANK;
108134 have_three = false ;
109135 }
110136 for (int64_t t=1 ; t < max_input_length; t++) {
111137 __syncthreads (); // on cuda 9 we might use partial synchronization of only the threads within the same batch
112- if ((t < input_length) && (target_length > 0 ) && ( s < 2 * target_length+ 1 )) {
138+ if ((t < input_length) && (s < 2 * target_length + 1 )) {
113139 // only for valid t, s. This is equation (6) and (7), la1, la2, la3 are the three summands,
114140 // lamax is the maximum for the logsumexp trick.
115141 scalar_t la1 = log_alpha_data[la_batch_offset + la_input_stride * (t-1 ) + la_target_stride * s];
@@ -146,7 +172,11 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
146172 // compute the loss (eq (8))
147173 if (threadIdx .x == 0 ) {
148174 scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1 ) + la_target_stride * (target_length*2 )];
149- scalar_t l2 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1 ) + la_target_stride * (target_length*2 -1 )];
175+ scalar_t l2 = target_length > 0
176+ ? log_alpha_data
177+ [la_batch_offset + la_input_stride * (input_length - 1 ) +
178+ la_target_stride * (target_length * 2 - 1 )]
179+ : neginf;
150180 scalar_t m = ((l1 > l2) ? l1 : l2);
151181 m = ((m == neginf) ? 0 : m);
152182 scalar_t log_likelihood = std::log (std::exp (l1-m)+std::exp (l2-m))+m;
@@ -236,7 +266,6 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
236266 threads_target /= 2 ;
237267 }
238268 int threads_batch = std::min (max_threads / threads_target, (int ) batch_size);
239-
240269 dim3 block (threads_target, threads_batch);
241270 dim3 grid ((2 *max_target_length+1 + threads_target-1 )/threads_target, (batch_size+threads_batch-1 )/threads_batch);
242271 cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
@@ -285,8 +314,13 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
285314 scalar_t lb;
286315 if (s == 2 *target_length) {
287316 lb = log_probs_data[lp_batch_offset + (input_length-1 ) * lp_input_stride + lp_char_stride * BLANK];
288- } else if ((target_length > 0 ) && (s == 2 *target_length-1 )) {
289- int64_t current_target_prime = get_target_prime (targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
317+ } else if (s == 2 * target_length - 1 ) { // false for target_length == 0
318+ int64_t current_target_prime = get_target_prime (
319+ targets_data,
320+ tg_batch_offset,
321+ tg_target_stride,
322+ s,
323+ BLANK);
290324 lb = log_probs_data[lp_batch_offset + (input_length-1 ) * lp_input_stride + lp_char_stride * current_target_prime];
291325 } else {
292326 lb = neginf;
@@ -301,19 +335,29 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
301335 int64_t s = threadIdx .x + block_s;
302336 int64_t current_target_prime;
303337 bool have_three;
304- if (s < 2 *target_length+1 ) {
305- current_target_prime = get_target_prime (targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
306- have_three = ((s < 2 *target_length-1 ) &&
307- (get_target_prime (targets_data, tg_batch_offset, tg_target_stride, s+2 , BLANK) !=
308- current_target_prime));
338+ if (s < 2 * target_length + 1 && target_length > 0 ) {
339+ current_target_prime = get_target_prime (
340+ targets_data,
341+ tg_batch_offset,
342+ tg_target_stride,
343+ s,
344+ BLANK);
345+ have_three =
346+ ((s < 2 * target_length - 1 ) &&
347+ (get_target_prime (
348+ targets_data,
349+ tg_batch_offset,
350+ tg_target_stride,
351+ s + 2 ,
352+ BLANK) != current_target_prime));
309353 } else {
310354 current_target_prime = BLANK;
311355 have_three = false ;
312356 }
313357 // now go backward in t. Note that we need to skip the last timestep that we did above.
314358 for (int64_t t=max_input_length-2 ; t>=0 ; t--) {
315359 __syncthreads (); // on cuda 9 we might use partial synchronization of only the threads within the same batch item
316- if ((t < input_length- 1 ) && (target_length > 0 ) && ( s < 2 * target_length+ 1 )) {
360+ if ((t < input_length - 1 ) && (s < 2 * target_length + 1 )) {
317361 scalar_t lb1 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1 ) + lb_target_stride * s];
318362 scalar_t lbmax = lb1;
319363 scalar_t lb2, lb3;
@@ -339,8 +383,13 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
339383 + log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime];
340384
341385 log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb;
342- } else if ((s < 2 *max_target_length+1 ) && ((target_length == 0 ) || (s >= 2 *target_length+1 ) || (t >= input_length))) {
343- log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = neginf;
386+ } else if (
387+ (s < 2 * max_target_length + 1 ) &&
388+ (((target_length == 0 ) && (s > 0 )) || (s >= 2 * target_length + 1 ) ||
389+ (t >= input_length))) {
390+ log_beta_data
391+ [lb_batch_offset + lb_input_stride * t + lb_target_stride * s] =
392+ neginf;
344393 }
345394 }
346395 }
@@ -448,8 +497,13 @@ ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
448497
449498 // collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s]
450499 for (int s = 0 ; s < 2 *max_target_length+1 ; s++) {
451- if ((target_length > 0 ) && (s < 2 *target_length+1 )) {
452- int64_t current_target_prime = get_target_prime (targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
500+ if (s < 2 * target_length + 1 ) { // if target_length == 0, s == 0
501+ int64_t current_target_prime = get_target_prime (
502+ targets_data,
503+ tg_batch_offset,
504+ tg_target_stride,
505+ s,
506+ BLANK);
453507 scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s]
454508 + log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s]);
455509 scalar_t & lcab = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * current_target_prime];
@@ -569,7 +623,6 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
569623 {
570624 dim3 block (threads_target, threads_batch);
571625 dim3 grid ((2 *max_target_length+1 + threads_target-1 )/threads_target, (batch_size+threads_batch-1 )/threads_batch);
572-
573626 ctc_loss_backward_log_beta_gpu_kernel<scalar_t , target_t ><<<grid, block, 0 , stream>>>
574627 (log_beta.data <scalar_t >(),
575628 log_probs.data <scalar_t >(), input_lengths_t .data <int64_t >(), log_probs.size (0 ),
@@ -612,12 +665,16 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
612665 // For the non-blank characters, we use a kernel to compute the subtrahend.
613666 // Again we might configure block and grid in a better way.
614667 int threads_target = max_threads;
615- while (threads_target / 2 >= max_target_length) {
668+ while (threads_target / 2 >= max_target_length && threads_target > 1 ) {
616669 threads_target /= 2 ;
617670 }
618671 int threads_batch = std::min (max_threads / threads_target, (int ) batch_size);
619672 dim3 block (threads_target, threads_batch);
620- dim3 grid ((max_target_length + threads_target-1 )/threads_target, (batch_size+threads_batch-1 )/threads_batch);
673+ dim3 grid (
674+ std::max<int >(
675+ (max_target_length + threads_target - 1 ) / threads_target, 1 ),
676+ (batch_size + threads_batch - 1 ) / threads_batch,
677+ 1 );
621678 ctc_loss_backward_collect_nonblank_gpu_kernel<scalar_t , target_t ><<<grid, block, 0 , stream>>>
622679 (grad.data <scalar_t >(),
623680 grad_out.data <scalar_t >(), grad_out.stride (0 ),
@@ -635,13 +692,12 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
635692 } else { // small problem, use naive algorithm
636693 // Still no block/grid configuration guru...
637694 int threads_input = max_threads;
638- while (threads_input / 2 >= log_probs.size (0 )) {
695+ while (threads_input / 2 >= log_probs.size (0 ) && threads_input > 1 ) {
639696 threads_input /= 2 ;
640697 }
641698 threads_batch = std::min (max_threads / threads_input, (int ) batch_size);
642699 dim3 block (threads_input, threads_batch);
643700 dim3 grid ((log_probs.size (0 ) + threads_input-1 )/threads_input, (batch_size+threads_batch-1 )/threads_batch);
644-
645701 ctc_loss_backward_collect_gpu_kernel<scalar_t , target_t ><<<grid, block, 0 , stream>>>
646702 (grad.data <scalar_t >(),
647703 grad_out.data <scalar_t >(), grad_out.stride (0 ),
0 commit comments