Skip to content

Commit 738b4a5

Browse files
author
Thiago Crepaldi
authored
Update ONNX's IO Adapter to support FakeTensor with ExportedProgram (#114407) (#115578)
Currently, the ONNX exporter using torch.nn.Module as input can support FakeTensor because the ONNX model stores all initializers When using torch.export.ExportedProgram as input, the initializers are lifted as inputs. In order to execute the ONNX model, we need to pass a reference to the non-fake model to the ONNXProgram.adapt_torch_inputs_to_onnx API, so that initializers can be fetched from the model and fed to the ONNX model as input ps: #115461 will track the API revision for the cases where additional `model_with_state_dict` are required to produce complete ONNX files exported with fake support. This is also tracked by the umbrella fake tensor issue #105464 FYI @BowenBao Pull Request resolved: #114407 Approved by: https://github.com/BowenBao
1 parent 4cf10bf commit 738b4a5

10 files changed

+231
-67
lines changed

test/onnx/onnx_test_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,9 +439,9 @@ def _compare_pytorch_onnx_with_ort(
439439
# ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict.
440440
# Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict.
441441
# Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__()
442-
ort_outputs = onnx_program(*input_args, **input_kwargs)
442+
ort_outputs = onnx_program(*input_args, model=ref_model, **input_kwargs)
443443
ref_outputs = ref_model(*ref_input_args, **ref_input_kwargs)
444-
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_outputs)
444+
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_model, ref_outputs)
445445

446446
if len(ref_outputs) != len(ort_outputs):
447447
raise AssertionError(

test/onnx/pytorch_test_common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,13 @@ def wrapper(self, *args, **kwargs):
188188
return skip_dec
189189

190190

191-
def skip_dynamic_fx_test(reason: str):
191+
def skip_dynamic_fx_test(reason: str, skip_model_type=None):
192192
"""Skip dynamic exporting test.
193193
194194
Args:
195195
reason: The reason for skipping dynamic exporting test.
196+
skip_model_type (onnx_test_common.TorchModelType): The model type to skip dynamic exporting test for.
197+
When None, model type is not used to skip dynamic tests.
196198
197199
Returns:
198200
A decorator for skipping dynamic exporting test.
@@ -201,7 +203,9 @@ def skip_dynamic_fx_test(reason: str):
201203
def skip_dec(func):
202204
@functools.wraps(func)
203205
def wrapper(self, *args, **kwargs):
204-
if self.dynamic_shapes:
206+
if self.dynamic_shapes and (
207+
not skip_model_type or self.model_type == skip_model_type
208+
):
205209
raise unittest.SkipTest(
206210
f"Skip verify dynamic shapes test for FX. {reason}"
207211
)

test/onnx/test_fx_to_onnx_with_onnxruntime.py

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,23 @@ def func(x, b=1.0):
198198
),
199199
)
200200
onnx_test_common.assert_dynamic_shapes(onnx_program, self.dynamic_shapes)
201-
onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(tensor_x, 8.0)
202-
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(func(tensor_x, 8.0))
201+
onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(
202+
tensor_x, model=func, b=8.0
203+
)
204+
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
205+
func, func(tensor_x, 8.0)
206+
)
203207
ort_outputs = onnx_test_common.run_ort(onnx_program, onnx_format_args)
204208
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
205209
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
206210

207211
# test on different non-tensor input - xfail
208-
onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(tensor_x, 9.0)
209-
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(func(tensor_x, 9.0))
212+
onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(
213+
tensor_x, model=func, b=9.0
214+
)
215+
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
216+
func, func(tensor_x, 9.0)
217+
)
210218
_ = onnx_test_common.run_ort(onnx_program, onnx_format_args)
211219
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
212220
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
@@ -538,7 +546,7 @@ def forward(self, x):
538546
additional_test_inputs=[((x2,),)],
539547
)
540548

541-
@pytorch_test_common.xfail(
549+
@pytorch_test_common.xfail_if_model_type_is_not_exportedprogram(
542550
"RuntimeError: at::functionalization::impl::isFunctionalTensor(self_) INTERNAL ASSERT FAILED "
543551
"at '/path/to/pytorch/torch/csrc/autograd/python_torch_functions_manual.cpp':514, please report a bug to PyTorch."
544552
)
@@ -831,10 +839,10 @@ def _test_fx_symbolic_tracer_large_scale_exporter(
831839
kwargs = create_pytorch_only_kwargs()
832840
# Original outputs.
833841
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
834-
model(*args, **kwargs)
842+
model, model(*args, **kwargs)
835843
)
836844
# ORT outputs.
837-
args_not_none = onnx_program.adapt_torch_inputs_to_onnx(*args)
845+
args_not_none = onnx_program.adapt_torch_inputs_to_onnx(*args, model=model)
838846

839847
# Drop Parameters and buffers added by fx_serialization.save_model_with_external_data
840848
args_not_none = args_not_none[: len(args) - len(kwargs)]
@@ -923,14 +931,24 @@ def create_pytorch_only_extra_kwargs():
923931
def _parameterized_class_attrs_and_values_with_fake_options():
924932
input_values = []
925933
input_values.extend(
926-
itertools.product((True, False), (True, False), (True, False), (True, False))
934+
itertools.product(
935+
(True, False),
936+
(True, False),
937+
(True, False),
938+
(True, False),
939+
(
940+
onnx_test_common.TorchModelType.TORCH_NN_MODULE,
941+
onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
942+
),
943+
)
927944
)
928945
return {
929946
"attrs": [
930947
"op_level_debug",
931948
"dynamic_shapes",
932949
"load_checkpoint_during_init",
933950
"export_within_fake_mode",
951+
"model_type",
934952
],
935953
"input_values": input_values,
936954
}
@@ -950,6 +968,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
950968
dynamic_shapes: bool
951969
load_checkpoint_during_init: bool
952970
export_within_fake_mode: bool
971+
model_type: onnx_test_common.TorchModelType
953972

954973
def setUp(self):
955974
super().setUp()
@@ -964,6 +983,7 @@ def _test_fake_tensor_mode_exporter(
964983
create_kwargs: Callable,
965984
load_checkpoint_during_init: bool,
966985
export_within_fake_mode: bool,
986+
model_type: onnx_test_common.TorchModelType,
967987
):
968988
"""Test helper for FakeTensorMode-enabled exporter.
969989
@@ -975,6 +995,8 @@ def _test_fake_tensor_mode_exporter(
975995
load_checkpoint_during_init: Whether to load a checkpoint during model initialization.
976996
(after or during model creation, but before exporting starts)
977997
export_within_fake_mode: Whether to call torch.onnx._dynamo_export within torch._subclasses.FakeTensorMode
998+
model_type: Type of user model. Used to determine whether the user model must be exported to
999+
torch.export.ExportedProgram before passing it to torch.onnx.dynamo_export
9781000
9791001
This test contains several steps.
9801002
@@ -990,13 +1012,17 @@ def _test_fake_tensor_mode_exporter(
9901012

9911013
# Create the toy model with real weight.
9921014
real_model = create_model()
1015+
state_dict = real_model.state_dict() # concrete (non-fake) state_dict
1016+
if model_type == onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
1017+
real_model = torch.export.export(
1018+
real_model, args=create_args(), kwargs=create_kwargs()
1019+
)
9931020

9941021
with tempfile.NamedTemporaryFile(
9951022
prefix=model_name, suffix=".pt"
9961023
) as tmp_checkpoint_file:
9971024
# Dump state_dict to a file to simulate how HuggingFace model is initialized.
9981025
# The file will be loaded via .load_state_dict(...)
999-
state_dict = real_model.state_dict()
10001026
torch.save(state_dict, tmp_checkpoint_file.name)
10011027

10021028
with torch.onnx.enable_fake_mode() as fake_context:
@@ -1014,6 +1040,13 @@ def _test_fake_tensor_mode_exporter(
10141040
)
10151041

10161042
if export_within_fake_mode:
1043+
if (
1044+
model_type
1045+
== onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
1046+
):
1047+
fake_model = torch.export.export(
1048+
fake_model, args=fake_args, kwargs=fake_kwargs
1049+
)
10171050
onnx_program = torch.onnx.dynamo_export(
10181051
fake_model,
10191052
*fake_args,
@@ -1022,6 +1055,13 @@ def _test_fake_tensor_mode_exporter(
10221055
)
10231056

10241057
if not export_within_fake_mode:
1058+
if (
1059+
model_type
1060+
== onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
1061+
):
1062+
fake_model = torch.export.export(
1063+
fake_model, args=fake_args, kwargs=fake_kwargs
1064+
)
10251065
onnx_program = torch.onnx.dynamo_export(
10261066
fake_model, *fake_args, **fake_kwargs, export_options=export_options
10271067
)
@@ -1038,10 +1078,12 @@ def _test_fake_tensor_mode_exporter(
10381078
kwargs = create_kwargs()
10391079
# Original outputs.
10401080
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
1041-
real_model(*args, **kwargs)
1081+
fake_model, real_model(*args, **kwargs)
10421082
)
10431083
# ORT outputs.
1044-
args_not_none = onnx_program.adapt_torch_inputs_to_onnx(*args, **kwargs)
1084+
args_not_none = onnx_program.adapt_torch_inputs_to_onnx(
1085+
*args, model=real_model, **kwargs
1086+
)
10451087

10461088
ort_outputs = onnx_test_common.run_ort(
10471089
tmp_onnx_file.name,
@@ -1053,6 +1095,10 @@ def _test_fake_tensor_mode_exporter(
10531095
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
10541096
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
10551097

1098+
@pytorch_test_common.skip_dynamic_fx_test(
1099+
"AssertionError: Dynamic shape check failed for graph inputs",
1100+
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1101+
)
10561102
def test_fake_tensor_mode_simple(self):
10571103
def create_model() -> nn.Module:
10581104
class Model(torch.nn.Module):
@@ -1079,14 +1125,19 @@ def create_kwargs():
10791125
create_kwargs,
10801126
load_checkpoint_during_init=self.load_checkpoint_during_init,
10811127
export_within_fake_mode=self.export_within_fake_mode,
1128+
model_type=self.model_type,
10821129
)
10831130

1084-
@pytorch_test_common.xfail(
1131+
@pytorch_test_common.xfail_if_model_type_is_not_exportedprogram(
10851132
"[ONNXRuntimeError] : 1 : FAIL : Type Error: Data in initializer 'h_0_attn_bias' "
10861133
"has element type tensor(uint8) but usage of initializer in graph expects tensor(bool)"
10871134
"https://github.com/huggingface/transformers/issues/21013"
10881135
"This can be addressed by using GPT2Config, but it is not now supported by FakeTensor exporting."
10891136
)
1137+
@pytorch_test_common.skip_dynamic_fx_test(
1138+
"AssertionError: Dynamic shape check failed for graph inputs",
1139+
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1140+
)
10901141
def test_large_scale_exporter_with_tiny_gpt2(self):
10911142
model_name = "sshleifer/tiny-gpt2"
10921143
device = "cpu"
@@ -1111,8 +1162,13 @@ def create_kwargs():
11111162
create_kwargs,
11121163
load_checkpoint_during_init=self.load_checkpoint_during_init,
11131164
export_within_fake_mode=self.export_within_fake_mode,
1165+
model_type=self.model_type,
11141166
)
11151167

1168+
@pytorch_test_common.skip_dynamic_fx_test(
1169+
"AssertionError: Dynamic shape check failed for graph inputs",
1170+
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1171+
)
11161172
def test_large_scale_exporter_with_toy_mlp(self):
11171173
class MLPModel(nn.Module):
11181174
def __init__(self):
@@ -1148,8 +1204,13 @@ def create_kwargs():
11481204
create_kwargs,
11491205
load_checkpoint_during_init=self.load_checkpoint_during_init,
11501206
export_within_fake_mode=self.export_within_fake_mode,
1207+
model_type=self.model_type,
11511208
)
11521209

1210+
@pytorch_test_common.skip_dynamic_fx_test(
1211+
"AssertionError: Dynamic shape check failed for graph inputs",
1212+
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1213+
)
11531214
def test_fake_tensor_mode_huggingface_google_t5(self):
11541215
config = transformers.T5Config(
11551216
vocab_size=8096, d_model=64, num_layers=2, num_heads=2
@@ -1179,8 +1240,13 @@ def create_model():
11791240
create_kwargs,
11801241
load_checkpoint_during_init=self.load_checkpoint_during_init,
11811242
export_within_fake_mode=self.export_within_fake_mode,
1243+
model_type=self.model_type,
11821244
)
11831245

1246+
@pytorch_test_common.skip_dynamic_fx_test(
1247+
"AssertionError: Dynamic shape check failed for graph inputs",
1248+
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1249+
)
11841250
def test_fake_tensor_mode_huggingface_openai_whisper(self):
11851251
config = transformers.WhisperConfig(
11861252
vocab_size=8096,
@@ -1231,6 +1297,7 @@ def create_kwargs():
12311297
create_kwargs,
12321298
load_checkpoint_during_init=self.load_checkpoint_during_init,
12331299
export_within_fake_mode=self.export_within_fake_mode,
1300+
model_type=self.model_type,
12341301
)
12351302

12361303
@pytorch_test_common.xfail(
@@ -1260,6 +1327,7 @@ def create_model():
12601327
create_kwargs,
12611328
load_checkpoint_during_init=self.load_checkpoint_during_init,
12621329
export_within_fake_mode=self.export_within_fake_mode,
1330+
model_type=self.model_type,
12631331
)
12641332

12651333
@pytorch_test_common.skip_dynamic_fx_test(
@@ -1287,6 +1355,7 @@ def create_model():
12871355
create_kwargs,
12881356
load_checkpoint_during_init=self.load_checkpoint_during_init,
12891357
export_within_fake_mode=self.export_within_fake_mode,
1358+
model_type=self.model_type,
12901359
)
12911360

12921361

test/onnx/torch_export/test_torch_export_with_onnxruntime.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@ def _compare_onnx_and_torch_exported_program(
3131
# NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict.
3232
# Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict.
3333
# Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__()
34-
onnx_outputs = onnx_exported_program(*input_args, **input_kwargs)
34+
onnx_outputs = onnx_exported_program(
35+
*input_args, model=torch_exported_program, **input_kwargs
36+
)
3537
torch_outputs = torch_exported_program(*input_args, **input_kwargs)
3638
torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx(
37-
torch_outputs
39+
torch_exported_program, torch_outputs
3840
)
3941
if len(torch_outputs_onnx_format) != len(onnx_outputs):
4042
raise AssertionError(

0 commit comments

Comments
 (0)