@@ -109,25 +109,13 @@ NNVM_REGISTER_OP(softmax)
109109.describe(" Softmax operation" )
110110.set_num_inputs(1 )
111111.include(" nn_module" )
112- .set_attr<FLuaCreateNNModule>(
113- " FLuaCreateNNModule" , R"(
114- function(ishape, kwarg)
115- return nn.SoftMax()
116- end
117- )" )
118112.set_attr<FInferShape>(" FInferShape" , SameShape);
119113
120114
121115NNVM_REGISTER_OP (relu)
122116.describe(" Relu operation" )
123117.set_num_inputs(1 )
124118.include(" nn_module" )
125- .set_attr<FLuaCreateNNModule>(
126- " FLuaCreateNNModule" , R"(
127- function(ishape, kwarg)
128- return nn.ReLU()
129- end
130- )" )
131119.set_attr<FInferShape>(" FInferShape" , SameShape)
132120.set_attr<bool >(" TBackwardNeedOutputs" , true );
133121
@@ -136,12 +124,6 @@ NNVM_REGISTER_OP(tanh)
136124.describe(" Tanh operation" )
137125.set_num_inputs(1 )
138126.include(" nn_module" )
139- .set_attr<FLuaCreateNNModule>(
140- " FLuaCreateNNModule" , R"(
141- function(ishape, kwarg)
142- return nn.Tanh()
143- end
144- )" )
145127.set_attr<FInferShape>(" FInferShape" , SameShape);
146128
147129
@@ -193,17 +175,6 @@ NNVM_REGISTER_OP(linear)
193175 }
194176 })
195177.include(" nn_module" )
196- .set_attr<FLuaCreateNNModule>(
197- " FLuaCreateNNModule" , R"(
198- function(ishape, kwarg)
199- local wshape = ishape[2]
200- local m = nn.Linear(wshape[2], wshape[1])
201- if #ishape == 2 then
202- m = m:noBias()
203- end
204- return m
205- end
206- )" )
207178.set_attr<FInferShape>(" FInferShape" , LinearShape);
208179
209180
@@ -273,37 +244,6 @@ NNVM_REGISTER_OP(conv2d)
273244 })
274245.set_attr_parser(ParamParser<ConvPoolParam>)
275246.include(" nn_module" )
276- .set_attr<FLuaCreateNNModule>(
277- " FLuaCreateNNModule" , R"(
278- function(ishape, kwarg)
279- local dshape = ishape[2]
280- local fshape = ishape[2]
281- local outPlane = fshape[1]
282- local inPlane = fshape[2]
283- local kH = fshape[3]
284- local kW = fshape[4]
285- local inH = dshape[3]
286- local inW = dshape[4]
287- local stride = nn_parse_tuple(kwarg.strides, {1,1,1,1})
288- local dH = stride[2]
289- local dW = stride[3]
290- local padH = 0
291- local padW = 0
292-
293- assert(kwarg.data_format == 'NCHW')
294- if kwarg.padding == 'SAME' then
295- padW = math.floor((kW - 1) / 2)
296- padH = math.floor((kH - 1) / 2)
297- end
298- local m = nn.SpatialConvolution(
299- inPlane, outPlane,
300- kW, kH, dW, dH, padW, padH)
301- if #ishape == 2 then
302- m = m:noBias()
303- end
304- return m
305- end
306- )" )
307247.set_attr<FListInputNames>(" FListInputNames" , [](const NodeAttrs& attrs) {
308248 if (dmlc::get<ConvPoolParam>(attrs.parsed ).no_bias ) {
309249 return std::vector<std::string>{" data" , " weight" };
@@ -320,54 +260,18 @@ NNVM_REGISTER_OP(max_pool)
320260.set_num_inputs(1 )
321261.set_attr_parser(ParamParser<ConvPoolParam>)
322262.include(" nn_module" )
323- .set_attr<FLuaCreateNNModule>(
324- " FLuaCreateNNModule" , R"(
325- function(ishape, kwarg)
326- local ksize = nn_parse_tuple(kwarg.ksize)
327- local stride = nn_parse_tuple(kwarg.strides, {1,1,1,1})
328- local kH = ksize[2]
329- local kW = ksize[3]
330- local dH = stride[2]
331- local dW = stride[3]
332- local padH = 0
333- local padW = 0
334- assert(kwarg.data_format == 'NCHW')
335- if kwarg.padding == 'SAME' then
336- padW = math.floor((kW - 1) / 2)
337- padH = math.floor((kH - 1) / 2)
338- end
339- return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
340- end
341- )" )
342263.set_attr<FInferShape>(" FInferShape" , ConvPoolShape);
343264
344265
345266NNVM_REGISTER_OP (mean_sparse_softmax_cross_entropy_with_logits)
346267.describe(" Softmax cross entropy given logit and label" )
347268.set_num_inputs(2 )
348- .include(" nn_criterion" )
349- .set_attr<FLuaCreateNNModule>(
350- " FLuaCreateNNModule" , R"(
351- function(ishape, kwarg)
352- return nn_zero_index_target_criterion(
353- nn.CrossEntropyCriterion())
354- end
355- )" );
356-
357- const char * LuaReshape = R"(
358- function(x, y, kwarg)
359- if x[1]:storage() == y[1]:storage() then
360- return function() end
361- else
362- return function() y[1]:copy(x[1]:resizeAs(y[1])) end
363- end
364- end
365- )" ;
269+ .include(" nn_criterion" );
270+
366271
367272NNVM_REGISTER_OP (flatten_layer)
368273.describe(" Flatten to 2D" )
369274.set_num_inputs(1 )
370- .set_attr<FLuaCompute>(" FLuaCompute" , LuaReshape)
371275.set_attr<FInplaceOption>(" FInplaceOption" , InplaceIn0Out0)
372276.set_attr<FInferShape>(
373277 " FInferShape" , [](const NodeAttrs& attrs,
@@ -388,7 +292,6 @@ NNVM_REGISTER_OP(flatten_layer)
388292
389293NNVM_REGISTER_OP (_flatten_backward)
390294.set_num_inputs(1 )
391- .set_attr<FLuaCompute>(" FLuaCompute" , LuaReshape)
392295.set_attr<FInplaceOption>(" FInplaceOption" , InplaceIn0Out0)
393296.set_attr<FBackwardOutToInIndex>(
394297 " FBackwardOutToInIndex" , [](const NodeAttrs& attrs) {
0 commit comments