Skip to content

Commit 136b547

Browse files
committed
Fix BatchNormalization warpSum for pre-Kepler cards
Fixes #298
1 parent 86d9f56 commit 136b547

File tree

2 files changed

+50
-87
lines changed

2 files changed

+50
-87
lines changed

lib/THCUNN/BatchNormalization.cu

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,20 @@ const int WARP_SIZE = 32;
77
typedef THCDeviceTensor<float, 3> DeviceTensor3;
88
typedef THCDeviceTensor<float, 1> DeviceTensor1;
99

10+
// The maximum number of threads in a block
11+
const int MAX_BLOCK_SIZE = 512;
12+
13+
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
14+
static int getNumThreads(int nElem) {
15+
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
16+
for (int i = 0; i != 5; ++i) {
17+
if (nElem <= threadSizes[i]) {
18+
return threadSizes[i];
19+
}
20+
}
21+
return MAX_BLOCK_SIZE;
22+
}
23+
1024
// Returns the index of the most significant 1 bit in `val`.
1125
__device__ __forceinline__ int getMSB(int val) {
1226
return 31 - __clz(val);
@@ -55,23 +69,20 @@ struct GradOp {
5569
const DeviceTensor3 gradOutput;
5670
};
5771

58-
// Sum across NumThreads threads within a warp
72+
// Sum across all threads within a warp
5973
static __device__ __forceinline__ float warpSum(float val) {
6074
#if __CUDA_ARCH__ >= 300
6175
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
6276
val += __shfl_xor(val, 1 << i, WARP_SIZE);
6377
}
6478
#else
65-
const int MAX_BLOCK_SIZE = 256; // maximum block size this module uses
6679
__shared__ float values[MAX_BLOCK_SIZE];
67-
__syncthreads();
6880
values[threadIdx.x] = val;
69-
__syncthreads();
81+
__threadfence_block();
7082
const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
7183
for (int i = 1; i < WARP_SIZE; i++) {
7284
val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
7385
}
74-
__syncthreads();
7586
#endif
7687
return val;
7788
}
@@ -97,6 +108,7 @@ __device__ T reduce(Op op, DeviceTensor3 tensor, int plane) {
97108

98109
// 'transpose', and reduce within warp again
99110
__shared__ T shared[32];
111+
__syncthreads();
100112
if (threadIdx.x % WARP_SIZE == 0) {
101113
shared[threadIdx.x / WARP_SIZE] = sum;
102114
}
@@ -214,16 +226,6 @@ __global__ void BatchNormalizationUpdateOutput_kernel(
214226
}
215227
}
216228

217-
static int getNumThreads(int nElem) {
218-
int threadSizes[5] = { 32, 64, 128, 256, 512 };
219-
for (int i = 0; i != 5; ++i) {
220-
if (nElem <= threadSizes[i]) {
221-
return threadSizes[i];
222-
}
223-
}
224-
return 512;
225-
}
226-
227229
void THNN_CudaBatchNormalization_updateOutput(
228230
THCState *state, THCudaTensor *input_, THCudaTensor *output_,
229231
THCudaTensor *weight_, THCudaTensor *bias_, THCudaTensor *runningMean_,

test.lua

Lines changed: 33 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -645,13 +645,8 @@ function cunntest.SparseLinear_backward()
645645
gslin:zeroGradParameters()
646646
end
647647

648-
local function BatchNormalization_forward(moduleName, dim, k)
649-
local planes = torch.random(1,k)
650-
local inputSize = { torch.random(2,24), planes }
651-
for i=1,dim do
652-
table.insert(inputSize, torch.random(1,k))
653-
end
654-
648+
local function BatchNormalization_forward(moduleName, inputSize)
649+
local planes = inputSize[2]
655650
local tm = {}
656651
local title = moduleName .. '.forward ' .. table.concat(inputSize, 'x')
657652
times[title] = tm
@@ -686,13 +681,8 @@ local function BatchNormalization_forward(moduleName, dim, k)
686681
precision_forward, 'error on running_var (forward)')
687682
end
688683

689-
local function BatchNormalization_forward_inference(moduleName, dim, k)
690-
local planes = torch.random(1,k)
691-
local inputSize = { torch.random(2,32), planes }
692-
for i=1,dim do
693-
table.insert(inputSize, torch.random(1,k))
694-
end
695-
684+
local function BatchNormalization_forward_inference(moduleName, inputSize)
685+
local planes = inputSize[2]
696686
local tm = {}
697687
local title = moduleName .. '.forward (evaluate) ' .. table.concat(inputSize, 'x')
698688
times[title] = tm
@@ -728,15 +718,10 @@ local function BatchNormalization_forward_inference(moduleName, dim, k)
728718
mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward evaluate)')
729719
end
730720

731-
local function BatchNormalization_backward(moduleName, mode, dim, k, backwardFn)
721+
local function BatchNormalization_backward(moduleName, mode, inputSize, backwardFn)
732722
assert(mode == 'training' or mode == 'evaluation', 'invalid mode')
733723

734-
local planes = torch.random(1,k)
735-
local inputSize = { torch.random(2,32), planes }
736-
for i=1,dim do
737-
table.insert(inputSize, torch.random(1,k))
738-
end
739-
724+
local planes = inputSize[2]
740725
local tm = {}
741726
local title = moduleName .. '.backward ' .. table.concat(inputSize, 'x')
742727
times[title] = tm
@@ -793,67 +778,43 @@ local function BatchNormalization_backward(moduleName, mode, dim, k, backwardFn)
793778
mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ')
794779
end
795780

796-
function cunntest.BatchNormalization()
797-
BatchNormalization_forward('BatchNormalization', 0, 128)
798-
BatchNormalization_forward_inference('BatchNormalization', 0, 128)
799-
BatchNormalization_backward('BatchNormalization', 'training', 0, 128, function(m, input, gradOutput)
800-
return m:backward(input, gradOutput)
801-
end)
802-
BatchNormalization_backward('BatchNormalization', 'evaluation', 0, 128, function(m, input, gradOutput)
781+
local function testBatchNormalization(name, dim, k)
782+
local function inputSize()
783+
local inputSize = { torch.random(2,32), torch.random(1, k) }
784+
for i=1,dim do
785+
table.insert(inputSize, torch.random(1,k))
786+
end
787+
return inputSize
788+
end
789+
local function backward1(m, input, gradOutput)
803790
return m:backward(input, gradOutput)
804-
end)
805-
BatchNormalization_backward('BatchNormalization', 'training', 0, 128, function(m, input, gradOutput)
806-
local gradInput = m:updateGradInput(input, gradOutput)
807-
m:accGradParameters(input, gradOutput)
808-
return gradInput
809-
end)
810-
BatchNormalization_backward('BatchNormalization', 'evaluation', 0, 128, function(m, input, gradOutput)
791+
end
792+
local function backward2(m, input, gradOutput)
811793
local gradInput = m:updateGradInput(input, gradOutput)
812794
m:accGradParameters(input, gradOutput)
813795
return gradInput
814-
end)
796+
end
797+
798+
BatchNormalization_forward(name, inputSize())
799+
BatchNormalization_forward_inference(name, inputSize())
800+
BatchNormalization_backward(name, 'training', inputSize(), backward1)
801+
BatchNormalization_backward(name, 'training', inputSize(), backward2)
802+
BatchNormalization_backward(name, 'evaluation', inputSize(), backward1)
803+
BatchNormalization_backward(name, 'evaluation', inputSize(), backward2)
804+
end
805+
806+
function cunntest.BatchNormalization()
807+
testBatchNormalization('BatchNormalization', 0, 128)
815808
end
816809

817810
function cunntest.SpatialBatchNormalization()
818-
BatchNormalization_forward('SpatialBatchNormalization', 2, 64)
819-
BatchNormalization_forward_inference('SpatialBatchNormalization', 2, 64)
820-
BatchNormalization_backward('SpatialBatchNormalization', 'training', 2, 64, function(m, input, gradOutput)
821-
return m:backward(input, gradOutput)
822-
end)
823-
BatchNormalization_backward('SpatialBatchNormalization', 'evaluation', 2, 64, function(m, input, gradOutput)
824-
return m:backward(input, gradOutput)
825-
end)
826-
BatchNormalization_backward('SpatialBatchNormalization', 'training', 2, 64, function(m, input, gradOutput)
827-
local gradInput = m:updateGradInput(input, gradOutput)
828-
m:accGradParameters(input, gradOutput)
829-
return gradInput
830-
end)
831-
BatchNormalization_backward('SpatialBatchNormalization', 'evaluation', 2, 64, function(m, input, gradOutput)
832-
local gradInput = m:updateGradInput(input, gradOutput)
833-
m:accGradParameters(input, gradOutput)
834-
return gradInput
835-
end)
811+
testBatchNormalization('SpatialBatchNormalization', 2, 64)
812+
-- check with large image size (32*32 = 1024)
813+
BatchNormalization_forward('SpatialBatchNormalization', {2, 2, 32, 32})
836814
end
837815

838816
function cunntest.VolumetricBatchNormalization()
839-
BatchNormalization_forward('VolumetricBatchNormalization', 3, 16)
840-
BatchNormalization_forward_inference('VolumetricBatchNormalization', 3, 16)
841-
BatchNormalization_backward('VolumetricBatchNormalization', 'training', 3, 16, function(m, input, gradOutput)
842-
return m:backward(input, gradOutput)
843-
end)
844-
BatchNormalization_backward('VolumetricBatchNormalization', 'evaluation', 3, 16, function(m, input, gradOutput)
845-
return m:backward(input, gradOutput)
846-
end)
847-
BatchNormalization_backward('VolumetricBatchNormalization', 'training', 3, 16, function(m, input, gradOutput)
848-
local gradInput = m:updateGradInput(input, gradOutput)
849-
m:accGradParameters(input, gradOutput)
850-
return gradInput
851-
end)
852-
BatchNormalization_backward('VolumetricBatchNormalization', 'evaluation', 3, 16, function(m, input, gradOutput)
853-
local gradInput = m:updateGradInput(input, gradOutput)
854-
m:accGradParameters(input, gradOutput)
855-
return gradInput
856-
end)
817+
testBatchNormalization('VolumetricBatchNormalization', 3, 16)
857818
end
858819

859820
function cunntest.SpatialConvolutionMM_forward_single()

0 commit comments

Comments
 (0)