Description
Motivation
any_symbolic_tensors
is called on pretty much all ops and it internally uses tree.flatten()
to check if any positional or keyword arguments to the op is a KerasTensor
(e.g. symbolic tensor). torchdynamo skips tracing tree
(presumably since it has C-bindings) therefore causing graph-breaks at each op. This results in poor jitted performance for the pytorch backend since the graph-breaks occur between each op and we lose opportunities for any significant compiler optimizations (e.g. operator fusion).
See #18569 for more details.
Proposal
NOTE: making
any_symbolic_tensors()
does not guarantee everything in keras will be dynamo compatible. Once we fix this other issues may arise.
- Povide a dynamo traceable pure-python version of
tree.flatten()
and use that instead oftree.flatten()
to prevent graph-breaks atany_symbolic_tensors()
. - If 1) is not enough, that is, we now observe graph breaks (albeit not as frequent) due to other usages of
tree.*
then (as suggested by @fchollet in (re)enable torch.compile in the pytorch trainer for train, predict, and eval #18569) we need to create akeras.utils.tree
that uses pure-python implementations when in dynamo context and replace usages oftree.*
withkeras.utils.tree.*
.
My suggestion is to first to 1), then see if 2) is needed as 2) is a bigger change that we may not actually need.