Skip to content

Commit 797544c

Browse files
Eli Stevenssoumith
authored andcommitted
implementation of bias=False for VolConv.cu
1 parent e9b05c7 commit 797544c

File tree

4 files changed

+178
-170
lines changed

4 files changed

+178
-170
lines changed

generic/SpatialFullConvolution.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ void THNN_(SpatialFullConvolution_updateOutput)(
182182
THCTensor_(data)(state, output_n), n_
183183
);
184184
}
185-
186185
}
187186

188187
// Free

generic/THCUNN.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,7 @@ TH_API void THNN_(VolumetricConvolution_updateOutput)(
10231023
THCTensor *input,
10241024
THCTensor *output,
10251025
THCTensor *weight,
1026-
THCTensor *bias,
1026+
THCTensor *bias, // [OPTIONAL]
10271027
THCTensor *finput,
10281028
THCTensor *fgradInput,
10291029
int dT, int dW, int dH,
@@ -1044,7 +1044,7 @@ TH_API void THNN_(VolumetricConvolution_accGradParameters)(
10441044
THCTensor *input,
10451045
THCTensor *gradOutput,
10461046
THCTensor *gradWeight,
1047-
THCTensor *gradBias,
1047+
THCTensor *gradBias, // [OPTIONAL]
10481048
THCTensor *finput,
10491049
THCTensor *fgradInput,
10501050
int dT, int dW, int dH,
@@ -1056,7 +1056,7 @@ TH_API void THNN_(VolumetricDilatedConvolution_updateOutput)(
10561056
THCTensor *input,
10571057
THCTensor *output,
10581058
THCTensor *weight,
1059-
THCTensor *bias,
1059+
THCTensor *bias, // [OPTIONAL]
10601060
THCTensor *columns,
10611061
THCTensor *ones,
10621062
int kT, int kW, int kH,
@@ -1081,7 +1081,7 @@ TH_API void THNN_(VolumetricDilatedConvolution_accGradParameters)(
10811081
THCTensor *input,
10821082
THCTensor *gradOutput,
10831083
THCTensor *gradWeight,
1084-
THCTensor *gradBias,
1084+
THCTensor *gradBias, // [OPTIONAL]
10851085
THCTensor *columns,
10861086
THCTensor *ones,
10871087
int kT, int kW, int kH,
@@ -1118,7 +1118,7 @@ TH_API void THNN_(VolumetricFullConvolution_updateOutput)(
11181118
THCTensor *input,
11191119
THCTensor *output,
11201120
THCTensor *weight,
1121-
THCTensor *bias,
1121+
THCTensor *bias, // [OPTIONAL]
11221122
THCTensor *finput,
11231123
THCTensor *fgradInput,
11241124
int dT, int dW, int dH,
@@ -1142,7 +1142,7 @@ TH_API void THNN_(VolumetricFullConvolution_accGradParameters)(
11421142
THCTensor *input,
11431143
THCTensor *gradOutput,
11441144
THCTensor *gradWeight,
1145-
THCTensor *gradBias,
1145+
THCTensor *gradBias, // [OPTIONAL]
11461146
THCTensor *finput,
11471147
THCTensor *fgradInput,
11481148
int dT, int dW, int dH,

generic/VolumetricConvolution.cu

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -178,22 +178,26 @@ void THNN_(VolumetricConvolution_updateOutput)(
178178
long k_ = 1;
179179

180180
// Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
181-
#ifdef THC_REAL_IS_FLOAT
182-
THCudaBlas_Sgemm(
183-
#elif defined(THC_REAL_IS_HALF)
184-
THCudaBlas_Hgemm(
185-
#elif defined(THC_REAL_IS_DOUBLE)
186-
THCudaBlas_Dgemm(
187-
#endif
188-
state,
189-
't', 'n',
190-
n_, m_, k_,
191-
ScalarConvert<int, real>::to(1),
192-
THCTensor_(data)(state, ones), k_,
193-
THCTensor_(data)(state, bias), k_,
194-
ScalarConvert<int, real>::to(0),
195-
THCTensor_(data)(state, output_n), n_
196-
);
181+
if (bias) {
182+
#ifdef THC_REAL_IS_FLOAT
183+
THCudaBlas_Sgemm(
184+
#elif defined(THC_REAL_IS_HALF)
185+
THCudaBlas_Hgemm(
186+
#elif defined(THC_REAL_IS_DOUBLE)
187+
THCudaBlas_Dgemm(
188+
#endif
189+
state,
190+
't', 'n',
191+
n_, m_, k_,
192+
ScalarConvert<int, real>::to(1),
193+
THCTensor_(data)(state, ones), k_,
194+
THCTensor_(data)(state, bias), k_,
195+
ScalarConvert<int, real>::to(0),
196+
THCTensor_(data)(state, output_n), n_
197+
);
198+
} else {
199+
THCTensor_(zero)(state, output_n);
200+
}
197201

198202
// Extract columns:
199203
im3d2col(
@@ -460,36 +464,38 @@ void THNN_(VolumetricConvolution_accGradParameters)(
460464
long k_ = outputDepth * outputHeight * outputWidth;
461465

462466
// Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
463-
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
464-
#ifdef THC_REAL_IS_FLOAT
465-
THCudaBlas_Sgemv(
466-
#elif defined(THC_REAL_IS_DOUBLE)
467-
THCudaBlas_Dgemv(
468-
#endif
469-
state,
470-
't',
471-
k_, m_,
472-
scale,
473-
THCTensor_(data)(state, gradOutput_n), k_,
474-
THCTensor_(data)(state, ones), 1,
475-
ScalarConvert<int, real>::to(1),
476-
THCTensor_(data)(state, gradBias), 1
477-
);
478-
#endif
479-
#ifdef THC_REAL_IS_HALF
480-
THCudaBlas_Hgemm(
481-
state,
482-
't', 'n',
483-
m_, 1, k_,
484-
scale,
485-
THCTensor_(data)(state, gradOutput_n), k_,
486-
THCTensor_(data)(state, ones), k_,
487-
ScalarConvert<int, real>::to(1),
488-
THCTensor_(data)(state, gradBias), m_
489-
);
490-
#endif
467+
if (gradBias) {
468+
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
469+
#ifdef THC_REAL_IS_FLOAT
470+
THCudaBlas_Sgemv(
471+
#elif defined(THC_REAL_IS_DOUBLE)
472+
THCudaBlas_Dgemv(
473+
#endif
474+
state,
475+
't',
476+
k_, m_,
477+
scale,
478+
THCTensor_(data)(state, gradOutput_n), k_,
479+
THCTensor_(data)(state, ones), 1,
480+
ScalarConvert<int, real>::to(1),
481+
THCTensor_(data)(state, gradBias), 1
482+
);
483+
#endif
484+
#ifdef THC_REAL_IS_HALF
485+
THCudaBlas_Hgemm(
486+
state,
487+
't', 'n',
488+
m_, 1, k_,
489+
scale,
490+
THCTensor_(data)(state, gradOutput_n), k_,
491+
THCTensor_(data)(state, ones), k_,
492+
ScalarConvert<int, real>::to(1),
493+
THCTensor_(data)(state, gradBias), m_
494+
);
495+
#endif
496+
}
491497
}
492-
498+
493499
// Free
494500
THCTensor_(free)(state, input_n);
495501
THCTensor_(free)(state, gradOutput_n);

0 commit comments

Comments
 (0)