@@ -124,16 +124,20 @@ def export(args, model=None, save_dir=None, use_ema=False):
124
124
with open (yml_file , 'w' ) as file :
125
125
yaml .dump (deploy_info , file )
126
126
127
+ paddle_version = version .parse (paddle .__version__ )
127
128
if cfg .dic .get ('export_with_pir' , False ):
128
- paddle_version = version .parse (paddle .__version__ )
129
129
assert (paddle_version >= version .parse ('3.0.0b2' )
130
130
or paddle_version == version .parse ('0.0.0' )) and os .environ .get (
131
131
"FLAGS_enable_pir_api" , None ) not in ["0" , "False" ]
132
132
paddle .jit .save (model , inference_model_path )
133
133
else :
134
- model .forward .rollback ()
135
- with paddle .pir_utils .OldIrGuard ():
136
- model = paddle .jit .to_static (model , input_spec = input_spec )
134
+ if paddle_version >= version .parse (
135
+ '3.0.0b2' ) or paddle_version == version .parse ('0.0.0' ):
136
+ model .forward .rollback ()
137
+ with paddle .pir_utils .OldIrGuard ():
138
+ model = paddle .jit .to_static (model , input_spec = input_spec )
139
+ paddle .jit .save (model , inference_model_path )
140
+ else :
137
141
paddle .jit .save (model , inference_model_path )
138
142
139
143
logger .info (f'The inference model is saved in { save_dir } ' )
0 commit comments