Skip to content

Commit 16c6fae

Browse files
add version control for export (#3877)
1 parent 605f764 commit 16c6fae

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

paddleseg/core/export.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,20 @@ def export(args, model=None, save_dir=None, use_ema=False):
124124
with open(yml_file, 'w') as file:
125125
yaml.dump(deploy_info, file)
126126

127+
paddle_version = version.parse(paddle.__version__)
127128
if cfg.dic.get('export_with_pir', False):
128-
paddle_version = version.parse(paddle.__version__)
129129
assert (paddle_version >= version.parse('3.0.0b2')
130130
or paddle_version == version.parse('0.0.0')) and os.environ.get(
131131
"FLAGS_enable_pir_api", None) not in ["0", "False"]
132132
paddle.jit.save(model, inference_model_path)
133133
else:
134-
model.forward.rollback()
135-
with paddle.pir_utils.OldIrGuard():
136-
model = paddle.jit.to_static(model, input_spec=input_spec)
134+
if paddle_version >= version.parse(
135+
'3.0.0b2') or paddle_version == version.parse('0.0.0'):
136+
model.forward.rollback()
137+
with paddle.pir_utils.OldIrGuard():
138+
model = paddle.jit.to_static(model, input_spec=input_spec)
139+
paddle.jit.save(model, inference_model_path)
140+
else:
137141
paddle.jit.save(model, inference_model_path)
138142

139143
logger.info(f'The inference model is saved in {save_dir}')

0 commit comments

Comments
 (0)