Skip to content

Commit b78f1e3

Browse files
committed
make etrecord support export program
Differential Revision: [D77965102](https://our.internmc.facebook.com/intern/diff/D77965102/) ghstack-source-id: 294973706 Pull Request resolved: #12288
1 parent ed9c4de commit b78f1e3

File tree

5 files changed

+120
-16
lines changed

5 files changed

+120
-16
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class StrEnum(str, Enum):
4545

4646
class ETRecordReservedFileNames(StrEnum):
4747
ETRECORD_IDENTIFIER = "ETRECORD_V0"
48+
EXPORTED_PROGRAM = "exported_program"
4849
EDGE_DIALECT_EXPORTED_PROGRAM = "edge_dialect_exported_program"
4950
ET_DIALECT_GRAPH_MODULE = "et_dialect_graph_module"
5051
DEBUG_HANDLE_MAP_NAME = "debug_handle_map"
@@ -55,6 +56,7 @@ class ETRecordReservedFileNames(StrEnum):
5556

5657
@dataclass
5758
class ETRecord:
59+
exported_program: Optional[ExportedProgram] = None
5860
edge_dialect_program: Optional[ExportedProgram] = None
5961
graph_map: Optional[Dict[str, ExportedProgram]] = None
6062
_debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None
@@ -71,17 +73,20 @@ def _handle_exported_program(
7173
assert isinstance(ep, ExportedProgram)
7274
serialized_artifact = serialize(ep)
7375
assert isinstance(serialized_artifact.exported_program, bytes)
76+
77+
method_name = f"/{method_name}" if method_name != "" else ""
78+
7479
etrecord_zip.writestr(
75-
f"{module_name}/{method_name}", serialized_artifact.exported_program
80+
f"{module_name}{method_name}", serialized_artifact.exported_program
7681
)
7782
etrecord_zip.writestr(
78-
f"{module_name}/{method_name}_state_dict", serialized_artifact.state_dict
83+
f"{module_name}{method_name}_state_dict", serialized_artifact.state_dict
7984
)
8085
etrecord_zip.writestr(
81-
f"{module_name}/{method_name}_constants", serialized_artifact.constants
86+
f"{module_name}{method_name}_constants", serialized_artifact.constants
8287
)
8388
etrecord_zip.writestr(
84-
f"{module_name}/{method_name}_example_inputs",
89+
f"{module_name}{method_name}_example_inputs",
8590
serialized_artifact.example_inputs,
8691
)
8792

@@ -188,7 +193,10 @@ def generate_etrecord(
188193
ExecutorchProgramManager,
189194
BundledProgram,
190195
],
191-
export_modules: Optional[
196+
exported_program: Optional[
197+
Union[ExportedProgram, Dict[str, ExportedProgram]]
198+
] = None,
199+
extra_recorded_export_modules: Optional[
192200
Dict[
193201
str,
194202
Union[
@@ -202,7 +210,7 @@ def generate_etrecord(
202210
"""
203211
Generates an `ETRecord` from the given objects, serializes it and saves it to the given path.
204212
The objects that will be serialized to an `ETRecord` are all the graph modules present
205-
in the `export_modules` dict, the graph module present in the edge dialect program object,
213+
in the `extra_recorded_export_modules` dict, the graph module present in the edge dialect program object,
206214
and also the graph module present in the ExecuTorch program object, which
207215
is the closest graph module representation of what is eventually run on the device.
208216
In addition to all the graph modules, we also serialize the program buffer, which the users
@@ -213,7 +221,8 @@ def generate_etrecord(
213221
et_record: Path to where the `ETRecord` file will be saved to.
214222
edge_dialect_program: `EdgeProgramManager` for this model returned by the call to to_edge()
215223
executorch_program: The ExecuTorch program for this model returned by the call to `to_executorch()` or the `BundledProgram` of this model
216-
export_modules [Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the
224+
exported_program: Optional graph module for this model returned by the call to `torch.export` from nn.Module.
225+
extra_recorded_export_modules [Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the
217226
value being the corresponding exported module. The exported graph modules can be either the
218227
output of `torch.export()` or `exir.to_edge()`.
219228
@@ -229,15 +238,28 @@ def generate_etrecord(
229238
# is an etrecord when it's used later in the Developer Tools.
230239
etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "")
231240

232-
if export_modules is not None:
233-
for module_name, export_module in export_modules.items():
241+
if exported_program is not None:
242+
# If multiple exported programs are provided, only saved forward method
243+
if isinstance(exported_program, dict) and "forward" in exported_program:
244+
exported_program = exported_program["forward"]
245+
246+
if isinstance(exported_program, ExportedProgram):
247+
_handle_exported_program(
248+
etrecord_zip,
249+
ETRecordReservedFileNames.EXPORTED_PROGRAM,
250+
"",
251+
exported_program,
252+
)
253+
254+
if extra_recorded_export_modules is not None:
255+
for module_name, export_module in extra_recorded_export_modules.items():
234256
contains_reserved_name = any(
235257
reserved_name in module_name
236258
for reserved_name in ETRecordReservedFileNames
237259
)
238260
if contains_reserved_name:
239261
raise RuntimeError(
240-
f"The name {module_name} provided in the export_modules dict is a reserved name in the ETRecord namespace."
262+
f"The name {module_name} provided in the extra_recorded_export_modules dict is a reserved name in the ETRecord namespace."
241263
)
242264
_handle_export_module(etrecord_zip, export_module, module_name)
243265

@@ -318,6 +340,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
318340
graph_map: Dict[str, ExportedProgram] = {}
319341
debug_handle_map = None
320342
delegate_map = None
343+
exported_program = None
321344
edge_dialect_program = None
322345
reference_outputs = None
323346
representative_inputs = None
@@ -347,6 +370,14 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
347370
etrecord_zip.read(f"{entry}_example_inputs"),
348371
)
349372
edge_dialect_program = deserialize(serialized_artifact)
373+
elif entry == ETRecordReservedFileNames.EXPORTED_PROGRAM:
374+
serialized_artifact = SerializedArtifact(
375+
etrecord_zip.read(ETRecordReservedFileNames.EXPORTED_PROGRAM),
376+
etrecord_zip.read(f"{entry}_state_dict"),
377+
etrecord_zip.read(f"{entry}_constants"),
378+
etrecord_zip.read(f"{entry}_example_inputs"),
379+
)
380+
exported_program = deserialize(serialized_artifact)
350381
elif entry == ETRecordReservedFileNames.REFERENCE_OUTPUTS:
351382
# @lint-ignore PYTHONPICKLEISBAD
352383
reference_outputs = pickle.loads(
@@ -383,6 +414,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
383414
graph_map[serialized_file] = deserialize(serialized_artifact)
384415

385416
return ETRecord(
417+
exported_program=exported_program,
386418
edge_dialect_program=edge_dialect_program,
387419
graph_map=graph_map,
388420
_debug_handle_map=debug_handle_map,

devtools/etrecord/tests/etrecord_test.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,13 @@ def test_etrecord_generation(self):
100100
tmpdirname + "/etrecord.bin",
101101
edge_output,
102102
et_output,
103-
{
103+
extra_recorded_export_modules={
104104
"aten_dialect_output": captured_output,
105105
},
106106
)
107107

108108
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
109+
109110
self.check_graph_closeness(
110111
etrecord.graph_map["aten_dialect_output/forward"],
111112
captured_output.exported_program.graph_module,
@@ -184,7 +185,7 @@ def test_etrecord_invalid_input(self):
184185
tmpdirname + "/etrecord.bin",
185186
edge_output,
186187
et_output,
187-
{"fail_test_case": et_output},
188+
extra_recorded_export_modules={"fail_test_case": et_output},
188189
)
189190

190191
def test_etrecord_reserved_name(self):
@@ -196,5 +197,76 @@ def test_etrecord_reserved_name(self):
196197
tmpdirname + "/etrecord.bin",
197198
edge_output,
198199
et_output,
199-
{reserved_name: captured_output.exported_program.graph_module},
200+
extra_recorded_export_modules={
201+
reserved_name: captured_output.exported_program.graph_module
202+
},
200203
)
204+
205+
def test_etrecord_generation_with_exported_program(self):
206+
"""Test that exported program can be recorded and parsed back correctly."""
207+
captured_output, edge_output, et_output = self.get_test_model()
208+
original_exported_program = captured_output.exported_program
209+
210+
with tempfile.TemporaryDirectory() as tmpdirname:
211+
# Generate ETRecord with exported program
212+
generate_etrecord(
213+
tmpdirname + "/etrecord.bin",
214+
edge_output,
215+
et_output,
216+
exported_program=original_exported_program,
217+
)
218+
219+
# Parse ETRecord back
220+
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
221+
222+
# Validate that the parsed exported program matches the original
223+
self.assertIsNotNone(etrecord.exported_program)
224+
self.check_graph_closeness(
225+
etrecord.exported_program,
226+
original_exported_program.graph_module,
227+
)
228+
229+
# Validate other components are still present
230+
self.check_graph_closeness(
231+
etrecord.edge_dialect_program,
232+
edge_output.exported_program.graph_module,
233+
)
234+
self.assertEqual(
235+
etrecord._debug_handle_map,
236+
json.loads(json.dumps(et_output.debug_handle_map)),
237+
)
238+
239+
def test_etrecord_generation_with_exported_program_dict(self):
240+
"""Test that exported program dictionary can be recorded and parsed back correctly."""
241+
captured_output, edge_output, et_output = self.get_test_model()
242+
original_exported_program = captured_output.exported_program
243+
exported_program_dict = {"forward": original_exported_program}
244+
245+
with tempfile.TemporaryDirectory() as tmpdirname:
246+
# Generate ETRecord with exported program dictionary
247+
generate_etrecord(
248+
tmpdirname + "/etrecord.bin",
249+
edge_output,
250+
et_output,
251+
exported_program=exported_program_dict,
252+
)
253+
254+
# Parse ETRecord back
255+
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
256+
257+
# Validate that the parsed exported program matches the original
258+
self.assertIsNotNone(etrecord.exported_program)
259+
self.check_graph_closeness(
260+
etrecord.exported_program,
261+
original_exported_program.graph_module,
262+
)
263+
264+
# Validate other components are still present
265+
self.check_graph_closeness(
266+
etrecord.edge_dialect_program,
267+
edge_output.exported_program.graph_module,
268+
)
269+
self.assertEqual(
270+
etrecord._debug_handle_map,
271+
json.loads(json.dumps(et_output.debug_handle_map)),
272+
)

devtools/inspector/tests/inspector_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def test_inspector_get_exported_program(self):
327327
tmpdirname + "/etrecord.bin",
328328
edge_output,
329329
et_output,
330-
{
330+
extra_recorded_export_modules={
331331
"aten_dialect_output": captured_output,
332332
},
333333
)

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_gen_graphs_from_etrecord(self):
5252
tmpdirname + "/etrecord.bin",
5353
edge_output,
5454
et_output,
55-
{
55+
extra_recorded_export_modules={
5656
"aten_dialect_output": captured_output,
5757
},
5858
)

examples/devtools/scripts/gen_sample_etrecord.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def gen_etrecord(model: torch.nn.Module, inputs: Any, output_path=None):
4141
(DEFAULT_OUTPUT_PATH if not output_path else output_path),
4242
edge_dialect_program=edge_program,
4343
executorch_program=et_program,
44-
export_modules={
44+
extra_recorded_export_modules={
4545
"aten_dialect_output": aten_dialect,
4646
},
4747
)

0 commit comments

Comments
 (0)