@@ -645,13 +645,8 @@ function cunntest.SparseLinear_backward()
645645 gslin :zeroGradParameters ()
646646end
647647
648- local function BatchNormalization_forward (moduleName , dim , k )
649- local planes = torch .random (1 ,k )
650- local inputSize = { torch .random (2 ,24 ), planes }
651- for i = 1 ,dim do
652- table.insert (inputSize , torch .random (1 ,k ))
653- end
654-
648+ local function BatchNormalization_forward (moduleName , inputSize )
649+ local planes = inputSize [2 ]
655650 local tm = {}
656651 local title = moduleName .. ' .forward ' .. table.concat (inputSize , ' x' )
657652 times [title ] = tm
@@ -686,13 +681,8 @@ local function BatchNormalization_forward(moduleName, dim, k)
686681 precision_forward , ' error on running_var (forward)' )
687682end
688683
689- local function BatchNormalization_forward_inference (moduleName , dim , k )
690- local planes = torch .random (1 ,k )
691- local inputSize = { torch .random (2 ,32 ), planes }
692- for i = 1 ,dim do
693- table.insert (inputSize , torch .random (1 ,k ))
694- end
695-
684+ local function BatchNormalization_forward_inference (moduleName , inputSize )
685+ local planes = inputSize [2 ]
696686 local tm = {}
697687 local title = moduleName .. ' .forward (evaluate) ' .. table.concat (inputSize , ' x' )
698688 times [title ] = tm
@@ -728,15 +718,10 @@ local function BatchNormalization_forward_inference(moduleName, dim, k)
728718 mytester :assertlt (error :abs ():max (), precision_forward , ' error on state (forward evaluate)' )
729719end
730720
731- local function BatchNormalization_backward (moduleName , mode , dim , k , backwardFn )
721+ local function BatchNormalization_backward (moduleName , mode , inputSize , backwardFn )
732722 assert (mode == ' training' or mode == ' evaluation' , ' invalid mode' )
733723
734- local planes = torch .random (1 ,k )
735- local inputSize = { torch .random (2 ,32 ), planes }
736- for i = 1 ,dim do
737- table.insert (inputSize , torch .random (1 ,k ))
738- end
739-
724+ local planes = inputSize [2 ]
740725 local tm = {}
741726 local title = moduleName .. ' .backward ' .. table.concat (inputSize , ' x' )
742727 times [title ] = tm
@@ -793,67 +778,43 @@ local function BatchNormalization_backward(moduleName, mode, dim, k, backwardFn)
793778 mytester :assertlt (berror :abs ():max (), precision_backward , ' error on bias (backward) ' )
794779end
795780
796- function cunntest .BatchNormalization ()
797- BatchNormalization_forward (' BatchNormalization' , 0 , 128 )
798- BatchNormalization_forward_inference (' BatchNormalization' , 0 , 128 )
799- BatchNormalization_backward (' BatchNormalization' , ' training' , 0 , 128 , function (m , input , gradOutput )
800- return m :backward (input , gradOutput )
801- end )
802- BatchNormalization_backward (' BatchNormalization' , ' evaluation' , 0 , 128 , function (m , input , gradOutput )
781+ local function testBatchNormalization (name , dim , k )
782+ local function inputSize ()
783+ local inputSize = { torch .random (2 ,32 ), torch .random (1 , k ) }
784+ for i = 1 ,dim do
785+ table.insert (inputSize , torch .random (1 ,k ))
786+ end
787+ return inputSize
788+ end
789+ local function backward1 (m , input , gradOutput )
803790 return m :backward (input , gradOutput )
804- end )
805- BatchNormalization_backward (' BatchNormalization' , ' training' , 0 , 128 , function (m , input , gradOutput )
806- local gradInput = m :updateGradInput (input , gradOutput )
807- m :accGradParameters (input , gradOutput )
808- return gradInput
809- end )
810- BatchNormalization_backward (' BatchNormalization' , ' evaluation' , 0 , 128 , function (m , input , gradOutput )
791+ end
792+ local function backward2 (m , input , gradOutput )
811793 local gradInput = m :updateGradInput (input , gradOutput )
812794 m :accGradParameters (input , gradOutput )
813795 return gradInput
814- end )
796+ end
797+
798+ BatchNormalization_forward (name , inputSize ())
799+ BatchNormalization_forward_inference (name , inputSize ())
800+ BatchNormalization_backward (name , ' training' , inputSize (), backward1 )
801+ BatchNormalization_backward (name , ' training' , inputSize (), backward2 )
802+ BatchNormalization_backward (name , ' evaluation' , inputSize (), backward1 )
803+ BatchNormalization_backward (name , ' evaluation' , inputSize (), backward2 )
804+ end
805+
806+ function cunntest .BatchNormalization ()
807+ testBatchNormalization (' BatchNormalization' , 0 , 128 )
815808end
816809
817810function cunntest .SpatialBatchNormalization ()
818- BatchNormalization_forward (' SpatialBatchNormalization' , 2 , 64 )
819- BatchNormalization_forward_inference (' SpatialBatchNormalization' , 2 , 64 )
820- BatchNormalization_backward (' SpatialBatchNormalization' , ' training' , 2 , 64 , function (m , input , gradOutput )
821- return m :backward (input , gradOutput )
822- end )
823- BatchNormalization_backward (' SpatialBatchNormalization' , ' evaluation' , 2 , 64 , function (m , input , gradOutput )
824- return m :backward (input , gradOutput )
825- end )
826- BatchNormalization_backward (' SpatialBatchNormalization' , ' training' , 2 , 64 , function (m , input , gradOutput )
827- local gradInput = m :updateGradInput (input , gradOutput )
828- m :accGradParameters (input , gradOutput )
829- return gradInput
830- end )
831- BatchNormalization_backward (' SpatialBatchNormalization' , ' evaluation' , 2 , 64 , function (m , input , gradOutput )
832- local gradInput = m :updateGradInput (input , gradOutput )
833- m :accGradParameters (input , gradOutput )
834- return gradInput
835- end )
811+ testBatchNormalization (' SpatialBatchNormalization' , 2 , 64 )
812+ -- check with large image size (32*32 = 1024)
813+ BatchNormalization_forward (' SpatialBatchNormalization' , {2 , 2 , 32 , 32 })
836814end
837815
838816function cunntest .VolumetricBatchNormalization ()
839- BatchNormalization_forward (' VolumetricBatchNormalization' , 3 , 16 )
840- BatchNormalization_forward_inference (' VolumetricBatchNormalization' , 3 , 16 )
841- BatchNormalization_backward (' VolumetricBatchNormalization' , ' training' , 3 , 16 , function (m , input , gradOutput )
842- return m :backward (input , gradOutput )
843- end )
844- BatchNormalization_backward (' VolumetricBatchNormalization' , ' evaluation' , 3 , 16 , function (m , input , gradOutput )
845- return m :backward (input , gradOutput )
846- end )
847- BatchNormalization_backward (' VolumetricBatchNormalization' , ' training' , 3 , 16 , function (m , input , gradOutput )
848- local gradInput = m :updateGradInput (input , gradOutput )
849- m :accGradParameters (input , gradOutput )
850- return gradInput
851- end )
852- BatchNormalization_backward (' VolumetricBatchNormalization' , ' evaluation' , 3 , 16 , function (m , input , gradOutput )
853- local gradInput = m :updateGradInput (input , gradOutput )
854- m :accGradParameters (input , gradOutput )
855- return gradInput
856- end )
817+ testBatchNormalization (' VolumetricBatchNormalization' , 3 , 16 )
857818end
858819
859820function cunntest .SpatialConvolutionMM_forward_single ()
0 commit comments