@@ -198,15 +198,23 @@ def func(x, b=1.0):
198
198
),
199
199
)
200
200
onnx_test_common .assert_dynamic_shapes (onnx_program , self .dynamic_shapes )
201
- onnx_format_args = onnx_program .adapt_torch_inputs_to_onnx (tensor_x , 8.0 )
202
- ref_outputs = onnx_program .adapt_torch_outputs_to_onnx (func (tensor_x , 8.0 ))
201
+ onnx_format_args = onnx_program .adapt_torch_inputs_to_onnx (
202
+ tensor_x , model = func , b = 8.0
203
+ )
204
+ ref_outputs = onnx_program .adapt_torch_outputs_to_onnx (
205
+ func , func (tensor_x , 8.0 )
206
+ )
203
207
ort_outputs = onnx_test_common .run_ort (onnx_program , onnx_format_args )
204
208
for ref_output , ort_output in zip (ref_outputs , ort_outputs ):
205
209
torch .testing .assert_close (ref_output , torch .tensor (ort_output ))
206
210
207
211
# test on different non-tensor input - xfail
208
- onnx_format_args = onnx_program .adapt_torch_inputs_to_onnx (tensor_x , 9.0 )
209
- ref_outputs = onnx_program .adapt_torch_outputs_to_onnx (func (tensor_x , 9.0 ))
212
+ onnx_format_args = onnx_program .adapt_torch_inputs_to_onnx (
213
+ tensor_x , model = func , b = 9.0
214
+ )
215
+ ref_outputs = onnx_program .adapt_torch_outputs_to_onnx (
216
+ func , func (tensor_x , 9.0 )
217
+ )
210
218
_ = onnx_test_common .run_ort (onnx_program , onnx_format_args )
211
219
for ref_output , ort_output in zip (ref_outputs , ort_outputs ):
212
220
torch .testing .assert_close (ref_output , torch .tensor (ort_output ))
@@ -538,7 +546,7 @@ def forward(self, x):
538
546
additional_test_inputs = [((x2 ,),)],
539
547
)
540
548
541
- @pytorch_test_common .xfail (
549
+ @pytorch_test_common .xfail_if_model_type_is_not_exportedprogram (
542
550
"RuntimeError: at::functionalization::impl::isFunctionalTensor(self_) INTERNAL ASSERT FAILED "
543
551
"at '/path/to/pytorch/torch/csrc/autograd/python_torch_functions_manual.cpp':514, please report a bug to PyTorch."
544
552
)
@@ -831,10 +839,10 @@ def _test_fx_symbolic_tracer_large_scale_exporter(
831
839
kwargs = create_pytorch_only_kwargs ()
832
840
# Original outputs.
833
841
ref_outputs = onnx_program .adapt_torch_outputs_to_onnx (
834
- model (* args , ** kwargs )
842
+ model , model (* args , ** kwargs )
835
843
)
836
844
# ORT outputs.
837
- args_not_none = onnx_program .adapt_torch_inputs_to_onnx (* args )
845
+ args_not_none = onnx_program .adapt_torch_inputs_to_onnx (* args , model = model )
838
846
839
847
# Drop Parameters and buffers added by fx_serialization.save_model_with_external_data
840
848
args_not_none = args_not_none [: len (args ) - len (kwargs )]
@@ -923,14 +931,24 @@ def create_pytorch_only_extra_kwargs():
923
931
def _parameterized_class_attrs_and_values_with_fake_options ():
924
932
input_values = []
925
933
input_values .extend (
926
- itertools .product ((True , False ), (True , False ), (True , False ), (True , False ))
934
+ itertools .product (
935
+ (True , False ),
936
+ (True , False ),
937
+ (True , False ),
938
+ (True , False ),
939
+ (
940
+ onnx_test_common .TorchModelType .TORCH_NN_MODULE ,
941
+ onnx_test_common .TorchModelType .TORCH_EXPORT_EXPORTEDPROGRAM ,
942
+ ),
943
+ )
927
944
)
928
945
return {
929
946
"attrs" : [
930
947
"op_level_debug" ,
931
948
"dynamic_shapes" ,
932
949
"load_checkpoint_during_init" ,
933
950
"export_within_fake_mode" ,
951
+ "model_type" ,
934
952
],
935
953
"input_values" : input_values ,
936
954
}
@@ -950,6 +968,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
950
968
dynamic_shapes : bool
951
969
load_checkpoint_during_init : bool
952
970
export_within_fake_mode : bool
971
+ model_type : onnx_test_common .TorchModelType
953
972
954
973
def setUp (self ):
955
974
super ().setUp ()
@@ -964,6 +983,7 @@ def _test_fake_tensor_mode_exporter(
964
983
create_kwargs : Callable ,
965
984
load_checkpoint_during_init : bool ,
966
985
export_within_fake_mode : bool ,
986
+ model_type : onnx_test_common .TorchModelType ,
967
987
):
968
988
"""Test helper for FakeTensorMode-enabled exporter.
969
989
@@ -975,6 +995,8 @@ def _test_fake_tensor_mode_exporter(
975
995
load_checkpoint_during_init: Whether to load a checkpoint during model initialization.
976
996
(after or during model creation, but before exporting starts)
977
997
export_within_fake_mode: Whether to call torch.onnx._dynamo_export within torch._subclasses.FakeTensorMode
998
+ model_type: Type of user model. Used to determine whether the user model must be exported to
999
+ torch.export.ExportedProgram before passing it to torch.onnx.dynamo_export
978
1000
979
1001
This test contains several steps.
980
1002
@@ -990,13 +1012,17 @@ def _test_fake_tensor_mode_exporter(
990
1012
991
1013
# Create the toy model with real weight.
992
1014
real_model = create_model ()
1015
+ state_dict = real_model .state_dict () # concrete (non-fake) state_dict
1016
+ if model_type == onnx_test_common .TorchModelType .TORCH_EXPORT_EXPORTEDPROGRAM :
1017
+ real_model = torch .export .export (
1018
+ real_model , args = create_args (), kwargs = create_kwargs ()
1019
+ )
993
1020
994
1021
with tempfile .NamedTemporaryFile (
995
1022
prefix = model_name , suffix = ".pt"
996
1023
) as tmp_checkpoint_file :
997
1024
# Dump state_dict to a file to simulate how HuggingFace model is initialized.
998
1025
# The file will be loaded via .load_state_dict(...)
999
- state_dict = real_model .state_dict ()
1000
1026
torch .save (state_dict , tmp_checkpoint_file .name )
1001
1027
1002
1028
with torch .onnx .enable_fake_mode () as fake_context :
@@ -1014,6 +1040,13 @@ def _test_fake_tensor_mode_exporter(
1014
1040
)
1015
1041
1016
1042
if export_within_fake_mode :
1043
+ if (
1044
+ model_type
1045
+ == onnx_test_common .TorchModelType .TORCH_EXPORT_EXPORTEDPROGRAM
1046
+ ):
1047
+ fake_model = torch .export .export (
1048
+ fake_model , args = fake_args , kwargs = fake_kwargs
1049
+ )
1017
1050
onnx_program = torch .onnx .dynamo_export (
1018
1051
fake_model ,
1019
1052
* fake_args ,
@@ -1022,6 +1055,13 @@ def _test_fake_tensor_mode_exporter(
1022
1055
)
1023
1056
1024
1057
if not export_within_fake_mode :
1058
+ if (
1059
+ model_type
1060
+ == onnx_test_common .TorchModelType .TORCH_EXPORT_EXPORTEDPROGRAM
1061
+ ):
1062
+ fake_model = torch .export .export (
1063
+ fake_model , args = fake_args , kwargs = fake_kwargs
1064
+ )
1025
1065
onnx_program = torch .onnx .dynamo_export (
1026
1066
fake_model , * fake_args , ** fake_kwargs , export_options = export_options
1027
1067
)
@@ -1038,10 +1078,12 @@ def _test_fake_tensor_mode_exporter(
1038
1078
kwargs = create_kwargs ()
1039
1079
# Original outputs.
1040
1080
ref_outputs = onnx_program .adapt_torch_outputs_to_onnx (
1041
- real_model (* args , ** kwargs )
1081
+ fake_model , real_model (* args , ** kwargs )
1042
1082
)
1043
1083
# ORT outputs.
1044
- args_not_none = onnx_program .adapt_torch_inputs_to_onnx (* args , ** kwargs )
1084
+ args_not_none = onnx_program .adapt_torch_inputs_to_onnx (
1085
+ * args , model = real_model , ** kwargs
1086
+ )
1045
1087
1046
1088
ort_outputs = onnx_test_common .run_ort (
1047
1089
tmp_onnx_file .name ,
@@ -1053,6 +1095,10 @@ def _test_fake_tensor_mode_exporter(
1053
1095
for ref_output , ort_output in zip (ref_outputs , ort_outputs ):
1054
1096
torch .testing .assert_close (ref_output , torch .tensor (ort_output ))
1055
1097
1098
+ @pytorch_test_common .skip_dynamic_fx_test (
1099
+ "AssertionError: Dynamic shape check failed for graph inputs" ,
1100
+ skip_model_type = onnx_test_common .TorchModelType .TORCH_EXPORT_EXPORTEDPROGRAM ,
1101
+ )
1056
1102
def test_fake_tensor_mode_simple (self ):
1057
1103
def create_model () -> nn .Module :
1058
1104
class Model (torch .nn .Module ):
@@ -1079,14 +1125,19 @@ def create_kwargs():
1079
1125
create_kwargs ,
1080
1126
load_checkpoint_during_init = self .load_checkpoint_during_init ,
1081
1127
export_within_fake_mode = self .export_within_fake_mode ,
1128
+ model_type = self .model_type ,
1082
1129
)
1083
1130
1084
- @pytorch_test_common .xfail (
1131
+ @pytorch_test_common .xfail_if_model_type_is_not_exportedprogram (
1085
1132
"[ONNXRuntimeError] : 1 : FAIL : Type Error: Data in initializer 'h_0_attn_bias' "
1086
1133
"has element type tensor(uint8) but usage of initializer in graph expects tensor(bool)"
1087
1134
"https://github.com/huggingface/transformers/issues/21013"
1088
1135
"This can be addressed by using GPT2Config, but it is not now supported by FakeTensor exporting."
1089
1136
)
1137
+ @pytorch_test_common .skip_dynamic_fx_test (
1138
+ "AssertionError: Dynamic shape check failed for graph inputs" ,
1139
+ skip_model_type = onnx_test_common .TorchModelType .TORCH_EXPORT_EXPORTEDPROGRAM ,
1140
+ )
1090
1141
def test_large_scale_exporter_with_tiny_gpt2 (self ):
1091
1142
model_name = "sshleifer/tiny-gpt2"
1092
1143
device = "cpu"
@@ -1111,8 +1162,13 @@ def create_kwargs():
1111
1162
create_kwargs ,
1112
1163
load_checkpoint_during_init = self .load_checkpoint_during_init ,
1113
1164
export_within_fake_mode = self .export_within_fake_mode ,
1165
+ model_type = self .model_type ,
1114
1166
)
1115
1167
1168
+ @pytorch_test_common .skip_dynamic_fx_test (
1169
+ "AssertionError: Dynamic shape check failed for graph inputs" ,
1170
+ skip_model_type = onnx_test_common .TorchModelType .TORCH_EXPORT_EXPORTEDPROGRAM ,
1171
+ )
1116
1172
def test_large_scale_exporter_with_toy_mlp (self ):
1117
1173
class MLPModel (nn .Module ):
1118
1174
def __init__ (self ):
@@ -1148,8 +1204,13 @@ def create_kwargs():
1148
1204
create_kwargs ,
1149
1205
load_checkpoint_during_init = self .load_checkpoint_during_init ,
1150
1206
export_within_fake_mode = self .export_within_fake_mode ,
1207
+ model_type = self .model_type ,
1151
1208
)
1152
1209
1210
+ @pytorch_test_common .skip_dynamic_fx_test (
1211
+ "AssertionError: Dynamic shape check failed for graph inputs" ,
1212
+ skip_model_type = onnx_test_common .TorchModelType .TORCH_EXPORT_EXPORTEDPROGRAM ,
1213
+ )
1153
1214
def test_fake_tensor_mode_huggingface_google_t5 (self ):
1154
1215
config = transformers .T5Config (
1155
1216
vocab_size = 8096 , d_model = 64 , num_layers = 2 , num_heads = 2
@@ -1179,8 +1240,13 @@ def create_model():
1179
1240
create_kwargs ,
1180
1241
load_checkpoint_during_init = self .load_checkpoint_during_init ,
1181
1242
export_within_fake_mode = self .export_within_fake_mode ,
1243
+ model_type = self .model_type ,
1182
1244
)
1183
1245
1246
+ @pytorch_test_common .skip_dynamic_fx_test (
1247
+ "AssertionError: Dynamic shape check failed for graph inputs" ,
1248
+ skip_model_type = onnx_test_common .TorchModelType .TORCH_EXPORT_EXPORTEDPROGRAM ,
1249
+ )
1184
1250
def test_fake_tensor_mode_huggingface_openai_whisper (self ):
1185
1251
config = transformers .WhisperConfig (
1186
1252
vocab_size = 8096 ,
@@ -1231,6 +1297,7 @@ def create_kwargs():
1231
1297
create_kwargs ,
1232
1298
load_checkpoint_during_init = self .load_checkpoint_during_init ,
1233
1299
export_within_fake_mode = self .export_within_fake_mode ,
1300
+ model_type = self .model_type ,
1234
1301
)
1235
1302
1236
1303
@pytorch_test_common .xfail (
@@ -1260,6 +1327,7 @@ def create_model():
1260
1327
create_kwargs ,
1261
1328
load_checkpoint_during_init = self .load_checkpoint_during_init ,
1262
1329
export_within_fake_mode = self .export_within_fake_mode ,
1330
+ model_type = self .model_type ,
1263
1331
)
1264
1332
1265
1333
@pytorch_test_common .skip_dynamic_fx_test (
@@ -1287,6 +1355,7 @@ def create_model():
1287
1355
create_kwargs ,
1288
1356
load_checkpoint_during_init = self .load_checkpoint_during_init ,
1289
1357
export_within_fake_mode = self .export_within_fake_mode ,
1358
+ model_type = self .model_type ,
1290
1359
)
1291
1360
1292
1361
0 commit comments