Skip to content

Commit 258bc64

Browse files
authored
add save_inference_model in infer.py (PaddlePaddle#2714)
* add save_inference_model in infer.py * format code * add comment * add save_inference_model doc * refine doc
1 parent 7225e14 commit 258bc64

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

PaddleCV/PaddleDetection/docs/GETTING_STARTED.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,18 @@ python tools/infer.py -c configs/faster_rcnn_r50_1x.yml --infer_dir=demo
7272
The visualization files are saved in `output` by default, to specify a different
7373
path, simply add a `--save_file=` flag.
7474

75+
- Save inference model
76+
77+
```bash
78+
export CUDA_VISIBLE_DEVICES=0
79+
# or run on CPU with:
80+
# export CPU_NUM=1
81+
python tools/infer.py -c configs/faster_rcnn_r50_1x.yml --infer_img=demo/000000570688.jpg \
82+
--save_inference_model
83+
```
84+
85+
Save inference model by set `--save_inference_model`.
86+
7587

7688
## FAQ
7789

PaddleCV/PaddleDetection/tools/infer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,24 @@ def get_test_images(infer_dir, infer_img):
8181
return images
8282

8383

84+
def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog):
85+
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
86+
save_dir = os.path.join(FLAGS.output_dir, cfg_name)
87+
feeded_var_names = [var.name for var in feed_vars.values()]
88+
# im_id is only used for visualize, not used in inference model
89+
feeded_var_names.remove('im_id')
90+
target_vars = test_fetches.values()
91+
logger.info("Save inference model to {}, input: {}, output: "
92+
"{}...".format(save_dir, feeded_var_names,
93+
[var.name for var in target_vars]))
94+
fluid.io.save_inference_model(save_dir,
95+
feeded_var_names=feeded_var_names,
96+
target_vars=target_vars,
97+
executor=exe,
98+
main_program=infer_prog,
99+
params_filename="__parmas__")
100+
101+
84102
def main():
85103
cfg = load_config(FLAGS.config)
86104

@@ -119,6 +137,9 @@ def main():
119137
if cfg.weights:
120138
checkpoint.load_checkpoint(exe, infer_prog, cfg.weights)
121139

140+
if FLAGS.save_inference_model:
141+
save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog)
142+
122143
# parse infer fetches
123144
extra_keys = []
124145
if cfg['metric'] == 'COCO':
@@ -196,5 +217,10 @@ def main():
196217
type=float,
197218
default=0.5,
198219
help="Threshold to reserve the result for visualization.")
220+
parser.add_argument(
221+
"--save_inference_model",
222+
action='store_true',
223+
default=False,
224+
help="Save inference model in output_dir if True.")
199225
FLAGS = parser.parse_args()
200226
main()

0 commit comments

Comments
 (0)