Skip to content

Commit d518ad7

Browse files
authored
Merge pull request tqchen#12 from ZihengJiang/resnet
support resnet on cifar
2 parents 29b464c + 44dd3e2 commit d518ad7

File tree

9 files changed

+335
-19
lines changed

9 files changed

+335
-19
lines changed

example/cifar_resnet.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import tinyflow as tf
2+
from tinyflow.datasets import get_cifar10
3+
import numpy as np
4+
5+
num_epoch = 10
6+
num_batch = 600
7+
batch_size = 100
8+
9+
10+
def conv_factory(x, filter_size, in_filters, out_filters):
11+
x = tf.nn.conv2d(x, num_filter=out_filters,
12+
ksize=[1, filter_size, filter_size, 1], padding='SAME')
13+
x = tf.nn.batch_normalization(x)
14+
x = tf.nn.relu(x)
15+
return x
16+
17+
def residual_factory(x, in_filters, out_filters):
18+
if in_filters == out_filters:
19+
orig_x = x
20+
conv1 = conv_factory(x, 3, in_filters, out_filters)
21+
conv2 = conv_factory(conv1, 3, out_filters, out_filters)
22+
new = orig_x + conv2
23+
return tf.nn.relu(new)
24+
else:
25+
conv1 = conv_factory(x, 3, in_filters, out_filters)
26+
conv2 = conv_factory(conv1, 3, out_filters, out_filters)
27+
project_x = conv_factory(x, 1, in_filters, out_filters)
28+
new = project_x + conv2
29+
return tf.nn.relu(new)
30+
31+
def resnet(x, n, in_filters, out_filters):
32+
for i in range(n):
33+
if i == 0:
34+
x = residual_factory(x, in_filters, 16)
35+
else:
36+
x = residual_factory(x, 16, 16)
37+
for i in range(n):
38+
if i == 0:
39+
x = residual_factory(x, 16, 32)
40+
else:
41+
x = residual_factory(x, 32, 32)
42+
for i in range(n):
43+
if i == 0:
44+
x = residual_factory(x, 32, 64)
45+
else:
46+
x = residual_factory(x, 64, 64)
47+
return x
48+
49+
50+
x = tf.placeholder(tf.float32)
51+
conv1 = tf.nn.conv2d(x, num_filter=16, ksize=[1, 5, 5, 1], padding='SAME')
52+
tanh1 = tf.tanh(conv1)
53+
res = resnet(tanh1, 1, 16, 64)
54+
pool1 = tf.nn.avg_pool(res, ksize=[1, 4, 4, 1], strides=[1, 2, 2, 1], padding='SAME', data_format='NCHW')
55+
conv2 = tf.nn.conv2d(pool1, num_filter=16, ksize=[1, 5, 5, 1])
56+
flatten = tf.nn.flatten_layer(conv2)
57+
fc1 = tf.nn.linear(flatten, num_hidden=10, name="fc1")
58+
59+
# define loss
60+
label = tf.placeholder(tf.float32)
61+
cross_entropy = tf.nn.mean_sparse_softmax_cross_entropy_with_logits(fc1, label)
62+
train_step = tf.train.AdamOptimizer(0.0005).minimize(cross_entropy)
63+
64+
sess = tf.Session(config='gpu')
65+
66+
# Auromatic variable shape inference API, infers the shape and initialize the weights.
67+
known_shape = {x: [batch_size, 3, 32, 32], label: [batch_size]}
68+
stdev = 0.01
69+
init_step = []
70+
for v, name, shape in tf.infer_variable_shapes(
71+
cross_entropy, feed_dict=known_shape):
72+
init_step.append(tf.assign(v, tf.normal(shape, stdev)))
73+
print("shape[%s]=%s" % (name, str(shape)))
74+
sess.run(init_step)
75+
sess.run(tf.initialize_all_variables())
76+
77+
# get the cifar dataset
78+
cifar = get_cifar10()
79+
80+
for epoch in range(num_epoch):
81+
sum_loss = 0.0
82+
for i in range(num_batch):
83+
batch_xs, batch_ys = cifar.train.next_batch(batch_size)
84+
loss, _ = sess.run([cross_entropy, train_step], feed_dict={x: batch_xs, label:batch_ys})
85+
sum_loss += loss
86+
print("epoch[%d] cross_entropy=%g" % (epoch, sum_loss /num_batch))
87+
88+
correct_prediction = tf.equal(tf.argmax(fc1, 1), label)
89+
accuracy = tf.reduce_mean(correct_prediction)
90+
print(sess.run(accuracy, feed_dict={x: cifar.test.images, label: cifar.test.labels}))

python/tinyflow/_base.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,32 +35,35 @@
3535
# global list of all variable initializers
3636
_all_variable_inits = []
3737

38-
def placeholder(dtype, shape=None, name=None):
39-
v = symbol.placeholder(name=name, dtype=dtype)
40-
return v
41-
4238

43-
def Variable(init, name=None):
44-
if not isinstance(init, symbol.Symbol):
45-
raise TypeError("Expect initialization expression to be Symbol")
39+
def Variable(init=None, name=None):
4640
name = NameManager.current.get(name, 'variable')
4741
v = symbol.Variable(name)
48-
_all_variable_inits.append(symbol.assign(v, init))
42+
if init is not None:
43+
if not isinstance(init, symbol.Symbol):
44+
raise TypeError("Expect initialization expression to be Symbol")
45+
_all_variable_inits.append(symbol.assign(v, init))
4946
return v
5047

5148

52-
def group(*inputs):
53-
x = _symbol_internal._nop()
54-
x._add_control_deps(symbol.Group(inputs))
55-
return x
56-
57-
5849
def initialize_all_variables():
5950
global _all_variable_inits
6051
init_op = group(*_all_variable_inits)
6152
_all_variable_inits = []
6253
return init_op
6354

55+
56+
def placeholder(dtype, name=None):
57+
v = symbol.placeholder(name=name, dtype=dtype)
58+
return v
59+
60+
61+
def group(*inputs):
62+
x = _symbol_internal._nop()
63+
x._add_control_deps(symbol.Group(inputs))
64+
return x
65+
66+
6467
def gradients(ys, xs, grad_ys=None):
6568
if isinstance(ys, list):
6669
ys = symbol.Group(ys)

python/tinyflow/datasets.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
import numpy as np
33
from collections import namedtuple
44
from sklearn.datasets import fetch_mldata
5+
import cPickle
6+
import sys
7+
import os
8+
from subprocess import call
9+
510

611
class ArrayPacker(object):
712
"""Dataset packer for iterator"""
@@ -41,3 +46,61 @@ def get_mnist(flatten=False, onehot=False):
4146
Y_test = Y[60000:]
4247
return MNISTData(train=ArrayPacker(X_train, Y_train),
4348
test=ArrayPacker(X_test, Y_test))
49+
50+
51+
CIFAR10Data = namedtuple("CIFAR10Data", ["train", "test"])
52+
53+
def load_batch(fpath, label_key='labels'):
54+
f = open(fpath, 'rb')
55+
if sys.version_info < (3,):
56+
d = cPickle.load(f)
57+
else:
58+
d = cPickle.load(f, encoding="bytes")
59+
# decode utf8
60+
for k, v in d.items():
61+
del(d[k])
62+
d[k.decode("utf8")] = v
63+
f.close()
64+
data = d["data"]
65+
labels = d[label_key]
66+
67+
data = data.reshape(data.shape[0], 3, 32, 32).astype(np.float32)
68+
labels = np.array(labels, dtype="float32")
69+
return data, labels
70+
71+
72+
def get_cifar10(swap_axes=False):
73+
path = "cifar-10-batches-py"
74+
if not os.path.exists(path):
75+
tar_file = "cifar-10-python.tar.gz"
76+
origin = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
77+
if os.path.exists(tar_file):
78+
need_download = False
79+
else:
80+
need_download = True
81+
if need_download:
82+
call(["wget", origin])
83+
call(["tar", "-xvf", "cifar-10-python.tar.gz"])
84+
else:
85+
call(["tar", "-xvf", "cifar-10-python.tar.gz"])
86+
87+
nb_train_samples = 50000
88+
89+
X_train = np.zeros((nb_train_samples, 3, 32, 32), dtype="float32")
90+
y_train = np.zeros((nb_train_samples,), dtype="float32")
91+
92+
for i in range(1, 6):
93+
fpath = os.path.join(path, 'data_batch_' + str(i))
94+
data, labels = load_batch(fpath)
95+
X_train[(i - 1) * 10000: i * 10000, :, :, :] = data
96+
y_train[(i - 1) * 10000: i * 10000] = labels
97+
98+
fpath = os.path.join(path, 'test_batch')
99+
X_test, y_test = load_batch(fpath)
100+
101+
if swap_axes:
102+
X_train = np.swapaxes(X_train, 1, 3)
103+
X_test = np.swapaxes(X_test, 1, 3)
104+
105+
return CIFAR10Data(train=ArrayPacker(X_train, y_train),
106+
test=ArrayPacker(X_test, y_test))

src/op_nn.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,38 @@ NNVM_REGISTER_OP(linear)
178178
.set_attr<FInferShape>("FInferShape", LinearShape);
179179

180180

181+
struct PadParam : public dmlc::Parameter<PadParam> {
182+
uint32_t dim;
183+
int pad;
184+
185+
DMLC_DECLARE_PARAMETER(PadParam) {
186+
DMLC_DECLARE_FIELD(dim).set_default(0);
187+
DMLC_DECLARE_FIELD(pad).set_default(0);
188+
}
189+
};
190+
DMLC_REGISTER_PARAMETER(PadParam);
191+
192+
inline bool PadShape(const NodeAttrs& attrs,
193+
std::vector<TShape> *ishape,
194+
std::vector<TShape> *oshape) {
195+
const auto& param = dmlc::get<PadParam>(attrs.parsed);
196+
if (ishape->at(0).ndim() == 0) {
197+
return false;
198+
}
199+
TShape out = ishape->at(0);
200+
out[param.dim] += abs(param.pad);
201+
oshape->at(0) = out;
202+
return true;
203+
}
204+
205+
NNVM_REGISTER_OP(pad)
206+
.describe("pads a tensor")
207+
.set_num_inputs(1)
208+
.include("nn_module")
209+
.set_attr_parser(ParamParser<PadParam>)
210+
.set_attr<FInferShape>("FInferShape", PadShape);
211+
212+
181213
struct ConvPoolParam : public dmlc::Parameter<ConvPoolParam> {
182214
TShape ksize;
183215
TShape strides;
@@ -263,6 +295,46 @@ NNVM_REGISTER_OP(max_pool)
263295
.set_attr<FInferShape>("FInferShape", ConvPoolShape);
264296

265297

298+
NNVM_REGISTER_OP(avg_pool)
299+
.describe("Avg pooling")
300+
.set_num_inputs(1)
301+
.set_attr_parser(ParamParser<ConvPoolParam>)
302+
.include("nn_module")
303+
.set_attr<FInferShape>("FInferShape", ConvPoolShape);
304+
305+
306+
struct BatchNormalizationParam : public dmlc::Parameter<BatchNormalizationParam> {
307+
std::string name;
308+
DMLC_DECLARE_PARAMETER(BatchNormalizationParam) {
309+
DMLC_DECLARE_FIELD(name).set_default("batch_normalization");
310+
}
311+
};
312+
DMLC_REGISTER_PARAMETER(BatchNormalizationParam);
313+
314+
inline bool BatchNormalizationShape(const NodeAttrs& attrs,
315+
std::vector<TShape> *ishape,
316+
std::vector<TShape> *oshape) {
317+
if (ishape->at(0).ndim() == 0) return false;
318+
const TShape& in = ishape->at(0);
319+
CHECK_EQ(in.ndim(), 4);
320+
TShape mean = TShape{in[1]};
321+
SHAPE_ASSIGN(ishape->at(1), mean);
322+
SHAPE_ASSIGN(ishape->at(2), mean);
323+
oshape->at(0) = in;
324+
return true;
325+
}
326+
327+
NNVM_REGISTER_OP(batch_normalization)
328+
.describe("batch normalization")
329+
.set_num_inputs(3)
330+
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
331+
return std::vector<std::string>{"data", "gamma", "beta"};
332+
})
333+
.set_attr_parser(ParamParser<BatchNormalizationParam>)
334+
.include("nn_module")
335+
.set_attr<FInferShape>("FInferShape", BatchNormalizationShape);
336+
337+
266338
NNVM_REGISTER_OP(mean_sparse_softmax_cross_entropy_with_logits)
267339
.describe("Softmax cross entropy given logit and label")
268340
.set_num_inputs(2)

src/op_tensor.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,18 @@ NNVM_REGISTER_OP(equal)
8989
.set_attr<FInferShape>("FInferShape", SameShape);
9090

9191

92+
NNVM_REGISTER_OP(__ewise_sum__)
93+
.describe("ewise sum")
94+
.set_num_inputs(nnvm::kVarg)
95+
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
96+
.set_attr<FInferShape>("FInferShape", SameShape)
97+
.set_attr<FGradient>(
98+
"FGradient", [](const NodePtr& n,
99+
const std::vector<NodeEntry>& ograds) {
100+
return std::vector<NodeEntry>(n->num_inputs(), ograds[0]);
101+
});
102+
103+
92104
NNVM_REGISTER_OP(__add_symbol__)
93105
.describe("add two data together")
94106
.set_num_inputs(2)

0 commit comments

Comments
 (0)