Skip to content

[Feature&Fix] Add sympy_to_func #505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Aug 30, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refine sym_to_func.py
  • Loading branch information
HydrogenSulfate committed Aug 26, 2023
commit 05cf6bb31067c12c813ac63c2ce61544f9ebb827
48 changes: 37 additions & 11 deletions ppsci/utils/sym_to_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,15 @@ class LayerNode(Node):
model (nn.Layer): NN model for computing forward result in this node.
"""

def __init__(self, expr: sp.core.function.UndefinedFunction, model: nn.Layer):
def __init__(
self,
expr: sp.core.function.UndefinedFunction,
model: nn.Layer,
detach_keys: Optional[Tuple[str, ...]] = None,
):
super().__init__(expr)
self.model = model
self.detach_keys = detach_keys

def forward(self, data_dict: Dict):
# use cache
Expand All @@ -206,6 +212,11 @@ def forward(self, data_dict: Dict):
output_dict = self.model(data_dict)
data_dict.update(output_dict)

# detach Tensor(s) if specified
if self.detach_keys:
for key in self.detach_keys:
data_dict[key] = data_dict[key].detach()

return data_dict


Expand Down Expand Up @@ -292,6 +303,7 @@ def _post_traverse(cur_node: sp.Basic, nodes: List[sp.Basic]) -> List[sp.Basic]:
def sympy_to_function(
expr: sp.Expr,
models: Optional[Union[arch.Arch, Tuple[arch.Arch, ...]]] = None,
detach_keys: Tuple[str, ...] = None,
) -> ComposedNode:
"""Convert sympy expression to callable function.

Expand Down Expand Up @@ -343,10 +355,11 @@ def sympy_to_function(
True
"""

# NOTE: Those simplify methods seem complicate given expr instead, so not use them here
# simplify expression to reduce nodes in tree
expr = sp.nsimplify(expr)
expr = sp.expand(expr)
expr = sp.simplify(expr)
# expr = sp.nsimplify(expr)
# expr = sp.expand(expr)
# expr = sp.simplify(expr)

# convert sympy expression tree to list of nodes in postorder
sympy_nodes = []
Expand All @@ -358,21 +371,34 @@ def sympy_to_function(
# remove duplicates with topo-order kept
sympy_nodes = list(dict.fromkeys(sympy_nodes))

# convert sympy node to callable node
if not isinstance(models, (tuple, list)):
models = (models,)
if detach_keys is None:
detach_keys = ()

# convert sympy node to callable node
callable_nodes = []
for i, node in enumerate(sympy_nodes):
if isinstance(node.func, sp.core.function.UndefinedFunction):
match = False
for model in models:
match_index = None
for j, model in enumerate(models):
if str(node.func.name) in model.output_keys:
callable_nodes.append(LayerNode(node, model))
if match:
callable_nodes.append(
LayerNode(
node,
model,
tuple(
key for key in detach_keys if key in model.output_keys
),
)
)
if match_index is not None:
raise ValueError(
f"Function {node} can match at least 2 output key of models, which is forbidden."
f"Name of function({node}) should be unique along given models,"
f" but got same output_key({node.func.name}) in models[{match_index}]"
f" and models[{j}]."
)
match = True
match_index = j
elif (
isinstance(node, tuple(PADDLE_FUNC_MAP.keys()))
or node.is_Add
Expand Down