Skip to content

Commit b04c904

Browse files
update hpi config (#3835)
1 parent d459390 commit b04c904

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

paddleseg/core/export.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def export(args, model=None, save_dir=None, use_ema=False):
2828
assert args.config is not None, \
2929
'No configuration file specified, please set --config'
3030
cfg = Config(args.config)
31+
use_fd_inference = True
3132
if not model:
3233
# save model
3334
builder = SegBuilder(cfg)
@@ -38,6 +39,7 @@ def export(args, model=None, save_dir=None, use_ema=False):
3839
logger.info('Loaded trained params successfully.')
3940
if args.output_op != 'none':
4041
model = WrappedModel(model, args.output_op)
42+
use_fd_inference = False
4143
utils.show_env_info()
4244
utils.show_cfg_info(cfg)
4345
else:
@@ -62,7 +64,7 @@ def export(args, model=None, save_dir=None, use_ema=False):
6264
save_name = 'model'
6365
yaml_name = 'deploy.yaml'
6466

65-
if uniform_output_enabled:
67+
if uniform_output_enabled and use_fd_inference == True:
6668
inference_model_path = os.path.join(save_dir, "inference", save_name)
6769
yml_file = os.path.join(save_dir, "inference", yaml_name)
6870
if use_ema:
@@ -95,21 +97,27 @@ def export(args, model=None, save_dir=None, use_ema=False):
9597
if cfg.dic.get("pdx_model_name", None):
9698
deploy_info["Global"] = {}
9799
deploy_info["Global"]["model_name"] = cfg.dic["pdx_model_name"]
98-
if cfg.dic.get("hpi_config_path", None):
99-
with open(cfg.dic["hpi_config_path"], "r") as fp:
100-
hpi_config = yaml.load(fp, Loader=yaml.SafeLoader)
101-
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
102-
hpi_config["Hpi"]["supported_backends"]["gpu"].remove(
103-
"paddle_tensorrt")
104-
del hpi_config['Hpi']['backend_config']['paddle_tensorrt']
105-
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
106-
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("tensorrt")
107-
del hpi_config['Hpi']['backend_config']['tensorrt']
108-
hpi_config["Hpi"]["selected_backends"]["gpu"] = "paddle_infer"
109-
deploy_info["Hpi"] = hpi_config["Hpi"]
100+
if cfg.dic.get("uniform_output_enabled", False):
101+
dynamic_shapes = {
102+
'x': [[1, 3, 128, 256], [1, 3, 512, 1024], [1, 3, 1024, 2048]]
103+
}
104+
supported_batch_size = [1, 100]
105+
106+
backend_keys = ['paddle_infer', 'tensorrt']
107+
hpi_config = {
108+
"backend_configs": {
109+
key: {
110+
"dynamic_shapes" if key == "tensorrt" else "trt_dynamic_shapes":
111+
dynamic_shapes
112+
}
113+
for key in backend_keys
114+
}
115+
}
116+
deploy_info["Hpi"] = hpi_config
110117
msg = '\n---------------Deploy Information---------------\n'
111118
msg += str(yaml.dump(deploy_info))
112-
logger.info(msg)
119+
if use_fd_inference == False:
120+
logger.info(msg)
113121

114122
with open(yml_file, 'w') as file:
115123
yaml.dump(deploy_info, file)

0 commit comments

Comments
 (0)