Skip to content

Commit 708bfa9

Browse files
committed
Add checks for convolution parameters
1 parent ef68938 commit 708bfa9

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

SpatialConvolutionMM.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ void THNN_CudaSpatialConvolutionMM_updateOutput(THCState *state, THCudaTensor *i
99
THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch mode) tensor is expected");
1010
THArgCheck(weight->nDimension == 2, 4, "weight tensor must be 2D (nOutputPlane,nInputPlane*kH*kW)");
1111
THArgCheck(weight->size[0] == bias->size[0], 4, "nOutputPlane mismatch in weight and bias");
12+
THArgCheck(kW > 0 && kH > 0, 8, "kernel size should be greater than zero");
13+
THArgCheck(dW > 0 && dH > 0, 10, "stride should be greater than zero");
1214

1315
// Params:
1416
int nInputPlane = weight->size[1]/(kH*kW);
@@ -125,6 +127,8 @@ void THNN_CudaSpatialConvolutionMM_updateGradInput(THCState *state, THCudaTensor
125127
THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch mode) tensor is expected");
126128
THArgCheck(weight->nDimension == 2, 4, "weight tensor must be 2D (nOutputPlane,nInputPlane*kH*kW)");
127129
THArgCheck(weight->size[0] == bias->size[0], 4, "nOutputPlane mismatch in weight and bias");
130+
THArgCheck(kW > 0 && kH > 0, 9, "kernel size should be greater than zero");
131+
THArgCheck(dW > 0 && dH > 0, 11, "stride should be greater than zero");
128132

129133
// Params
130134
int nInputPlane = weight->size[1]/(kW*kH);
@@ -208,6 +212,8 @@ void THNN_CudaSpatialConvolutionMM_accGradParameters(THCState *state, THCudaTens
208212
THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch mode) tensor is expected");
209213
THArgCheck(gradWeight->nDimension == 2, 4, "gradWeight tensor must be 2D (nOutputPlane,nInputPlane*kH*kW)");
210214
THArgCheck(gradWeight->size[0] == gradBias->size[0], 4, "nOutputPlane mismatch in gradWeight and gradBias");
215+
THArgCheck(kW > 0 && kH > 0, 8, "kernel size should be greater than zero");
216+
THArgCheck(dW > 0 && dH > 0, 10, "stride should be greater than zero");
211217

212218
// Params
213219
int nInputPlane = gradWeight->size[1]/(kW*kH);

0 commit comments

Comments
 (0)