Skip to content

Commit 29c30b1

Browse files
angelayipytorchmergebot
authored andcommitted
[export] Fix serialize nn_module_stack (#104721)
Summary: Some serialized nn_module_stacks contain nested commas, something like: `(getitem(L['module'],0),torch.nn.modules.linear.Linear)` Fixing the parsing so that we can deserialize the string in the format of: `(local identifier, module type)` Test Plan: CI Differential Revision: D47252881 Pull Request resolved: #104721 Approved by: https://github.com/zhxchen17
1 parent 6a3d5f1 commit 29c30b1

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

torch/_export/serde/serialize.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,21 @@ def deserialize_meta_func(serialized_target: str):
236236
key = kv[:key_idx]
237237
assert kv[key_idx + 1] == "("
238238
assert kv[-1] == ")"
239-
values = kv[key_idx + 2: -1].split(",")
240-
assert len(values) == 2
241-
module = deserialize_meta_func(values[1])
242-
nn_module_stack[key] = (values[0], module)
239+
240+
paren = 0
241+
comma_idx = None
242+
for i, c in enumerate(kv[key_idx + 2:-1]):
243+
if c == "," and paren == 0:
244+
assert comma_idx is None
245+
comma_idx = i + key_idx + 2
246+
elif c == "(":
247+
paren += 1
248+
elif c == ")":
249+
paren -= 1
250+
251+
assert comma_idx is not None
252+
module = deserialize_meta_func(kv[comma_idx + 1:-1])
253+
nn_module_stack[key] = (kv[key_idx + 2:comma_idx], module)
243254
ret["nn_module_stack"] = nn_module_stack
244255

245256
if source_fn_str := metadata.get("source_fn"):

0 commit comments

Comments
 (0)