Skip to content

(torch) Use pure-python implementation of tree when in dynamo context #18614

Open
@kiukchung

Description

@kiukchung

Motivation

any_symbolic_tensors is called on pretty much all ops and it internally uses tree.flatten() to check if any positional or keyword arguments to the op is a KerasTensor (e.g. symbolic tensor). torchdynamo skips tracing tree (presumably since it has C-bindings) therefore causing graph-breaks at each op. This results in poor jitted performance for the pytorch backend since the graph-breaks occur between each op and we lose opportunities for any significant compiler optimizations (e.g. operator fusion).

See #18569 for more details.

Proposal

NOTE: making any_symbolic_tensors() does not guarantee everything in keras will be dynamo compatible. Once we fix this other issues may arise.

  1. Povide a dynamo traceable pure-python version of tree.flatten() and use that instead of tree.flatten() to prevent graph-breaks at any_symbolic_tensors().
  2. If 1) is not enough, that is, we now observe graph breaks (albeit not as frequent) due to other usages of tree.* then (as suggested by @fchollet in (re)enable torch.compile in the pytorch trainer for train, predict, and eval #18569) we need to create a keras.utils.tree that uses pure-python implementations when in dynamo context and replace usages of tree.* with keras.utils.tree.*.

My suggestion is to first to 1), then see if 2) is needed as 2) is a bigger change that we may not actually need.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions