Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 65 additions & 38 deletions model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Mapping, Sequence
from typing import Any, Dict, NamedTuple, Tuple

import jax
from jax import tree_util as jax_tree_util
from orbax.experimental.model import core as obm
from orbax.experimental.model.tf2obm import tf_concrete_function_handle_pb2
Expand All @@ -30,7 +31,6 @@
)
TF_CONCRETE_FUNCTION_HANDLE_VERSION = '0.0.1'

_INPUT_NAME_PREFIX = 'input'
_OUTPUT_NAME_PREFIX = 'output'


Expand All @@ -46,9 +46,6 @@ def is_args_kwargs_pattern(tree: utils.TfSignature) -> bool:
)


_NamesAndSequence = Tuple[Sequence[str], Sequence[Any]]


def tf_concrete_function_name_to_obm_function(
name: str,
*,
Expand Down Expand Up @@ -88,10 +85,8 @@ def tf_concrete_function_name_to_obm_function(
input_signature = utils.get_input_signature(fn)
output_signature = utils.get_output_signature(fn)

input_names, _, _ = _get_flat_signature(input_signature, _INPUT_NAME_PREFIX)
output_names, _, _ = _get_flat_signature(
output_signature, _OUTPUT_NAME_PREFIX
)
input_names, _, _ = _flat_input_signature(fn)
output_names, _, _ = _flat_output_signature(fn)
unstructured_data = obm.manifest_pb2.UnstructuredData(
inlined_bytes=tf_concrete_function_handle_pb2.TfConcreteFunctionHandle(
fn_name=name,
Expand Down Expand Up @@ -253,33 +248,69 @@ class SignatureFlat(NamedTuple):
tree_def: jax_tree_util.PyTreeDef


# We choose to rely solely on a concrete function's TF signature to
# determine its argument names, not using any other information (such
# as the argument names in the original Python `def`, or the `name`
# field in `TensorSpec`). Currently in TF SavedModel, if a concrete
# function's TF signature is a list, SavedModel may use the argument
# names in the original Python `def` to generate a keyword-based
# version of this function (which is needed for Servomatic which only
# supports keyword-based calling conventions). We think relying on
# this SavedModel behavior is a mistake and the user should make the
# TF signature a dict instead if they want to serve the function on
# Servomatic. If we find that there are too many users relying on this
# SavedModel behavior, we can revisit the decision here.
def _get_flat_signature(
signature: utils.TfSignature, name_prefix: str
def _flat_input_signature(
fn: tf.types.experimental.ConcreteFunction,
) -> SignatureFlat:
"""Gets the flattened signature.
"""Returns the flattened input signature of the given function."""
leaves, tree_def = jax_tree_util.tree_flatten(utils.get_input_signature(fn))
# The argument names in SavedModel's SignatureDef may not match the names in
# the input signature due to internal name mangling, hence we're looking
# it up in the FunctionDef.
input_names = [arg.name for arg in fn.function_def.signature.input_arg]
if len(input_names) < len(leaves):
# There could be more arguments in the FunctionDef than in the input
# signature, because it also contains the captured inputs appended
# to the flattened list of the input arguments.
raise ValueError(
f'The number of input arguments in SignatureDef ({len(input_names)}) is'
' smaller than the number of leaves in the flattened input signature'
f' ({len(leaves)})'
)
return SignatureFlat(input_names[: len(leaves)], leaves, tree_def)

Args:
signature: The TF signature.
name_prefix: The prefix for generating names.

Returns:
A SignatureFlat object `(names, leaves, treedef)`.
"""
leaves, tree_def = jax_tree_util.tree_flatten(signature)
names = tuple(f'{name_prefix}_{i}' for i in range(len(leaves)))
return SignatureFlat(names, leaves, tree_def)
def _output_name(path: Sequence[Any]) -> str:
"""Returns the output name based on its path in the output signature."""
format_error = ValueError(
'Invalid output format. TF function output must be a single tensor or'
' a list of tensors, or a dict of tensors with string keys.'
)
if not path:
# Scalar return value (single tensor).
return f'{_OUTPUT_NAME_PREFIX}_0'

if len(path) > 1:
# Too many levels of nesting, not compatible with TF output format.
raise format_error

key = path[0]
if isinstance(key, jax_tree_util.SequenceKey):
return f'{_OUTPUT_NAME_PREFIX}_{key.idx}'
elif isinstance(key, jax_tree_util.DictKey):
# The order is stable as guaranteed by `jax.tree.flatten`.
return f'{key.key}'
elif isinstance(key, jax_tree_util.GetAttrKey):
return f'{key.name}'
else:
raise format_error


def _flat_output_signature(
fn: tf.types.experimental.ConcreteFunction,
) -> SignatureFlat:
"""Returns the flattened output signature of the given function."""
# Here we leverage the TF requirement that outputs of the function
# used as the signature must be a single tensor, or a list of tensors,
# or dict of tensors with string keys (potentially a flattened dataclass).
leaves_with_path, tree_def = jax.tree.flatten_with_path(
utils.get_output_signature(fn)
)
paths, leaves = zip(*leaves_with_path)
print('paths: ', paths)
print('leaves: ', leaves)

output_names = [_output_name(path) for path in paths]
return SignatureFlat(output_names, leaves, tree_def)


def to_keyword_only_fn(
Expand All @@ -293,12 +324,8 @@ def to_keyword_only_fn(
Returns:
The wrapped function (also a TF concrete function).
"""
input_names, input_leaves, input_def = _get_flat_signature(
utils.get_input_signature(f), _INPUT_NAME_PREFIX
)
output_names, _, _ = _get_flat_signature(
utils.get_output_signature(f), _OUTPUT_NAME_PREFIX
)
input_names, input_leaves, input_def = _flat_input_signature(f)
output_names, _, _ = _flat_output_signature(f)

if input_names is None and output_names is None:
return f
Expand Down
Loading
Loading