@@ -81,6 +81,24 @@ def get_test_images(infer_dir, infer_img):
81
81
return images
82
82
83
83
84
+ def save_infer_model (FLAGS , exe , feed_vars , test_fetches , infer_prog ):
85
+ cfg_name = os .path .basename (FLAGS .config ).split ('.' )[0 ]
86
+ save_dir = os .path .join (FLAGS .output_dir , cfg_name )
87
+ feeded_var_names = [var .name for var in feed_vars .values ()]
88
+ # im_id is only used for visualize, not used in inference model
89
+ feeded_var_names .remove ('im_id' )
90
+ target_vars = test_fetches .values ()
91
+ logger .info ("Save inference model to {}, input: {}, output: "
92
+ "{}..." .format (save_dir , feeded_var_names ,
93
+ [var .name for var in target_vars ]))
94
+ fluid .io .save_inference_model (save_dir ,
95
+ feeded_var_names = feeded_var_names ,
96
+ target_vars = target_vars ,
97
+ executor = exe ,
98
+ main_program = infer_prog ,
99
+ params_filename = "__parmas__" )
100
+
101
+
84
102
def main ():
85
103
cfg = load_config (FLAGS .config )
86
104
@@ -119,6 +137,9 @@ def main():
119
137
if cfg .weights :
120
138
checkpoint .load_checkpoint (exe , infer_prog , cfg .weights )
121
139
140
+ if FLAGS .save_inference_model :
141
+ save_infer_model (FLAGS , exe , feed_vars , test_fetches , infer_prog )
142
+
122
143
# parse infer fetches
123
144
extra_keys = []
124
145
if cfg ['metric' ] == 'COCO' :
@@ -196,5 +217,10 @@ def main():
196
217
type = float ,
197
218
default = 0.5 ,
198
219
help = "Threshold to reserve the result for visualization." )
220
+ parser .add_argument (
221
+ "--save_inference_model" ,
222
+ action = 'store_true' ,
223
+ default = False ,
224
+ help = "Save inference model in output_dir if True." )
199
225
FLAGS = parser .parse_args ()
200
226
main ()
0 commit comments