@@ -64,46 +64,101 @@ class BaseConvolutionLayer : public Layer<Dtype> {
6464 // Compute height_out_ and width_out_ from other parameters.
6565 virtual void compute_output_shape () = 0;
6666
67- int kernel_h_, kernel_w_;
68- int stride_h_, stride_w_;
67+ // / @brief The spatial dimensions of a filter kernel.
68+ Blob<int > kernel_shape_;
69+ // / @brief The spatial dimensions of the stride.
70+ Blob<int > stride_;
71+ // / @brief The spatial dimensions of the padding.
72+ Blob<int > pad_;
73+ // / @brief The spatial dimensions of the convolution input.
74+ Blob<int > conv_input_shape_;
75+ // / @brief The spatial dimensions of the input.
76+ Blob<int > input_shape_;
77+ // / @brief The spatial dimensions of the col_buffer.
78+ vector<int > col_buffer_shape_;
79+ // / @brief The spatial dimensions of the output.
80+ vector<int > output_shape_;
81+
82+ int num_spatial_axes_;
83+ int bottom_dim_;
84+ int top_dim_;
85+
86+ int channel_axis_;
6987 int num_;
7088 int channels_;
71- int pad_h_, pad_w_;
72- int height_, width_;
7389 int group_;
90+ int out_spatial_dim_;
91+ int weight_offset_;
7492 int num_output_;
75- int height_out_, width_out_;
7693 bool bias_term_;
7794 bool is_1x1_;
95+ bool force_nd_im2col_;
7896
7997 private:
8098 // wrap im2col/col2im so we don't have to remember the (long) argument lists
8199 inline void conv_im2col_cpu (const Dtype* data, Dtype* col_buff) {
82- im2col_cpu (data, conv_in_channels_, conv_in_height_, conv_in_width_,
83- kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff);
100+ if (!force_nd_im2col_ && num_spatial_axes_ == 2 ) {
101+ im2col_cpu (data, conv_in_channels_,
102+ conv_input_shape_.cpu_data ()[1 ], conv_input_shape_.cpu_data ()[2 ],
103+ kernel_shape_.cpu_data ()[0 ], kernel_shape_.cpu_data ()[1 ],
104+ pad_.cpu_data ()[0 ], pad_.cpu_data ()[1 ],
105+ stride_.cpu_data ()[0 ], stride_.cpu_data ()[1 ], col_buff);
106+ } else {
107+ im2col_nd_cpu (data, num_spatial_axes_, conv_input_shape_.cpu_data (),
108+ col_buffer_shape_.data (), kernel_shape_.cpu_data (),
109+ pad_.cpu_data (), stride_.cpu_data (), col_buff);
110+ }
84111 }
85112 inline void conv_col2im_cpu (const Dtype* col_buff, Dtype* data) {
86- col2im_cpu (col_buff, conv_in_channels_, conv_in_height_, conv_in_width_,
87- kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data);
113+ if (!force_nd_im2col_ && num_spatial_axes_ == 2 ) {
114+ col2im_cpu (col_buff, conv_in_channels_,
115+ conv_input_shape_.cpu_data ()[1 ], conv_input_shape_.cpu_data ()[2 ],
116+ kernel_shape_.cpu_data ()[0 ], kernel_shape_.cpu_data ()[1 ],
117+ pad_.cpu_data ()[0 ], pad_.cpu_data ()[1 ],
118+ stride_.cpu_data ()[0 ], stride_.cpu_data ()[1 ], data);
119+ } else {
120+ col2im_nd_cpu (col_buff, num_spatial_axes_, conv_input_shape_.cpu_data (),
121+ col_buffer_shape_.data (), kernel_shape_.cpu_data (),
122+ pad_.cpu_data (), stride_.cpu_data (), data);
123+ }
88124 }
89125#ifndef CPU_ONLY
90126 inline void conv_im2col_gpu (const Dtype* data, Dtype* col_buff) {
91- im2col_gpu (data, conv_in_channels_, conv_in_height_, conv_in_width_,
92- kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff);
127+ if (!force_nd_im2col_ && num_spatial_axes_ == 2 ) {
128+ im2col_gpu (data, conv_in_channels_,
129+ conv_input_shape_.cpu_data ()[1 ], conv_input_shape_.cpu_data ()[2 ],
130+ kernel_shape_.cpu_data ()[0 ], kernel_shape_.cpu_data ()[1 ],
131+ pad_.cpu_data ()[0 ], pad_.cpu_data ()[1 ],
132+ stride_.cpu_data ()[0 ], stride_.cpu_data ()[1 ], col_buff);
133+ } else {
134+ im2col_nd_gpu (data, num_spatial_axes_, num_kernels_im2col_,
135+ conv_input_shape_.gpu_data (), col_buffer_.gpu_shape (),
136+ kernel_shape_.gpu_data (), pad_.gpu_data (),
137+ stride_.gpu_data (), col_buff);
138+ }
93139 }
94140 inline void conv_col2im_gpu (const Dtype* col_buff, Dtype* data) {
95- col2im_gpu (col_buff, conv_in_channels_, conv_in_height_, conv_in_width_,
96- kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data);
141+ if (!force_nd_im2col_ && num_spatial_axes_ == 2 ) {
142+ col2im_gpu (col_buff, conv_in_channels_,
143+ conv_input_shape_.cpu_data ()[1 ], conv_input_shape_.cpu_data ()[2 ],
144+ kernel_shape_.cpu_data ()[0 ], kernel_shape_.cpu_data ()[1 ],
145+ pad_.cpu_data ()[0 ], pad_.cpu_data ()[1 ],
146+ stride_.cpu_data ()[0 ], stride_.cpu_data ()[1 ], data);
147+ } else {
148+ col2im_nd_gpu (col_buff, num_spatial_axes_, num_kernels_col2im_,
149+ conv_input_shape_.gpu_data (), col_buffer_.gpu_shape (),
150+ kernel_shape_.gpu_data (), pad_.gpu_data (), stride_.gpu_data (),
151+ data);
152+ }
97153 }
98154#endif
99155
156+ int num_kernels_im2col_;
157+ int num_kernels_col2im_;
100158 int conv_out_channels_;
101159 int conv_in_channels_;
102160 int conv_out_spatial_dim_;
103- int conv_in_height_;
104- int conv_in_width_;
105161 int kernel_dim_;
106- int weight_offset_;
107162 int col_offset_;
108163 int output_offset_;
109164
@@ -250,7 +305,7 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
250305 cudnnTensorDescriptor_t bias_desc_;
251306 cudnnFilterDescriptor_t filter_desc_;
252307 vector<cudnnConvolutionDescriptor_t> conv_descs_;
253- int bottom_offset_, top_offset_, weight_offset_, bias_offset_;
308+ int bottom_offset_, top_offset_, bias_offset_;
254309 size_t workspaceSizeInBytes;
255310 void *workspace;
256311};
@@ -287,11 +342,22 @@ class Im2colLayer : public Layer<Dtype> {
287342 virtual void Backward_gpu (const vector<Blob<Dtype>*>& top,
288343 const vector<bool >& propagate_down, const vector<Blob<Dtype>*>& bottom);
289344
290- int kernel_h_, kernel_w_;
291- int stride_h_, stride_w_;
345+ // / @brief The spatial dimensions of a filter kernel.
346+ Blob<int > kernel_shape_;
347+ // / @brief The spatial dimensions of the stride.
348+ Blob<int > stride_;
349+ // / @brief The spatial dimensions of the padding.
350+ Blob<int > pad_;
351+
352+ int num_spatial_axes_;
353+ int bottom_dim_;
354+ int top_dim_;
355+
356+ int channel_axis_;
357+ int num_;
292358 int channels_;
293- int height_, width_;
294- int pad_h_, pad_w_ ;
359+
360+ bool force_nd_im2col_ ;
295361};
296362
297363// Forward declare PoolingLayer and SplitLayer for use in LRNLayer.
0 commit comments