|
9 | 9 | from multiprocessing import Process, Queue |
10 | 10 | from caffe2.proto import caffe2_pb2 |
11 | 11 | from caffe2.python import core, cnn, data_parallel_model, dyndep, optimizer, \ |
12 | | - rnn_cell, workspace |
| 12 | + rnn_cell, workspace, model_helper, brew |
13 | 13 | from caffe2.python.test_util import TestCase |
14 | 14 | from future.utils import viewkeys |
15 | 15 |
|
@@ -208,6 +208,59 @@ def add_optimizer(model): |
208 | 208 | self.assertFalse(core.BlobReference("cpu_1/data") in checkpoint_params) |
209 | 209 | self.assertTrue(core.BlobReference("optimizer_iteration") in checkpoint_params) |
210 | 210 |
|
| 211 | + def test_net_conversion_and_append_net(self): |
| 212 | + other = model_helper.ModelHelper() |
| 213 | + fc1 = brew.fc(other, "data", "other_fc1", dim_in=3*227*227, dim_out=10) |
| 214 | + fc2 = brew.fc(other, fc1, "other_fc2", dim_in=10, dim_out=10) |
| 215 | + brew.fc(other, fc2, "other_fc3", dim_in=10, dim_out=10) |
| 216 | + |
| 217 | + def add_input_ops(model): |
| 218 | + model.net.UniformFill([], ["data"], shape=[4, 227, 227, 3]) |
| 219 | + model.net.UniformFill([], ["label"], shape=[4]) |
| 220 | + |
| 221 | + def add_model_ops(model, loss_scale): |
| 222 | + model.NHWC2NCHW("data", "data_nchw") |
| 223 | + model.Conv("data_nchw", 'conv1', 3, 64, |
| 224 | + weight_init=("MSRAFill", {}), kernel=7, |
| 225 | + stride=2, pad=3, no_bias=0) |
| 226 | + model.SpatialBN('conv1', 'conv1_spatbn_relu', 64, epsilon=1e-3) |
| 227 | + model.Relu('conv1_spatbn_relu', 'conv1_spatbn_relu') |
| 228 | + model.MaxPool('conv1_spatbn_relu', 'pool1', kernel=3, stride=2) |
| 229 | + model.FC('pool1', 'fc', dim_in=(64 * 56 * 56), dim_out=10) |
| 230 | + |
| 231 | + # Append the net and param_init_net of the other model |
| 232 | + appendnet = data_parallel_model.ConvertNetForDevice(other.net) |
| 233 | + model.net.AppendNet(appendnet) |
| 234 | + |
| 235 | + model.param_init_net.AppendNet( |
| 236 | + data_parallel_model.ConvertNetForDevice(other.param_init_net)) |
| 237 | + |
| 238 | + model.Sigmoid('fc', 'fc_sigm') |
| 239 | + model.Softmax('fc_sigm', 'softmax') |
| 240 | + loss = model.AveragedLoss('softmax', 'loss') |
| 241 | + return [loss] |
| 242 | + |
| 243 | + def add_optimizer(model): |
| 244 | + optimizer.build_sgd(model, 0.1, policy="fixed", momentum=0.9) |
| 245 | + |
| 246 | + model = cnn.CNNModelHelper( |
| 247 | + order="NCHW", |
| 248 | + name="test", |
| 249 | + ) |
| 250 | + data_parallel_model.Parallelize_CPU( |
| 251 | + model, |
| 252 | + input_builder_fun=add_input_ops, |
| 253 | + forward_pass_builder_fun=add_model_ops, |
| 254 | + optimizer_builder_fun=add_optimizer, |
| 255 | + devices=range(4) |
| 256 | + ) |
| 257 | + |
| 258 | + # Just create and run net and confirm no exception is thrown |
| 259 | + workspace.RunNetOnce(model.param_init_net) |
| 260 | + workspace.CreateNet(model.net) |
| 261 | + workspace.RunNet(model.net) |
| 262 | + |
| 263 | + |
211 | 264 | def test_synchronization_barrier(self): |
212 | 265 |
|
213 | 266 | def run(comm_rank, comm_size, tmpdir): |
|
0 commit comments