def lambdify(
expr: Union[sp.Basic, List[sp.Basic]],
models: Optional[Union[arch.Arch, Tuple[arch.Arch, ...]]] = None,
extra_parameters: Optional[Sequence[paddle.Tensor]] = None,
graph_filename: Optional[str] = None,
create_graph: bool = True,
retain_graph: Optional[bool] = None,
fuse_derivative: bool = False,
) -> Union[ComposedNode, List[ComposedNode]]:
"""Convert sympy expression to callable function.
Args:
expr (Union[sp.Basic, List[sp.Basic]]): Sympy expression(s) to be converted.
Will return callable functions in list if multiple expressions are given,
else return one single callable function.
models (Optional[Union[arch.Arch, Tuple[arch.Arch, ...]]]): Model(s) for
computing forward result in `LayerNode`.
extra_parameters (Optional[nn.ParameterList]): Extra learnable parameters.
Defaults to None.
graph_filename (Optional[str]): Save computational graph to `graph_filename.png`
for given `expr`, if `graph_filename` is not None and a valid string,
such as 'momentum_x'. Defaults to None.
create_graph (bool, optional): Whether to create the gradient graphs of
the computing process. When it is True, higher order derivatives are
supported to compute. When it is False, the gradient graphs of the
computing process would be discarded. Defaults to True.
retain_graph (Optional[bool]): Whether to retain the forward graph which
is used to calculate the gradient. When it is True, the graph would
be retained, in which way users can calculate backward twice for the
same graph. When it is False, the graph would be freed. Defaults to None,
which means it is equal to `create_graph`.
fuse_derivative (bool, optional): Whether to fuse the derivative nodes.
For example, if `expr` is 'Derivative(u, x) + Derivative(u, y)'
It will compute grad(u, x) + grad(u, y) if fuse_derivative=False,
else will compute sum(grad(u, [x, y])) if fuse_derivative=True as is more
efficient in backward-graph. Defaults to False, as it is experimental so not
enabled by default if used independently.
Returns:
Union[ComposedNode, List[ComposedNode]]: Callable object(s) for computing expr
with necessary input(s) data in dict given.
Examples:
>>> import paddle
>>> import ppsci
>>> import sympy as sp
>>> a, b, c, x, y = sp.symbols("a b c x y")
>>> u = sp.Function("u")(x, y)
>>> v = sp.Function("v")(x, y)
>>> z = -a + b * (c ** 2) + u * v + 2.3
>>> model = ppsci.arch.MLP(("x", "y"), ("u", "v"), 4, 16)
>>> batch_size = 13
>>> a_tensor = paddle.randn([batch_size, 1])
>>> b_tensor = paddle.randn([batch_size, 1])
>>> c_tensor = paddle.randn([batch_size, 1])
>>> x_tensor = paddle.randn([batch_size, 1])
>>> y_tensor = paddle.randn([batch_size, 1])
>>> model_output_dict = model({"x": x_tensor, "y": y_tensor})
>>> u_tensor, v_tensor = model_output_dict["u"], model_output_dict["v"]
>>> z_tensor_manually = (
... -a_tensor + b_tensor * (c_tensor ** 2)
... + u_tensor * v_tensor + 2.3
... )
>>> z_tensor_sympy = ppsci.lambdify(z, model)(
... {
... "a": a_tensor,
... "b": b_tensor,
... "c": c_tensor,
... "x": x_tensor,
... "y": y_tensor,
... }
... )
>>> paddle.allclose(z_tensor_manually, z_tensor_sympy).item()
True
"""
if not extra_parameters:
extra_parameters = ()
if isinstance(models, arch.ModelList):
models = tuple(models.model_list[i] for i in range(len(models.model_list)))
if not isinstance(models, (tuple, list)):
models = (models,)
def _expr_to_callable_nodes(
single_expr: sp.Basic, graph_filename_: Optional[str] = None
) -> List[Node]:
"""Convert sympy expression to a sequence of nodes in topologic order.
Args:
single_expr (sp.Basic): Single sympy expression, such as "a+b*c".
graph_filename_ (Optional[str]): Save computational graph to
`/path/to/graph_filename.png` for given `expr`, if `graph_filename` is not
None and a valid string, such as 'momentum_x'. Defaults to None.
Returns:
List[Node]: Sequence of callable nodes.
"""
# NOTE: Those simplify methods may complicate given expr instead, so not use here
# simplify expression to reduce nodes in tree
# expr = sp.nsimplify(expr)
# expr = sp.expand(expr)
# expr = sp.simplify(expr)
# remove 1.0 from sympy expression tree
single_expr = single_expr.subs(1.0, 1)
# convert sympy expression tree to list of nodes in post-order
sympy_nodes: List[sp.Basic] = []
sympy_nodes = _post_traverse(single_expr, sympy_nodes)
# remove unnecessary symbol nodes already in input dict(except for parameter symbol)
_parameter_names = tuple(param.name for param in extra_parameters)
sympy_nodes = [
node
for node in sympy_nodes
if (not node.is_Symbol) or (_cvt_to_key(node) in _parameter_names)
]
# remove duplicated node(s) with topological order kept
sympy_nodes = list(dict.fromkeys(sympy_nodes))
# convert sympy node to callable node
callable_nodes = []
for i, node in enumerate(sympy_nodes):
if isinstance(
node, tuple(SYMPY_TO_PADDLE.keys()) + (sp.Add, sp.Mul, sp.Derivative)
):
if isinstance(node, sp.Derivative):
callable_nodes.append(
DerivativeNode(node, create_graph, retain_graph)
)
else:
callable_nodes.append(OperatorNode(node))
elif isinstance(node, sp.Function):
if str(node.func) == equation.DETACH_FUNC_NAME:
callable_nodes.append(DetachNode(node))
logger.debug(f"Detected detach node {node}")
else:
match_index = None
for j, model in enumerate(models):
if str(node.func) in model.output_keys:
callable_nodes.append(
LayerNode(
node,
model,
)
)
if match_index is not None:
raise ValueError(
f"Name of function: '{node}' should be unique along given"
f" models, but got same output_key: '{str(node.func)}' "
f"in given models[{match_index}] and models[{j}]."
)
match_index = j
# NOTE: Skip 'sdf' function, which should be already generated in
# given data_dict
if match_index is None and str(node.func) != "sdf":
raise ValueError(
f"Node {node} can not match any model in given model(s)."
)
elif node.is_Number or node.is_NumberSymbol:
callable_nodes.append(ConstantNode(node))
elif isinstance(node, sp.Symbol):
callable_nodes.append(
ParameterNode(
node,
*[
param
for param in extra_parameters
if param.name == node.name
],
)
)
else:
raise NotImplementedError(
f"The node {node} is not supported in lambdify."
)
# NOTE: visualize computational graph using 'pygraphviz'
if isinstance(graph_filename, str):
_visualize_graph(sympy_nodes, os.path.join(graph_filename, graph_filename_))
return callable_nodes
if isinstance(expr, sp.Basic):
callable_nodes_group = [_expr_to_callable_nodes(expr, "expr")]
else:
callable_nodes_group = [
_expr_to_callable_nodes(expr_i, f"expr_{i}")
for i, expr_i in enumerate(expr)
]
# [Optional] Fused derivatives nodes that with same function to be differentiated
while fuse_derivative:
candidate_pos: List[Tuple[int, int]] = [] # [(group_id, node_id), ...]
# use 4-nested for-loop to find all potential mergable derivative nodes
for i in range(len(callable_nodes_group)):
for j in range(len(callable_nodes_group[i])):
# skip non-derivative node
if not isinstance(callable_nodes_group[i][j], DerivativeNode):
continue
# skip sdf function since it is always already given in data_dict
if callable_nodes_group[i][j].expr.args[0].name == "sdf":
continue
# skip merged node
if callable_nodes_group[i][j].merged:
continue
candidate_pos = [[i, j]]
for ii in range(len(callable_nodes_group)):
for jj in range(len(callable_nodes_group[ii])):
# skip non-derivative node
if not isinstance(callable_nodes_group[ii][jj], DerivativeNode):
continue
# skip same node
if i == ii and j == jj:
continue
# skip merged node
if callable_nodes_group[ii][jj].merged:
continue
# has same function item
if (
callable_nodes_group[i][j].expr.args[0]
== callable_nodes_group[ii][jj].expr.args[0]
):
candidate_pos.append([ii, jj])
if len(candidate_pos) > 1:
break
if len(candidate_pos) > 1:
break
# merge all candidate nodes into one or more FusedDerivativeNode node
if len(candidate_pos) > 1:
fused_node_seq = _fuse_derivative_nodes(
[callable_nodes_group[gid][nid].expr for gid, nid in candidate_pos]
)
assert isinstance(
fused_node_seq, list
), "'fused_node_seq' should be list of 'FusedDerivativeNode'"
gid0, nid0 = candidate_pos[0]
logger.debug(
f"Fused {len(candidate_pos)} derivatives nodes: "
f"{[callable_nodes_group[i][j].expr for i, j in candidate_pos]} into"
f" {len(fused_node_seq)} fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])"
)
# mark merged node
for i, (gid, nid) in enumerate(candidate_pos):
assert isinstance(callable_nodes_group[gid][nid], DerivativeNode)
callable_nodes_group[gid][nid].merged = True
# replace first mergable node with fused node sequence(packed in list)
# then mask the rest merged node to None(except [gid0, nid0])
for i, (gid, nid) in enumerate(candidate_pos[1:]):
# keep the end node of each group to avoid generating empty callable
# node sequence, this will not effect performance since cache strategy
# in Node.forward
if nid != len(callable_nodes_group[gid]) - 1:
callable_nodes_group[gid][nid] = None
if nid0 == len(callable_nodes_group[gid0]) - 1:
callable_nodes_group[gid0].insert(nid0, fused_node_seq)
else:
callable_nodes_group[gid0][nid0] = fused_node_seq
# re-organize callable_nodes_group, remove None element and unpack list
for i in range(len(callable_nodes_group)):
tmp = []
for j in range(len(callable_nodes_group[i])):
if isinstance(
callable_nodes_group[i][j], (Node, FusedDerivativeNode)
):
tmp.append(callable_nodes_group[i][j])
elif isinstance(callable_nodes_group[i][j], list) and isinstance(
callable_nodes_group[i][j][0], FusedDerivativeNode
):
tmp.extend(callable_nodes_group[i][j])
else:
assert (
callable_nodes_group[i][j] is None
), f"Unexpected element: {callable_nodes_group[i][j]}"
callable_nodes_group[i] = tmp
else:
# exit while loop if no more fused
break
# Compose callable nodes into one callable object
if isinstance(expr, sp.Basic):
return ComposedNode(callable_nodes_group[0])
else:
return [ComposedNode(callable_nodes) for callable_nodes in callable_nodes_group]