Skip to content

Commit 2e1c1cb

Browse files
committed
Merge pull request BVLC#2049 from jeffdonahue/nd-convolution
ND convolution with im2col
2 parents 3d12b5d + 9d8206e commit 2e1c1cb

21 files changed

+1594
-321
lines changed

include/caffe/blob.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ class Blob {
219219

220220
const Dtype* cpu_data() const;
221221
void set_cpu_data(Dtype* data);
222+
const int* gpu_shape() const;
222223
const Dtype* gpu_data() const;
223224
const Dtype* cpu_diff() const;
224225
const Dtype* gpu_diff() const;
@@ -268,6 +269,7 @@ class Blob {
268269
protected:
269270
shared_ptr<SyncedMemory> data_;
270271
shared_ptr<SyncedMemory> diff_;
272+
shared_ptr<SyncedMemory> shape_data_;
271273
vector<int> shape_;
272274
int count_;
273275
int capacity_;

include/caffe/util/im2col.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,48 @@
33

44
namespace caffe {
55

6+
template <typename Dtype>
7+
void im2col_nd_cpu(const Dtype* data_im, const int num_spatial_axes,
8+
const int* im_shape, const int* col_shape,
9+
const int* kernel_shape, const int* pad, const int* stride,
10+
Dtype* data_col);
11+
612
template <typename Dtype>
713
void im2col_cpu(const Dtype* data_im, const int channels,
814
const int height, const int width, const int kernel_h, const int kernel_w,
915
const int pad_h, const int pad_w, const int stride_h,
1016
const int stride_w, Dtype* data_col);
1117

18+
template <typename Dtype>
19+
void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes,
20+
const int* im_shape, const int* col_shape,
21+
const int* kernel_shape, const int* pad, const int* stride,
22+
Dtype* data_im);
23+
1224
template <typename Dtype>
1325
void col2im_cpu(const Dtype* data_col, const int channels,
1426
const int height, const int width, const int patch_h, const int patch_w,
1527
const int pad_h, const int pad_w, const int stride_h,
1628
const int stride_w, Dtype* data_im);
1729

30+
template <typename Dtype>
31+
void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes,
32+
const int col_size, const int* im_shape, const int* col_shape,
33+
const int* kernel_shape, const int* pad, const int* stride,
34+
Dtype* data_col);
35+
1836
template <typename Dtype>
1937
void im2col_gpu(const Dtype* data_im, const int channels,
2038
const int height, const int width, const int kernel_h, const int kernel_w,
2139
const int pad_h, const int pad_w, const int stride_h,
2240
const int stride_w, Dtype* data_col);
2341

42+
template <typename Dtype>
43+
void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes,
44+
const int im_size, const int* im_shape, const int* col_shape,
45+
const int* kernel_shape, const int* pad, const int* stride,
46+
Dtype* data_im);
47+
2448
template <typename Dtype>
2549
void col2im_gpu(const Dtype* data_col, const int channels,
2650
const int height, const int width, const int patch_h, const int patch_w,

include/caffe/vision_layers.hpp

Lines changed: 87 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -64,46 +64,101 @@ class BaseConvolutionLayer : public Layer<Dtype> {
6464
// Compute height_out_ and width_out_ from other parameters.
6565
virtual void compute_output_shape() = 0;
6666

67-
int kernel_h_, kernel_w_;
68-
int stride_h_, stride_w_;
67+
/// @brief The spatial dimensions of a filter kernel.
68+
Blob<int> kernel_shape_;
69+
/// @brief The spatial dimensions of the stride.
70+
Blob<int> stride_;
71+
/// @brief The spatial dimensions of the padding.
72+
Blob<int> pad_;
73+
/// @brief The spatial dimensions of the convolution input.
74+
Blob<int> conv_input_shape_;
75+
/// @brief The spatial dimensions of the input.
76+
Blob<int> input_shape_;
77+
/// @brief The spatial dimensions of the col_buffer.
78+
vector<int> col_buffer_shape_;
79+
/// @brief The spatial dimensions of the output.
80+
vector<int> output_shape_;
81+
82+
int num_spatial_axes_;
83+
int bottom_dim_;
84+
int top_dim_;
85+
86+
int channel_axis_;
6987
int num_;
7088
int channels_;
71-
int pad_h_, pad_w_;
72-
int height_, width_;
7389
int group_;
90+
int out_spatial_dim_;
91+
int weight_offset_;
7492
int num_output_;
75-
int height_out_, width_out_;
7693
bool bias_term_;
7794
bool is_1x1_;
95+
bool force_nd_im2col_;
7896

7997
private:
8098
// wrap im2col/col2im so we don't have to remember the (long) argument lists
8199
inline void conv_im2col_cpu(const Dtype* data, Dtype* col_buff) {
82-
im2col_cpu(data, conv_in_channels_, conv_in_height_, conv_in_width_,
83-
kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff);
100+
if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
101+
im2col_cpu(data, conv_in_channels_,
102+
conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
103+
kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
104+
pad_.cpu_data()[0], pad_.cpu_data()[1],
105+
stride_.cpu_data()[0], stride_.cpu_data()[1], col_buff);
106+
} else {
107+
im2col_nd_cpu(data, num_spatial_axes_, conv_input_shape_.cpu_data(),
108+
col_buffer_shape_.data(), kernel_shape_.cpu_data(),
109+
pad_.cpu_data(), stride_.cpu_data(), col_buff);
110+
}
84111
}
85112
inline void conv_col2im_cpu(const Dtype* col_buff, Dtype* data) {
86-
col2im_cpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_,
87-
kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data);
113+
if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
114+
col2im_cpu(col_buff, conv_in_channels_,
115+
conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
116+
kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
117+
pad_.cpu_data()[0], pad_.cpu_data()[1],
118+
stride_.cpu_data()[0], stride_.cpu_data()[1], data);
119+
} else {
120+
col2im_nd_cpu(col_buff, num_spatial_axes_, conv_input_shape_.cpu_data(),
121+
col_buffer_shape_.data(), kernel_shape_.cpu_data(),
122+
pad_.cpu_data(), stride_.cpu_data(), data);
123+
}
88124
}
89125
#ifndef CPU_ONLY
90126
inline void conv_im2col_gpu(const Dtype* data, Dtype* col_buff) {
91-
im2col_gpu(data, conv_in_channels_, conv_in_height_, conv_in_width_,
92-
kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff);
127+
if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
128+
im2col_gpu(data, conv_in_channels_,
129+
conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
130+
kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
131+
pad_.cpu_data()[0], pad_.cpu_data()[1],
132+
stride_.cpu_data()[0], stride_.cpu_data()[1], col_buff);
133+
} else {
134+
im2col_nd_gpu(data, num_spatial_axes_, num_kernels_im2col_,
135+
conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(),
136+
kernel_shape_.gpu_data(), pad_.gpu_data(),
137+
stride_.gpu_data(), col_buff);
138+
}
93139
}
94140
inline void conv_col2im_gpu(const Dtype* col_buff, Dtype* data) {
95-
col2im_gpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_,
96-
kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data);
141+
if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
142+
col2im_gpu(col_buff, conv_in_channels_,
143+
conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
144+
kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
145+
pad_.cpu_data()[0], pad_.cpu_data()[1],
146+
stride_.cpu_data()[0], stride_.cpu_data()[1], data);
147+
} else {
148+
col2im_nd_gpu(col_buff, num_spatial_axes_, num_kernels_col2im_,
149+
conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(),
150+
kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(),
151+
data);
152+
}
97153
}
98154
#endif
99155

156+
int num_kernels_im2col_;
157+
int num_kernels_col2im_;
100158
int conv_out_channels_;
101159
int conv_in_channels_;
102160
int conv_out_spatial_dim_;
103-
int conv_in_height_;
104-
int conv_in_width_;
105161
int kernel_dim_;
106-
int weight_offset_;
107162
int col_offset_;
108163
int output_offset_;
109164

@@ -250,7 +305,7 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
250305
cudnnTensorDescriptor_t bias_desc_;
251306
cudnnFilterDescriptor_t filter_desc_;
252307
vector<cudnnConvolutionDescriptor_t> conv_descs_;
253-
int bottom_offset_, top_offset_, weight_offset_, bias_offset_;
308+
int bottom_offset_, top_offset_, bias_offset_;
254309
size_t workspaceSizeInBytes;
255310
void *workspace;
256311
};
@@ -287,11 +342,22 @@ class Im2colLayer : public Layer<Dtype> {
287342
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
288343
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
289344

290-
int kernel_h_, kernel_w_;
291-
int stride_h_, stride_w_;
345+
/// @brief The spatial dimensions of a filter kernel.
346+
Blob<int> kernel_shape_;
347+
/// @brief The spatial dimensions of the stride.
348+
Blob<int> stride_;
349+
/// @brief The spatial dimensions of the padding.
350+
Blob<int> pad_;
351+
352+
int num_spatial_axes_;
353+
int bottom_dim_;
354+
int top_dim_;
355+
356+
int channel_axis_;
357+
int num_;
292358
int channels_;
293-
int height_, width_;
294-
int pad_h_, pad_w_;
359+
360+
bool force_nd_im2col_;
295361
};
296362

297363
// Forward declare PoolingLayer and SplitLayer for use in LRNLayer.

src/caffe/blob.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,16 @@ void Blob<Dtype>::Reshape(const vector<int>& shape) {
2424
CHECK_LE(shape.size(), kMaxBlobAxes);
2525
count_ = 1;
2626
shape_.resize(shape.size());
27+
if (!shape_data_ || shape_data_->size() < shape.size() * sizeof(int)) {
28+
shape_data_.reset(new SyncedMemory(shape.size() * sizeof(int)));
29+
}
30+
int* shape_data = static_cast<int*>(shape_data_->mutable_cpu_data());
2731
for (int i = 0; i < shape.size(); ++i) {
2832
CHECK_GE(shape[i], 0);
2933
CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX";
3034
count_ *= shape[i];
3135
shape_[i] = shape[i];
36+
shape_data[i] = shape[i];
3237
}
3338
if (count_ > capacity_) {
3439
capacity_ = count_;
@@ -67,6 +72,12 @@ Blob<Dtype>::Blob(const vector<int>& shape)
6772
Reshape(shape);
6873
}
6974

75+
template <typename Dtype>
76+
const int* Blob<Dtype>::gpu_shape() const {
77+
CHECK(shape_data_);
78+
return (const int*)shape_data_->gpu_data();
79+
}
80+
7081
template <typename Dtype>
7182
const Dtype* Blob<Dtype>::cpu_data() const {
7283
CHECK(data_);

0 commit comments

Comments
 (0)