@@ -178,22 +178,26 @@ void THNN_(VolumetricConvolution_updateOutput)(
178178 long k_ = 1 ;
179179
180180 // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
181- #ifdef THC_REAL_IS_FLOAT
182- THCudaBlas_Sgemm (
183- #elif defined(THC_REAL_IS_HALF)
184- THCudaBlas_Hgemm (
185- #elif defined(THC_REAL_IS_DOUBLE)
186- THCudaBlas_Dgemm (
187- #endif
188- state,
189- ' t' , ' n' ,
190- n_, m_, k_,
191- ScalarConvert<int , real>::to (1 ),
192- THCTensor_ (data)(state, ones), k_,
193- THCTensor_ (data)(state, bias), k_,
194- ScalarConvert<int , real>::to (0 ),
195- THCTensor_ (data)(state, output_n), n_
196- );
181+ if (bias) {
182+ #ifdef THC_REAL_IS_FLOAT
183+ THCudaBlas_Sgemm (
184+ #elif defined(THC_REAL_IS_HALF)
185+ THCudaBlas_Hgemm (
186+ #elif defined(THC_REAL_IS_DOUBLE)
187+ THCudaBlas_Dgemm (
188+ #endif
189+ state,
190+ ' t' , ' n' ,
191+ n_, m_, k_,
192+ ScalarConvert<int , real>::to (1 ),
193+ THCTensor_ (data)(state, ones), k_,
194+ THCTensor_ (data)(state, bias), k_,
195+ ScalarConvert<int , real>::to (0 ),
196+ THCTensor_ (data)(state, output_n), n_
197+ );
198+ } else {
199+ THCTensor_ (zero)(state, output_n);
200+ }
197201
198202 // Extract columns:
199203 im3d2col (
@@ -460,36 +464,38 @@ void THNN_(VolumetricConvolution_accGradParameters)(
460464 long k_ = outputDepth * outputHeight * outputWidth;
461465
462466 // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
463- #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
464- #ifdef THC_REAL_IS_FLOAT
465- THCudaBlas_Sgemv (
466- #elif defined(THC_REAL_IS_DOUBLE)
467- THCudaBlas_Dgemv (
468- #endif
469- state,
470- ' t' ,
471- k_, m_,
472- scale,
473- THCTensor_ (data)(state, gradOutput_n), k_,
474- THCTensor_ (data)(state, ones), 1 ,
475- ScalarConvert<int , real>::to (1 ),
476- THCTensor_ (data)(state, gradBias), 1
477- );
478- #endif
479- #ifdef THC_REAL_IS_HALF
480- THCudaBlas_Hgemm (
481- state,
482- ' t' , ' n' ,
483- m_, 1 , k_,
484- scale,
485- THCTensor_ (data)(state, gradOutput_n), k_,
486- THCTensor_ (data)(state, ones), k_,
487- ScalarConvert<int , real>::to (1 ),
488- THCTensor_ (data)(state, gradBias), m_
489- );
490- #endif
467+ if (gradBias) {
468+ #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
469+ #ifdef THC_REAL_IS_FLOAT
470+ THCudaBlas_Sgemv (
471+ #elif defined(THC_REAL_IS_DOUBLE)
472+ THCudaBlas_Dgemv (
473+ #endif
474+ state,
475+ ' t' ,
476+ k_, m_,
477+ scale,
478+ THCTensor_ (data)(state, gradOutput_n), k_,
479+ THCTensor_ (data)(state, ones), 1 ,
480+ ScalarConvert<int , real>::to (1 ),
481+ THCTensor_ (data)(state, gradBias), 1
482+ );
483+ #endif
484+ #ifdef THC_REAL_IS_HALF
485+ THCudaBlas_Hgemm (
486+ state,
487+ ' t' , ' n' ,
488+ m_, 1 , k_,
489+ scale,
490+ THCTensor_ (data)(state, gradOutput_n), k_,
491+ THCTensor_ (data)(state, ones), k_,
492+ ScalarConvert<int , real>::to (1 ),
493+ THCTensor_ (data)(state, gradBias), m_
494+ );
495+ #endif
496+ }
491497 }
492-
498+
493499 // Free
494500 THCTensor_ (free )(state, input_n);
495501 THCTensor_ (free )(state, gradOutput_n);
0 commit comments