diff --git a/docs/zh/api/probability.md b/docs/zh/api/probability.md index e4915291ff..4e1df7acc8 100644 --- a/docs/zh/api/probability.md +++ b/docs/zh/api/probability.md @@ -1,3 +1,5 @@ +# Probability(概率编程) 模块 + ::: ppsci.probability handler: python options: diff --git a/docs/zh/api/utils.md b/docs/zh/api/utils.md index e4bc9ac6f4..8ce7878a32 100644 --- a/docs/zh/api/utils.md +++ b/docs/zh/api/utils.md @@ -19,5 +19,6 @@ - load_checkpoint - load_pretrain - save_checkpoint + - lambdify show_root_heading: false heading_level: 3 diff --git a/examples/aneurysm/aneurysm.py b/examples/aneurysm/aneurysm.py index 922f7f4831..2d34ae0dc0 100644 --- a/examples/aneurysm/aneurysm.py +++ b/examples/aneurysm/aneurysm.py @@ -132,7 +132,7 @@ def inlet_w_ref_func(_in): ) igc_outlet = ppsci.constraint.IntegralConstraint( equation["NormalDotVec"].equations, - {"normal_dot_vel": 2.54}, + {"normal_dot_vec": 2.54}, geom["outlet_geo"], { **train_dataloader_cfg, @@ -141,12 +141,12 @@ def inlet_w_ref_func(_in): "integral_batch_size": 310, }, ppsci.loss.IntegralLoss("sum"), - weight_dict={"normal_dot_vel": 0.1}, + weight_dict={"normal_dot_vec": 0.1}, name="igc_outlet", ) igc_integral = ppsci.constraint.IntegralConstraint( equation["NormalDotVec"].equations, - {"normal_dot_vel": -2.54}, + {"normal_dot_vec": -2.54}, geom["integral_geo"], { **train_dataloader_cfg, @@ -155,7 +155,7 @@ def inlet_w_ref_func(_in): "integral_batch_size": 310, }, ppsci.loss.IntegralLoss("sum"), - weight_dict={"normal_dot_vel": 0.1}, + weight_dict={"normal_dot_vec": 0.1}, name="igc_integral", ) # wrap constraints together diff --git a/examples/bracket/bracket.py b/examples/bracket/bracket.py index ca3f237f1f..f2dfcc9c0b 100644 --- a/examples/bracket/bracket.py +++ b/examples/bracket/bracket.py @@ -127,15 +127,15 @@ support_interior_constraint = ppsci.constraint.InteriorConstraint( equation["LinearElasticity"].equations, { - "equilibrium_x": 0, - "equilibrium_y": 0, - "equilibrium_z": 0, "stress_disp_xx": 0, "stress_disp_yy": 0, "stress_disp_zz": 0, "stress_disp_xy": 0, "stress_disp_xz": 0, "stress_disp_yz": 0, + "equilibrium_x": 0, + "equilibrium_y": 0, + "equilibrium_z": 0, }, geom["geo"], {**train_dataloader_cfg, "batch_size": 2048}, @@ -149,30 +149,30 @@ & (z < BOUNDS_SUPPORT_Z[1]) ), weight_dict={ - "equilibrium_x": "sdf", - "equilibrium_y": "sdf", - "equilibrium_z": "sdf", "stress_disp_xx": "sdf", "stress_disp_yy": "sdf", "stress_disp_zz": "sdf", "stress_disp_xy": "sdf", "stress_disp_xz": "sdf", "stress_disp_yz": "sdf", + "equilibrium_x": "sdf", + "equilibrium_y": "sdf", + "equilibrium_z": "sdf", }, name="support_interior", ) bracket_interior_constraint = ppsci.constraint.InteriorConstraint( equation["LinearElasticity"].equations, { - "equilibrium_x": 0, - "equilibrium_y": 0, - "equilibrium_z": 0, "stress_disp_xx": 0, "stress_disp_yy": 0, "stress_disp_zz": 0, "stress_disp_xy": 0, "stress_disp_xz": 0, "stress_disp_yz": 0, + "equilibrium_x": 0, + "equilibrium_y": 0, + "equilibrium_z": 0, }, geom["geo"], {**train_dataloader_cfg, "batch_size": 1024}, @@ -186,15 +186,15 @@ & (z < BOUNDS_BRACKET_Z[1]) ), weight_dict={ - "equilibrium_x": "sdf", - "equilibrium_y": "sdf", - "equilibrium_z": "sdf", "stress_disp_xx": "sdf", "stress_disp_yy": "sdf", "stress_disp_zz": "sdf", "stress_disp_xy": "sdf", "stress_disp_xz": "sdf", "stress_disp_yz": "sdf", + "equilibrium_x": "sdf", + "equilibrium_y": "sdf", + "equilibrium_z": "sdf", }, name="bracket_interior", ) diff --git a/examples/laplace/laplace2d.py b/examples/laplace/laplace2d.py index 5ba1719d2a..fe8eacd3b3 100644 --- a/examples/laplace/laplace2d.py +++ b/examples/laplace/laplace2d.py @@ -29,7 +29,7 @@ EVAL_FREQ = 200 # set output directory - OUTPUT_DIR = "./output/laplace2d" if not args.output_dir else args.output_dir + OUTPUT_DIR = "./output_laplace2d" if not args.output_dir else args.output_dir logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") # set model diff --git a/examples/pipe/poiseuille_flow.py b/examples/pipe/poiseuille_flow.py index 76b28e1708..8b270bf542 100644 --- a/examples/pipe/poiseuille_flow.py +++ b/examples/pipe/poiseuille_flow.py @@ -133,9 +133,7 @@ def output_trans_p(input, out): # set euqation equation = { - "NavierStokes": ppsci.equation.NavierStokes( - nu=lambda out: out["nu"], rho=RHO, dim=2, time=False - ) + "NavierStokes": ppsci.equation.NavierStokes(nu="nu", rho=RHO, dim=2, time=False) } # set constraint diff --git a/mkdocs.yml b/mkdocs.yml index 5ee2ea9ce7..4f16b63616 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -87,6 +87,7 @@ nav: - ppsci.validate: zh/api/validate.md - ppsci.visualize: zh/api/visualize.md - ppsci.experimental: zh/api/experimental.md + - ppsci.probability: zh/api/probability.md - 使用指南: zh/user_guide.md - 开发与复现指南: - 开发指南: zh/development.md diff --git a/ppsci/__init__.py b/ppsci/__init__.py index 3877401fa8..5ddca0615d 100644 --- a/ppsci/__init__.py +++ b/ppsci/__init__.py @@ -29,6 +29,7 @@ from ppsci.utils.checker import run_check # isort:skip from ppsci.utils.checker import run_check_mesh # isort:skip +from ppsci.utils import lambdify # isort:skip __all__ = [ "arch", @@ -47,4 +48,5 @@ "experimental", "run_check", "run_check_mesh", + "lambdify", ] diff --git a/ppsci/constraint/boundary_constraint.py b/ppsci/constraint/boundary_constraint.py index 205c314283..afbad13c16 100644 --- a/ppsci/constraint/boundary_constraint.py +++ b/ppsci/constraint/boundary_constraint.py @@ -23,7 +23,6 @@ import numpy as np import sympy -from sympy.parsing import sympy_parser as sp_parser from typing_extensions import Literal from ppsci import geometry @@ -86,14 +85,12 @@ def __init__( weight_dict: Optional[Dict[str, Union[float, Callable]]] = None, name: str = "BC", ): - self.output_expr = output_expr - for label_name, expr in self.output_expr.items(): - if isinstance(expr, str): - self.output_expr[label_name] = sp_parser.parse_expr(expr) - self.label_dict = label_dict self.input_keys = geom.dim_keys - self.output_keys = list(label_dict.keys()) + self.output_keys = tuple(label_dict.keys()) + self.output_expr = { + k: v for k, v in output_expr.items() if k in self.output_keys + } # "area" will be kept in "output_dict" for computation. if isinstance(geom, geometry.Mesh): self.output_keys += ["area"] @@ -137,9 +134,6 @@ def __init__( weight = {key: np.ones_like(next(iter(label.values()))) for key in label} if weight_dict is not None: for key, value in weight_dict.items(): - if isinstance(value, str): - value = sp_parser.parse_expr(value) - if isinstance(value, (int, float)): weight[key] = np.full_like(next(iter(label.values())), value) elif isinstance(value, sympy.Basic): diff --git a/ppsci/constraint/initial_constraint.py b/ppsci/constraint/initial_constraint.py index cfcc89ec1b..351af60c74 100644 --- a/ppsci/constraint/initial_constraint.py +++ b/ppsci/constraint/initial_constraint.py @@ -23,7 +23,6 @@ import numpy as np import sympy -from sympy.parsing import sympy_parser as sp_parser from typing_extensions import Literal from ppsci import geometry @@ -89,14 +88,12 @@ def __init__( weight_dict: Optional[Dict[str, Callable]] = None, name: str = "IC", ): - self.output_expr = output_expr - for label_name, expr in self.output_expr.items(): - if isinstance(expr, str): - self.output_expr[label_name] = sp_parser.parse_expr(expr) - self.label_dict = label_dict self.input_keys = geom.dim_keys - self.output_keys = list(label_dict.keys()) + self.output_keys = tuple(label_dict.keys()) + self.output_expr = { + k: v for k, v in output_expr.items() if k in self.output_keys + } # "area" will be kept in "output_dict" for computation. if isinstance(geom.geometry, geometry.Mesh): self.output_keys += ["area"] @@ -117,8 +114,6 @@ def __init__( # prepare label label = {} for key, value in label_dict.items(): - if isinstance(value, str): - value = sp_parser.parse_expr(value) if isinstance(value, (int, float)): label[key] = np.full_like(next(iter(input.values())), value) elif isinstance(value, sympy.Basic): @@ -142,8 +137,6 @@ def __init__( weight = {key: np.ones_like(next(iter(label.values()))) for key in label} if weight_dict is not None: for key, value in weight_dict.items(): - if isinstance(value, str): - value = sp_parser.parse_expr(value) if isinstance(value, (int, float)): weight[key] = np.full_like(next(iter(label.values())), value) elif isinstance(value, sympy.Basic): diff --git a/ppsci/constraint/integral_constraint.py b/ppsci/constraint/integral_constraint.py index fedce730a5..63d8314fa3 100644 --- a/ppsci/constraint/integral_constraint.py +++ b/ppsci/constraint/integral_constraint.py @@ -24,7 +24,6 @@ import numpy as np import paddle import sympy -from sympy.parsing import sympy_parser as sp_parser from typing_extensions import Literal from ppsci import geometry @@ -86,14 +85,12 @@ def __init__( weight_dict: Optional[Dict[str, Callable]] = None, name: str = "IgC", ): - self.output_expr = output_expr - for label_name, expr in self.output_expr.items(): - if isinstance(expr, str): - self.output_expr[label_name] = sp_parser.parse_expr(expr) - self.label_dict = label_dict self.input_keys = geom.dim_keys - self.output_keys = list(label_dict.keys()) + self.output_keys = tuple(label_dict.keys()) + self.output_expr = { + k: v for k, v in output_expr.items() if k in self.output_keys + } # "area" will be kept in "output_dict" for computation. if isinstance(geom, geometry.Mesh): self.output_keys += ["area"] @@ -149,9 +146,6 @@ def __init__( weight = {key: np.ones_like(next(iter(label.values()))) for key in label} if weight_dict is not None: for key, value in weight_dict.items(): - if isinstance(value, str): - value = sp_parser.parse_expr(value) - if isinstance(value, (int, float)): weight[key] = np.full_like(next(iter(label.values())), value) elif isinstance(value, sympy.Basic): diff --git a/ppsci/constraint/interior_constraint.py b/ppsci/constraint/interior_constraint.py index 1a8474d745..a333c82db3 100644 --- a/ppsci/constraint/interior_constraint.py +++ b/ppsci/constraint/interior_constraint.py @@ -23,7 +23,6 @@ import numpy as np import sympy -from sympy.parsing import sympy_parser as sp_parser from typing_extensions import Literal from ppsci import geometry @@ -86,14 +85,12 @@ def __init__( weight_dict: Optional[Dict[str, Union[Callable, float]]] = None, name: str = "EQ", ): - self.output_expr = output_expr - for label_name, expr in self.output_expr.items(): - if isinstance(expr, str): - self.output_expr[label_name] = sp_parser.parse_expr(expr) - self.label_dict = label_dict self.input_keys = geom.dim_keys - self.output_keys = list(label_dict.keys()) + self.output_keys = tuple(label_dict.keys()) + self.output_expr = { + k: v for k, v in output_expr.items() if k in self.output_keys + } # "area" will be kept in "output_dict" for computation. if isinstance(geom, geometry.Mesh): self.output_keys += ["area"] @@ -114,8 +111,6 @@ def __init__( # prepare label label = {} for key, value in label_dict.items(): - if isinstance(value, str): - value = sp_parser.parse_expr(value) if isinstance(value, (int, float)): label[key] = np.full_like(next(iter(input.values())), value) elif isinstance(value, sympy.Basic): diff --git a/ppsci/constraint/periodic_constraint.py b/ppsci/constraint/periodic_constraint.py index cfbad8796e..6aace571ec 100644 --- a/ppsci/constraint/periodic_constraint.py +++ b/ppsci/constraint/periodic_constraint.py @@ -24,7 +24,6 @@ import numpy as np import paddle import sympy -from sympy.parsing import sympy_parser as sp_parser from typing_extensions import Literal from ppsci import geometry @@ -73,13 +72,11 @@ def __init__( weight_dict: Optional[Dict[str, Callable]] = None, name: str = "PeriodicBC", ): - self.output_expr = output_expr - for label_name, expr in self.output_expr.items(): - if isinstance(expr, str): - self.output_expr[label_name] = sp_parser.parse_expr(expr) - self.input_keys = geom.dim_keys - self.output_keys = list(output_expr.keys()) + self.output_keys = tuple(output_expr.keys()) + self.output_expr = { + k: v for k, v in output_expr.items() if k in self.output_keys + } # "area" will be kept in "output_dict" for computation. if isinstance(geom, geometry.Mesh): self.output_keys += ["area"] @@ -143,9 +140,6 @@ def __init__( weight = {key: np.ones_like(next(iter(label.values()))) for key in label} if weight_dict is not None: for key, value in weight_dict.items(): - if isinstance(value, str): - value = sp_parser.parse_expr(value) - if isinstance(value, (int, float)): weight[key] = np.full_like(next(iter(label.values())), value) elif isinstance(value, sympy.Basic): diff --git a/ppsci/constraint/supervised_constraint.py b/ppsci/constraint/supervised_constraint.py index a0c34d8bec..84b8816222 100644 --- a/ppsci/constraint/supervised_constraint.py +++ b/ppsci/constraint/supervised_constraint.py @@ -60,19 +60,20 @@ def __init__( output_expr: Optional[Dict[str, Callable]] = None, name: str = "Sup", ): - self.output_expr = output_expr - # build dataset _dataset = dataset.build_dataset(dataloader_cfg["dataset"]) self.input_keys = _dataset.input_keys self.output_keys = ( - list(output_expr.keys()) if output_expr is not None else _dataset.label_keys + tuple(output_expr.keys()) + if output_expr is not None + else _dataset.label_keys ) + self.output_expr = output_expr if self.output_expr is None: self.output_expr = { - key: lambda out, k=key: out[k] for key in self.output_keys + key: (lambda out, k=key: out[k]) for key in self.output_keys } # construct dataloader with dataset and dataloader_cfg diff --git a/ppsci/data/dataset/array_dataset.py b/ppsci/data/dataset/array_dataset.py index ad31e707b2..7f06217029 100644 --- a/ppsci/data/dataset/array_dataset.py +++ b/ppsci/data/dataset/array_dataset.py @@ -29,9 +29,9 @@ class NamedArrayDataset(io.Dataset): Args: input (Dict[str, np.ndarray]): Input dict. label (Dict[str, np.ndarray]): Label dict. - weight (Optional[Dict[str, np.ndarray]], optional): Weight dict. - transforms (Optional[vision.Compose], optional): Compose object contains sample wise - transform(s). + weight (Optional[Dict[str, np.ndarray]]): Weight dict. Defaults to None. + transforms (Optional[vision.Compose]): Compose object contains sample wise + transform(s). Defaults to None. Examples: >>> import ppsci diff --git a/ppsci/equation/__init__.py b/ppsci/equation/__init__.py index d1c95a6e6f..47526d4d60 100644 --- a/ppsci/equation/__init__.py +++ b/ppsci/equation/__init__.py @@ -16,6 +16,7 @@ from ppsci.equation.fpde import FractionalPoisson from ppsci.equation.ide import Volterra +from ppsci.equation.pde import DETACH_FUNC_NAME from ppsci.equation.pde import PDE from ppsci.equation.pde import Biharmonic from ppsci.equation.pde import Laplace @@ -29,6 +30,7 @@ __all__ = [ "PDE", + "DETACH_FUNC_NAME", "Biharmonic", "Laplace", "LinearElasticity", diff --git a/ppsci/equation/pde/__init__.py b/ppsci/equation/pde/__init__.py index 65addab794..1ff84a31a4 100644 --- a/ppsci/equation/pde/__init__.py +++ b/ppsci/equation/pde/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ppsci.equation.pde.base import DETACH_FUNC_NAME from ppsci.equation.pde.base import PDE from ppsci.equation.pde.biharmonic import Biharmonic from ppsci.equation.pde.laplace import Laplace @@ -23,6 +24,7 @@ __all__ = [ "PDE", + "DETACH_FUNC_NAME", "Biharmonic", "Laplace", "LinearElasticity", diff --git a/ppsci/equation/pde/base.py b/ppsci/equation/pde/base.py index 707f7dadbd..a9cfa599a8 100644 --- a/ppsci/equation/pde/base.py +++ b/ppsci/equation/pde/base.py @@ -17,12 +17,15 @@ from typing import Callable from typing import Dict from typing import List +from typing import Optional from typing import Tuple import paddle import sympy from paddle import nn +DETACH_FUNC_NAME = "detach" + class PDE: """Base class for Partial Differential Equation""" @@ -30,11 +33,12 @@ class PDE: def __init__(self): super().__init__() self.equations = {} - # for PDE which has learnable parameter(s) self.learnable_parameters = nn.ParameterList() - def create_symbols(self, symbol_str) -> Tuple[sympy.Symbol, ...]: + self.detach_keys: Optional[Tuple[str, ...]] = None + + def create_symbols(self, symbol_str: str) -> Tuple[sympy.Symbol, ...]: """Create symbols Args: @@ -45,7 +49,9 @@ def create_symbols(self, symbol_str) -> Tuple[sympy.Symbol, ...]: """ return sympy.symbols(symbol_str) - def create_function(self, name, invars) -> sympy.Function: + def create_function( + self, name: str, invars: Tuple[sympy.Symbol, ...] + ) -> sympy.Function: """Create named function depending on given invars. Args: @@ -55,7 +61,13 @@ def create_function(self, name, invars) -> sympy.Function: Returns: sympy.Function: Named sympy function. """ - return sympy.Function(name)(*invars) + expr = sympy.Function(name)(*invars) + + # wrap `expression(...)` to `detach(expression(...))` + # if name of expression is in given detach_keys + if self.detach_keys and name in self.detach_keys: + expr = sympy.Function(DETACH_FUNC_NAME)(expr) + return expr def add_equation(self, name: str, equation: Callable): """Add an equation. diff --git a/ppsci/equation/pde/biharmonic.py b/ppsci/equation/pde/biharmonic.py index e4c6db997c..8c79651a1c 100644 --- a/ppsci/equation/pde/biharmonic.py +++ b/ppsci/equation/pde/biharmonic.py @@ -14,7 +14,10 @@ from __future__ import annotations -from ppsci.autodiff import hessian +from typing import Optional +from typing import Tuple +from typing import Union + from ppsci.equation.pde import base @@ -27,27 +30,41 @@ class Biharmonic(base.PDE): Args: dim (int): Dimension of equation. - q (float): Load. - D (float): Rigidity. + q (Union[float, str]): Load. + D (Union[float, str]): Rigidity. + detach_keys(Optional[Tuple[str, ...]]): Keys used for detach during computing. + Defaults to None. Examples: >>> import ppsci >>> pde = ppsci.equation.Biharmonic(2, -1.0, 1.0) """ - def __init__(self, dim: int, q: float, D: float): + def __init__( + self, + dim: int, + q: Union[float, str], + D: Union[float, str], + detach_keys: Optional[Tuple[str, ...]] = None, + ): super().__init__() + self.detach_keys = detach_keys + + invars = self.create_symbols("x y z")[:dim] + u = self.create_function("u", invars) + + if isinstance(q, str): + q = self.create_function("q", invars) + if isinstance(D, str): + D = self.create_function("D", invars) + self.dim = dim self.q = q self.D = D - def biharmonic_compute_func(out): - u = out["u"] - biharmonic = -self.q / self.D - invars = ("x", "y", "z")[: self.dim] - for invar_i in invars: - for invar_j in invars: - biharmonic += hessian(hessian(u, out[invar_i]), out[invar_j]) - return biharmonic + biharmonic = -self.q / self.D + for invar_i in invars: + for invar_j in invars: + biharmonic += u.diff(invar_i, 2).diff(invar_j, 2) - self.add_equation("biharmonic", biharmonic_compute_func) + self.add_equation("biharmonic", biharmonic) diff --git a/ppsci/equation/pde/laplace.py b/ppsci/equation/pde/laplace.py index 509b493c0c..ad63bdafbe 100644 --- a/ppsci/equation/pde/laplace.py +++ b/ppsci/equation/pde/laplace.py @@ -14,7 +14,9 @@ from __future__ import annotations -from ppsci.autodiff import hessian +from typing import Optional +from typing import Tuple + from ppsci.equation.pde import base @@ -27,23 +29,25 @@ class Laplace(base.PDE): Args: dim (int): Dimension of equation. + detach_keys(Optional[Tuple[str, ...]]): Keys used for detach during computing. + Defaults to None. Examples: >>> import ppsci >>> pde = ppsci.equation.Laplace(2) """ - def __init__(self, dim: int): + def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None): super().__init__() + self.detach_keys = detach_keys + + invars = self.create_symbols("x y z")[:dim] + u = self.create_function("u", invars) + self.dim = dim - def laplace_compute_func(out): - x, y = out["x"], out["y"] - u = out["u"] - laplace = hessian(u, x) + hessian(u, y) - if self.dim == 3: - z = out["z"] - laplace += hessian(u, z) - return laplace + laplace = 0 + for invar in invars: + laplace += u.diff(invar, 2) - self.add_equation("laplace", laplace_compute_func) + self.add_equation("laplace", laplace) diff --git a/ppsci/equation/pde/linear_elasticity.py b/ppsci/equation/pde/linear_elasticity.py index f69c61b43b..3b906207f6 100644 --- a/ppsci/equation/pde/linear_elasticity.py +++ b/ppsci/equation/pde/linear_elasticity.py @@ -15,9 +15,11 @@ from __future__ import annotations from typing import Optional +from typing import Tuple +from typing import Union + +import sympy as sp -from ppsci.autodiff import hessian -from ppsci.autodiff import jacobian from ppsci.equation.pde import base @@ -37,13 +39,15 @@ class LinearElasticity(base.PDE): $$ Args: - E (Optional[float]): The Young's modulus. Defaults to None. - nu (Optional[float]): The Poisson's ratio. Defaults to None. - lambda_ (Optional[float]): Lamé's first parameter. Defaults to None. - mu (Optional[float]): Lamé's second parameter (shear modulus). Defaults to None. - rho (float, optional): Mass density. Defaults to 1. + E (Optional[Union[float, str]]): The Young's modulus. Defaults to None. + nu (Optional[Union[float, str]]): The Poisson's ratio. Defaults to None. + lambda_ (Optional[Union[float, str]]): Lamé's first parameter. Defaults to None. + mu (Optional[Union[float, str]]): Lamé's second parameter (shear modulus). Defaults to None. + rho (Union[float, str], optional): Mass density. Defaults to 1. dim (int, optional): Dimension of the linear elasticity (2 or 3). Defaults to 3. time (bool, optional): Whether contains time data. Defaults to False. + detach_keys(Optional[Tuple[str, ...]]): Keys used for detach during computing. + Defaults to None. Examples: >>> import ppsci @@ -54,221 +58,121 @@ class LinearElasticity(base.PDE): def __init__( self, - E: Optional[float] = None, - nu: Optional[float] = None, - lambda_: Optional[float] = None, - mu: Optional[float] = None, - rho: float = 1, + E: Optional[Union[float, str]] = None, + nu: Optional[Union[float, str]] = None, + lambda_: Optional[Union[float, str]] = None, + mu: Optional[Union[float, str]] = None, + rho: Union[float, str] = 1, dim: int = 3, time: bool = False, + detach_keys: Optional[Tuple[str, ...]] = None, ): super().__init__() + self.detach_keys = detach_keys + self.dim = dim + self.time = time + + t, x, y, z = self.create_symbols("t x y z") + normal_x, normal_y, normal_z = self.create_symbols("normal_x normal_y normal_z") + invars = (x, y) + if time: + invars = (t,) + invars + if self.dim == 3: + invars += (z,) + + u = self.create_function("u", invars) + v = self.create_function("v", invars) + w = self.create_function("w", invars) if dim == 3 else sp.Number(0) + + sigma_xx = self.create_function("sigma_xx", invars) + sigma_yy = self.create_function("sigma_yy", invars) + sigma_xy = self.create_function("sigma_xy", invars) + sigma_zz = ( + self.create_function("sigma_zz", invars) if dim == 3 else sp.Number(0) + ) + sigma_xz = ( + self.create_function("sigma_xz", invars) if dim == 3 else sp.Number(0) + ) + sigma_yz = ( + self.create_function("sigma_yz", invars) if dim == 3 else sp.Number(0) + ) + + # compute lambda and mu if lambda_ is None: - nu = float(nu) - E = float(E) + if isinstance(nu, str): + nu = self.create_function(nu)(invars) + if isinstance(E, str): + E = self.create_function(E)(invars) lambda_ = nu * E / ((1 + nu) * (1 - 2 * nu)) mu = E / (2 * (1 + nu)) + else: + if isinstance(lambda_, str): + lambda_ = self.create_function(lambda_)(invars) + if isinstance(mu, str): + mu = self.create_function(mu)(invars) + + if isinstance(rho, str): + rho = self.create_function(rho)(invars) self.E = E self.nu = nu self.lambda_ = lambda_ self.mu = mu self.rho = rho - self.dim = dim - self.time = time - - # Stress equations - def stress_disp_xx_compute_func(out): - x, y, u, v = ( - out["x"], - out["y"], - out["u"], - out["v"], - ) - sigma_xx = out["sigma_xx"] - stress_disp_xx = ( - self.lambda_ * (jacobian(u, x) + jacobian(v, y)) - + 2 * self.mu * jacobian(u, x) - - sigma_xx - ) - if self.dim == 3: - z, w = out["z"], out["w"] - stress_disp_xx += self.lambda_ * jacobian(w, z) - return stress_disp_xx - - self.add_equation("stress_disp_xx", stress_disp_xx_compute_func) - - def stress_disp_yy_compute_func(out): - x, y, u, v = ( - out["x"], - out["y"], - out["u"], - out["v"], - ) - sigma_yy = out["sigma_yy"] - stress_disp_yy = ( - self.lambda_ * (jacobian(u, x) + jacobian(v, y)) - + 2 * self.mu * jacobian(v, y) - - sigma_yy - ) - if self.dim == 3: - z, w = out["z"], out["w"] - stress_disp_yy += self.lambda_ * jacobian(w, z) - return stress_disp_yy - - self.add_equation("stress_disp_yy", stress_disp_yy_compute_func) - - if self.dim == 3: - - def stress_disp_zz_compute_func(out): - x, y, z, u, v, w = ( - out["x"], - out["y"], - out["z"], - out["u"], - out["v"], - out["w"], - ) - sigma_zz = out["sigma_zz"] - stress_disp_zz = ( - self.lambda_ * (jacobian(u, x) + jacobian(v, y) + jacobian(w, z)) - + 2 * self.mu * jacobian(w, z) - - sigma_zz - ) - return stress_disp_zz - - self.add_equation("stress_disp_zz", stress_disp_zz_compute_func) - - def stress_disp_xy_compute_func(out): - x, y, u, v = out["x"], out["y"], out["u"], out["v"] - sigma_xy = out["sigma_xy"] - stress_disp_xy = self.mu * (jacobian(u, y) + jacobian(v, x)) - sigma_xy - return stress_disp_xy - - self.add_equation("stress_disp_xy", stress_disp_xy_compute_func) + # compute stress equations + stress_disp_xx = ( + lambda_ * (u.diff(x) + v.diff(y) + w.diff(z)) + + 2 * mu * u.diff(x) + - sigma_xx + ) + stress_disp_yy = ( + lambda_ * (u.diff(x) + v.diff(y) + w.diff(z)) + + 2 * mu * v.diff(y) + - sigma_yy + ) + stress_disp_zz = ( + lambda_ * (u.diff(x) + v.diff(y) + w.diff(z)) + + 2 * mu * w.diff(z) + - sigma_zz + ) + stress_disp_xy = mu * (u.diff(y) + v.diff(x)) - sigma_xy + stress_disp_xz = mu * (u.diff(z) + w.diff(x)) - sigma_xz + stress_disp_yz = mu * (v.diff(z) + w.diff(y)) - sigma_yz + + # compute equilibrium equations + equilibrium_x = rho * ((u.diff(t)).diff(t)) - ( + sigma_xx.diff(x) + sigma_xy.diff(y) + sigma_xz.diff(z) + ) + equilibrium_y = rho * ((v.diff(t)).diff(t)) - ( + sigma_xy.diff(x) + sigma_yy.diff(y) + sigma_yz.diff(z) + ) + equilibrium_z = rho * ((w.diff(t)).diff(t)) - ( + sigma_xz.diff(x) + sigma_yz.diff(y) + sigma_zz.diff(z) + ) + + # compute traction equations + traction_x = normal_x * sigma_xx + normal_y * sigma_xy + normal_z * sigma_xz + traction_y = normal_x * sigma_xy + normal_y * sigma_yy + normal_z * sigma_yz + traction_z = normal_x * sigma_xz + normal_y * sigma_yz + normal_z * sigma_zz + + # add stress equations + self.add_equation("stress_disp_xx", stress_disp_xx) + self.add_equation("stress_disp_yy", stress_disp_yy) + self.add_equation("stress_disp_xy", stress_disp_xy) if self.dim == 3: + self.add_equation("stress_disp_zz", stress_disp_zz) + self.add_equation("stress_disp_xz", stress_disp_xz) + self.add_equation("stress_disp_yz", stress_disp_yz) - def stress_disp_xz_compute_func(out): - x, z, u, w = out["x"], out["z"], out["u"], out["w"] - sigma_xz = out["sigma_xz"] - stress_disp_xz = self.mu * (jacobian(u, z) + jacobian(w, x)) - sigma_xz - return stress_disp_xz - - self.add_equation("stress_disp_xz", stress_disp_xz_compute_func) - - def stress_disp_yz_compute_func(out): - y, z, v, w = out["y"], out["z"], out["v"], out["w"] - sigma_yz = out["sigma_yz"] - stress_disp_yz = self.mu * (jacobian(v, z) + jacobian(w, y)) - sigma_yz - return stress_disp_yz - - self.add_equation("stress_disp_yz", stress_disp_yz_compute_func) - - # Equations of equilibrium - def equilibrium_x_compute_func(out): - x, y = out["x"], out["y"] - sigma_xx, sigma_xy = out["sigma_xx"], out["sigma_xy"] - equilibrium_x = -jacobian(sigma_xx, x) - jacobian(sigma_xy, y) - if self.dim == 3: - z, sigma_xz = out["z"], out["sigma_xz"] - equilibrium_x -= jacobian(sigma_xz, z) - if self.time: - t, u = out["t"], out["u"] - equilibrium_x += self.rho * hessian(u, t) - return equilibrium_x - - self.add_equation("equilibrium_x", equilibrium_x_compute_func) - - def equilibrium_y_compute_func(out): - x, y = out["x"], out["y"] - sigma_xy, sigma_yy = ( - out["sigma_xy"], - out["sigma_yy"], - ) - equilibrium_y = -jacobian(sigma_xy, x) - jacobian(sigma_yy, y) - if self.dim == 3: - z, sigma_yz = out["z"], out["sigma_yz"] - equilibrium_y -= jacobian(sigma_yz, z) - if self.time: - t, v = out["t"], out["v"] - equilibrium_y += self.rho * hessian(v, t) - return equilibrium_y - - self.add_equation("equilibrium_y", equilibrium_y_compute_func) - + # add equilibrium equations + self.add_equation("equilibrium_x", equilibrium_x) + self.add_equation("equilibrium_y", equilibrium_y) if self.dim == 3: + self.add_equation("equilibrium_z", equilibrium_z) - def equilibrium_z_compute_func(out): - x, y, z = out["x"], out["y"], out["z"] - sigma_xz, sigma_yz, sigma_zz = ( - out["sigma_xz"], - out["sigma_yz"], - out["sigma_zz"], - ) - equilibrium_z = ( - -jacobian(sigma_xz, x) - - jacobian(sigma_yz, y) - - jacobian(sigma_zz, z) - ) - if self.time: - t, w = out["t"], out["w"] - equilibrium_z += self.rho * hessian(w, t) - return equilibrium_z - - self.add_equation("equilibrium_z", equilibrium_z_compute_func) - - # Traction equations - def traction_x_compute_func(out): - normal_x, normal_y = ( - out["normal_x"], - out["normal_y"], - ) - sigma_xx, sigma_xy = ( - out["sigma_xx"], - out["sigma_xy"], - ) - traction_x = normal_x * sigma_xx + normal_y * sigma_xy - if self.dim == 3: - normal_z, sigma_xz = out["normal_z"], out["sigma_xz"] - traction_x += normal_z * sigma_xz - return traction_x - - self.add_equation("traction_x", traction_x_compute_func) - - def traction_y_compute_func(out): - normal_x, normal_y = ( - out["normal_x"], - out["normal_y"], - ) - sigma_xy, sigma_yy = ( - out["sigma_xy"], - out["sigma_yy"], - ) - traction_y = normal_x * sigma_xy + normal_y * sigma_yy - if self.dim == 3: - normal_z, sigma_yz = out["normal_z"], out["sigma_yz"] - traction_y += normal_z * sigma_yz - return traction_y - - self.add_equation("traction_y", traction_y_compute_func) - + # add traction equations + self.add_equation("traction_x", traction_x) + self.add_equation("traction_y", traction_y) if self.dim == 3: - - def traction_z_compute_func(out): - normal_x, normal_y, normal_z = ( - out["normal_x"], - out["normal_y"], - out["normal_z"], - ) - sigma_xz, sigma_yz, sigma_zz = ( - out["sigma_xz"], - out["sigma_yz"], - out["sigma_zz"], - ) - traction_z = ( - normal_x * sigma_xz + normal_y * sigma_yz + normal_z * sigma_zz - ) - return traction_z - - self.add_equation("traction_z", traction_z_compute_func) + self.add_equation("traction_z", traction_z) diff --git a/ppsci/equation/pde/navier_stokes.py b/ppsci/equation/pde/navier_stokes.py index 946ed1dd40..6446e9d139 100644 --- a/ppsci/equation/pde/navier_stokes.py +++ b/ppsci/equation/pde/navier_stokes.py @@ -14,11 +14,12 @@ from __future__ import annotations -from typing import Callable +from typing import Optional +from typing import Tuple from typing import Union -from ppsci.autodiff import hessian -from ppsci.autodiff import jacobian +import sympy as sp + from ppsci.equation.pde import base @@ -53,96 +54,90 @@ class NavierStokes(base.PDE): $$ Args: - nu (Union[float, Callable]): Dynamic viscosity. - rho (float): Density. + nu (Union[float, str]): Dynamic viscosity. + rho (Union[float, str]): Density. dim (int): Dimension of equation. time (bool): Whether the euqation is time-dependent. + detach_keys(Optional[Tuple[str, ...]]): Keys used for detach during computing. + Defaults to None. Examples: >>> import ppsci >>> pde = ppsci.equation.NavierStokes(0.1, 1.0, 3, False) """ - def __init__(self, nu: Union[float, Callable], rho: float, dim: int, time: bool): + def __init__( + self, + nu: Union[float, str], + rho: Union[float, str], + dim: int, + time: bool, + detach_keys: Optional[Tuple[str, ...]] = None, + ): super().__init__() - self.nu = nu - self.rho = rho + self.detach_keys = detach_keys self.dim = dim self.time = time - def continuity_compute_func(out): - x, y = out["x"], out["y"] - u, v = out["u"], out["v"] - continuity = jacobian(u, x) + jacobian(v, y) - if self.dim == 3: - z, w = out["z"], out["w"] - continuity += jacobian(w, z) - return continuity - - self.add_equation("continuity", continuity_compute_func) - - def momentum_x_compute_func(out): - nu = self.nu(out) if callable(self.nu) else self.nu - x, y = out["x"], out["y"] - u, v, p = out["u"], out["v"], out["p"] - momentum_x = ( - u * jacobian(u, x) - + v * jacobian(u, y) - - nu * hessian(u, x) - - nu * hessian(u, y) - + 1 / rho * jacobian(p, x) - ) - if self.time: - t = out["t"] - momentum_x += jacobian(u, t) - if self.dim == 3: - z, w = out["z"], out["w"] - momentum_x += w * jacobian(u, z) - momentum_x -= nu * hessian(u, z) - return momentum_x - - self.add_equation("momentum_x", momentum_x_compute_func) - - def momentum_y_compute_func(out): - nu = self.nu(out) if callable(self.nu) else self.nu - x, y = out["x"], out["y"] - u, v, p = out["u"], out["v"], out["p"] - momentum_y = ( - u * jacobian(v, x) - + v * jacobian(v, y) - - nu * hessian(v, x) - - nu * hessian(v, y) - + 1 / rho * jacobian(p, y) - ) - if self.time: - t = out["t"] - momentum_y += jacobian(v, t) - if self.dim == 3: - z, w = out["z"], out["w"] - momentum_y += w * jacobian(v, z) - momentum_y -= nu * hessian(v, z) - return momentum_y + t, x, y, z = self.create_symbols("t x y z") + invars = (x, y) + if time: + invars = (t,) + invars + if dim == 3: + invars += (z,) - self.add_equation("momentum_y", momentum_y_compute_func) + if isinstance(nu, str): + nu = self.create_function(nu, invars) + if isinstance(rho, str): + rho = self.create_function(rho, invars) - if self.dim == 3: + self.nu = nu + self.rho = rho - def momentum_z_compute_func(out): - nu = self.nu(out) if callable(self.nu) else self.nu - x, y, z = out["x"], out["y"], out["z"] - u, v, w, p = out["u"], out["v"], out["w"], out["p"] - momentum_z = ( - u * jacobian(w, x) - + v * jacobian(w, y) - + w * jacobian(w, z) - - nu * hessian(w, x) - - nu * hessian(w, y) - - nu * hessian(w, z) - + 1 / rho * jacobian(p, z) - ) - if self.time: - t = out["t"] - momentum_z += jacobian(w, t) - return momentum_z - - self.add_equation("momentum_z", momentum_z_compute_func) + u = self.create_function("u", invars) + v = self.create_function("v", invars) + w = self.create_function("w", invars) if dim == 3 else sp.Number(0) + p = self.create_function("p", invars) + + continuity = u.diff(x) + v.diff(y) + w.diff(z) + momentum_x = ( + u.diff(t) + + u * u.diff(x) + + v * u.diff(y) + + w * u.diff(z) + - ( + (nu * u.diff(x)).diff(x) + + (nu * u.diff(y)).diff(y) + + (nu * u.diff(z)).diff(z) + ) + + 1 / rho * p.diff(x) + ) + momentum_y = ( + v.diff(t) + + u * v.diff(x) + + v * v.diff(y) + + w * v.diff(z) + - ( + (nu * v.diff(x)).diff(x) + + (nu * v.diff(y)).diff(y) + + (nu * v.diff(z)).diff(z) + ) + + 1 / rho * p.diff(y) + ) + momentum_z = ( + w.diff(t) + + u * w.diff(x) + + v * w.diff(y) + + w * w.diff(z) + - ( + (nu * w.diff(x)).diff(x) + + (nu * w.diff(y)).diff(y) + + (nu * w.diff(z)).diff(z) + ) + + 1 / rho * p.diff(z) + ) + self.add_equation("continuity", continuity) + self.add_equation("momentum_x", momentum_x) + self.add_equation("momentum_y", momentum_y) + if self.dim == 3: + self.add_equation("momentum_z", momentum_z) diff --git a/ppsci/equation/pde/normal_dot_vec.py b/ppsci/equation/pde/normal_dot_vec.py index b71555fd04..c20efd2ffe 100644 --- a/ppsci/equation/pde/normal_dot_vec.py +++ b/ppsci/equation/pde/normal_dot_vec.py @@ -14,6 +14,7 @@ from __future__ import annotations +from typing import Optional from typing import Tuple from ppsci.equation.pde import base @@ -29,22 +30,28 @@ class NormalDotVec(base.PDE): Args: vec_keys (Tuple[str, ...]): Keys for vectors, such as ("u", "v", "w") for velocity vector. + detach_keys(Optional[Tuple[str, ...]]): Keys used for detach during computing. + Defaults to None. Examples: >>> import ppsci >>> pde = ppsci.equation.NormalDotVec(("u", "v", "w")) """ - def __init__(self, vec_keys: Tuple[str, ...]): + def __init__( + self, vec_keys: Tuple[str, ...], detach_keys: Optional[Tuple[str, ...]] = None + ): super().__init__() - self.vec_keys = vec_keys - self.normal_keys = ("normal_x", "normal_y", "normal_z") + self.detach_keys = detach_keys + if not vec_keys: + raise ValueError(f"len(vec_keys)({len(vec_keys)}) should be larger than 0.") - def normal_dot_vel_compute_func(out): - normal_dot_vel = 0 - for i, vec_key in enumerate(vec_keys): - normal_dot_vel += out[vec_key] * out[self.normal_keys[i]] + self.vec_keys = vec_keys + vec_vars = self.create_symbols(" ".join(vec_keys)) + normals = self.create_symbols("normal_x normal_y normal_z") - return normal_dot_vel + normal_dot_vec = 0 + for (normal, vec) in zip(normals, vec_vars): + normal_dot_vec += normal * vec - self.equations["normal_dot_vel"] = normal_dot_vel_compute_func + self.add_equation("normal_dot_vec", normal_dot_vec) diff --git a/ppsci/equation/pde/poisson.py b/ppsci/equation/pde/poisson.py index f2c6b9e02a..8cb7c62e7c 100644 --- a/ppsci/equation/pde/poisson.py +++ b/ppsci/equation/pde/poisson.py @@ -14,7 +14,9 @@ from __future__ import annotations -from ppsci.autodiff import hessian +from typing import Optional +from typing import Tuple + from ppsci.equation.pde import base @@ -27,21 +29,23 @@ class Poisson(base.PDE): Args: dim (int): Dimension of equation. + detach_keys(Optional[Tuple[str, ...]]): Keys used for detach during computing. + Defaults to None. Examples: >>> import ppsci >>> pde = ppsci.equation.Poisson(2) """ - def __init__(self, dim: int): + def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None): super().__init__() + self.detach_keys = detach_keys + invars = self.create_symbols("x y z")[:dim] + p = self.create_function("p", invars) self.dim = dim - def poisson_compute_func(out): - invars = ("x", "y", "z")[: self.dim] - poisson = 0 - for invar in invars: - poisson += hessian(out["p"], out[invar]) - return poisson + poisson = 0 + for invar in invars: + poisson += p.diff(invar, 2) - self.add_equation("poisson", poisson_compute_func) + self.add_equation("poisson", poisson) diff --git a/ppsci/equation/pde/viv.py b/ppsci/equation/pde/viv.py index fad58a3b30..0e37a721c4 100644 --- a/ppsci/equation/pde/viv.py +++ b/ppsci/equation/pde/viv.py @@ -15,10 +15,9 @@ from __future__ import annotations import paddle +import sympy as sp from paddle.nn import initializer -from ppsci.autodiff import hessian -from ppsci.autodiff import jacobian from ppsci.equation.pde import base @@ -45,25 +44,21 @@ def __init__(self, rho: float, k1: float, k2: float): self.k1 = paddle.create_parameter( shape=[], dtype=paddle.get_default_dtype(), + name="k1", default_initializer=initializer.Constant(k1), ) self.k2 = paddle.create_parameter( shape=[], dtype=paddle.get_default_dtype(), + name="k2", default_initializer=initializer.Constant(k2), ) self.learnable_parameters.append(self.k1) self.learnable_parameters.append(self.k2) - def f_compute_func(out): - eta, t = out["eta"], out["t_f"] - eta__t = jacobian(eta, t) - eta__t__t = hessian(eta, t) - f = ( - self.rho * eta__t__t - + paddle.exp(self.k1) * eta__t - + paddle.exp(self.k2) * eta - ) - return f - - self.add_equation("f", f_compute_func) + t_f = self.create_symbols("t_f") + eta = self.create_function("eta", (t_f,)) + k1 = self.create_symbols(self.k1.name) + k2 = self.create_symbols(self.k2.name) + f = self.rho * eta.diff(t_f, 2) + sp.exp(k1) * eta.diff(t_f) + sp.exp(k2) * eta + self.add_equation("f", f) diff --git a/ppsci/geometry/geometry.py b/ppsci/geometry/geometry.py index f511dc730e..3e087c70cf 100644 --- a/ppsci/geometry/geometry.py +++ b/ppsci/geometry/geometry.py @@ -64,7 +64,13 @@ def uniform_points(self, n: int, boundary=True): ) return self.random_points(n) - def sample_interior(self, n, random="pseudo", criteria=None, evenly=False): + def sample_interior( + self, + n, + random="pseudo", + criteria=None, + evenly=False, + ): """Sample random points in the geometry and return those meet criteria.""" x = np.empty(shape=(n, self.ndim), dtype=paddle.get_default_dtype()) _size, _ntry, _nsuc = 0, 0, 0 @@ -103,6 +109,7 @@ def sample_interior(self, n, random="pseudo", criteria=None, evenly=False): else: sdf_dict = {} x_dict = misc.convert_to_dict(x, self.dim_keys) + return {**x_dict, **sdf_dict} def sample_boundary(self, n, random="pseudo", criteria=None, evenly=False): diff --git a/ppsci/geometry/mesh.py b/ppsci/geometry/mesh.py index f43dab31a2..0c10e47ce8 100644 --- a/ppsci/geometry/mesh.py +++ b/ppsci/geometry/mesh.py @@ -438,11 +438,11 @@ def sample_interior(self, n, random="pseudo", criteria=None, evenly=False): points, areas = self.random_points(n, random, criteria) x_dict = misc.convert_to_dict(points, self.dim_keys) - area_dict = misc.convert_to_dict(areas, ["area"]) + area_dict = misc.convert_to_dict(areas, ("area",)) # NOTE: add negtive to the sdf values because weight should be positive. sdf = -self.sdf_func(points) - sdf_dict = misc.convert_to_dict(sdf, ["sdf"]) + sdf_dict = misc.convert_to_dict(sdf, ("sdf",)) return {**x_dict, **area_dict, **sdf_dict} diff --git a/ppsci/solver/solver.py b/ppsci/solver/solver.py index cf77218cd0..baa288fdf2 100644 --- a/ppsci/solver/solver.py +++ b/ppsci/solver/solver.py @@ -27,6 +27,7 @@ import numpy as np import paddle import paddle.distributed as dist +import sympy as sp import visualdl as vdl from packaging import version from paddle import amp @@ -313,6 +314,40 @@ def __init__( # use loss aggregator, use summation if None self.loss_aggregator = loss_aggregator + # convert sympy to callable object if exist + extra_parameters = [] + for equation in self.equation.values(): + extra_parameters += list(equation.learnable_parameters) + + def convert_expr( + container_dict: Dict[ + str, + Union[ + ppsci.constraint.Constraint, + ppsci.validate.Validator, + ppsci.visualize.Visualizer, + ], + ] + ) -> None: + for container in container_dict.values(): + for name, expr in container.output_expr.items(): + if isinstance(expr, sp.Basic): + container.output_expr[name] = ppsci.lambdify( + expr, + self.model, + extra_parameters, + # os.path.join(self.output_dir, container.name, expr), # HACK: Activate it for DEBUG. + ) + + if self.constraint: + convert_expr(self.constraint) + + if self.validator: + convert_expr(self.validator) + + if self.visualizer: + convert_expr(self.visualizer) + @staticmethod def from_config(cfg: Dict[str, Any]) -> Solver: """Initialize solver from given config. diff --git a/ppsci/utils/__init__.py b/ppsci/utils/__init__.py index 1d341c40b1..e6e327ffe6 100644 --- a/ppsci/utils/__init__.py +++ b/ppsci/utils/__init__.py @@ -32,6 +32,7 @@ from ppsci.utils.save_load import load_checkpoint from ppsci.utils.save_load import load_pretrain from ppsci.utils.save_load import save_checkpoint +from ppsci.utils.symbolic import lambdify __all__ = [ "initializer", @@ -54,4 +55,5 @@ "load_checkpoint", "load_pretrain", "save_checkpoint", + "lambdify", ] diff --git a/ppsci/utils/expression.py b/ppsci/utils/expression.py index eecbb84223..bcf866ae53 100644 --- a/ppsci/utils/expression.py +++ b/ppsci/utils/expression.py @@ -27,6 +27,7 @@ import paddle from ppsci import constraint from ppsci import validate + from ppsci import arch from ppsci.autodiff import clear @@ -54,7 +55,7 @@ def train_forward( self, expr_dicts: Tuple[Dict[str, Callable], ...], input_dicts: Tuple[Dict[str, "paddle.Tensor"], ...], - model: nn.Layer, + model: arch.Arch, constraint: Dict[str, "constraint.Constraint"], label_dicts: Tuple[Dict[str, "paddle.Tensor"], ...], weight_dicts: Tuple[Dict[str, "paddle.Tensor"], ...], @@ -65,7 +66,7 @@ def train_forward( Args: expr_dicts (Tuple[Dict[str, Callable], ...]): Tuple of expression dicts. input_dicts (Tuple[Dict[str, paddle.Tensor], ...]): Tuple of input dicts. - model (nn.Layer): NN model. + model (arch.Arch): NN model. constraint (Dict[str, "constraint.Constraint"]): Constraint dict. label_dicts (Tuple[Dict[str, paddle.Tensor], ...]): Tuple of label dicts. weight_dicts (Tuple[Dict[str, paddle.Tensor], ...]): Tuple of weight dicts. @@ -76,17 +77,13 @@ def train_forward( output_dicts = [] for i, expr_dict in enumerate(expr_dicts): # model forward - if callable(next(iter(expr_dict.values()))): - output_dict = model(input_dicts[i]) + output_dict = model(input_dicts[i]) # equation forward + data_dict = {k: v for k, v in input_dicts[i].items()} + data_dict.update(output_dict) for name, expr in expr_dict.items(): - if name not in label_dicts[i]: - continue - if callable(expr): - output_dict[name] = expr({**output_dict, **input_dicts[i]}) - else: - raise TypeError(f"expr type({type(expr)}) is invalid") + output_dict[name] = expr(data_dict) # put field 'area' into output_dict if "area" in input_dicts[i]: @@ -113,7 +110,7 @@ def eval_forward( self, expr_dict: Dict[str, Callable], input_dict: Dict[str, "paddle.Tensor"], - model: nn.Layer, + model: arch.Arch, validator: "validate.Validator", label_dict: Dict[str, "paddle.Tensor"], weight_dict: Dict[str, "paddle.Tensor"], @@ -124,7 +121,7 @@ def eval_forward( Args: expr_dict (Dict[str, Callable]): Expression dict. input_dict (Dict[str, paddle.Tensor]): Input dict. - model (nn.Layer): NN model. + model (arch.Arch): NN model. validator (validate.Validator): Validator. label_dict (Dict[str, paddle.Tensor]): Label dict. weight_dict (Dict[str, paddle.Tensor]): Weight dict. @@ -134,17 +131,13 @@ def eval_forward( given validator. """ # model forward - if callable(next(iter(expr_dict.values()))): - output_dict = model(input_dict) + output_dict = model(input_dict) # equation forward + data_dict = {k: v for k, v in input_dict.items()} + data_dict.update(output_dict) for name, expr in expr_dict.items(): - if name not in label_dict: - continue - if callable(expr): - output_dict[name] = expr({**output_dict, **input_dict}) - else: - raise TypeError(f"expr type({type(expr)}) is invalid") + output_dict[name] = expr(data_dict) # put field 'area' into output_dict if "area" in input_dict: @@ -165,7 +158,7 @@ def visu_forward( self, expr_dict: Optional[Dict[str, Callable]], input_dict: Dict[str, "paddle.Tensor"], - model: nn.Layer, + model: arch.Arch, ) -> Dict[str, "paddle.Tensor"]: """Forward computation for visualization, including model forward and equation forward. @@ -173,7 +166,7 @@ def visu_forward( Args: expr_dict (Optional[Dict[str, Callable]]): Expression dict. input_dict (Dict[str, paddle.Tensor]): Input dict. - model (nn.Layer): NN model. + model (arch.Arch): NN model. Returns: Dict[str, paddle.Tensor]: Result dict for given expression dict. @@ -183,11 +176,10 @@ def visu_forward( if isinstance(expr_dict, dict): # equation forward + data_dict = {k: v for k, v in input_dict.items()} + data_dict.update(output_dict) for name, expr in expr_dict.items(): - if callable(expr): - output_dict[name] = expr({**output_dict, **input_dict}) - else: - raise TypeError(f"expr type({type(expr)}) is invalid") + output_dict[name] = expr(data_dict) # clear differentiation cache clear() diff --git a/ppsci/utils/initializer.py b/ppsci/utils/initializer.py index afc38a25bc..de8a992f24 100644 --- a/ppsci/utils/initializer.py +++ b/ppsci/utils/initializer.py @@ -92,9 +92,9 @@ def norm_cdf(x): # Transform to proper mean, std _tensor = paddle.multiply( - _tensor, paddle.to_tensor(std * math.sqrt(2.0), _tensor.dtype) + _tensor, paddle.to_tensor(std * math.sqrt(2.0), tensor.dtype) ) - _tensor = paddle.add(_tensor, paddle.to_tensor(mean, _tensor.dtype)) + _tensor = paddle.add(_tensor, paddle.to_tensor(mean, tensor.dtype)) # Clamp to ensure it"s in the proper range _tensor = paddle.clip(_tensor, min=a, max=b) @@ -438,9 +438,10 @@ def linear_init_(module: nn.Layer) -> None: Args: module (nn.Layer): Linear Layer to be initialized. """ - bound = 1 / math.sqrt(module.weight.shape[0]) - uniform_(module.weight, -bound, bound) + kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(module.weight, reverse=True) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 uniform_(module.bias, -bound, bound) @@ -450,7 +451,9 @@ def conv_init_(module: nn.Layer) -> None: Args: module (nn.Layer): Convolution Layer to be initialized. """ - bound = 1 / np.sqrt(np.prod(module.weight.shape[1:])) - uniform_(module.weight, -bound, bound) + kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: - uniform_(module.bias, -bound, bound) + fan_in, _ = _calculate_fan_in_and_fan_out(module.weight, reverse=False) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + uniform_(module.bias, -bound, bound) diff --git a/ppsci/utils/sym_to_func.py b/ppsci/utils/sym_to_func.py deleted file mode 100644 index 15809171d9..0000000000 --- a/ppsci/utils/sym_to_func.py +++ /dev/null @@ -1,422 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Sympy to python function conversion module -""" - -from __future__ import annotations - -import functools -from typing import TYPE_CHECKING -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union - -import paddle -import sympy as sp -from paddle import nn -from typing_extensions import TypeAlias - -from ppsci.autodiff import hessian -from ppsci.autodiff import jacobian - -if TYPE_CHECKING: - from ppsci import arch - - -__all__ = [ - "sympy_to_function", -] - - -PADDLE_FUNC_MAP = { - sp.sin: paddle.sin, - sp.cos: paddle.cos, - sp.exp: paddle.exp, - sp.Pow: paddle.pow, - sp.log: paddle.log, - sp.tan: paddle.tan, - sp.Max: paddle.maximum, - sp.Min: paddle.minimum, - sp.Abs: paddle.abs, - sp.Heaviside: functools.partial(paddle.heaviside, y=paddle.zeros([])), -} - -SYMPY_BUILTIN_FUNC: TypeAlias = Union[ - sp.sin, - sp.cos, - sp.exp, - sp.Pow, - sp.log, - sp.tan, - sp.Max, - sp.Min, - sp.Abs, - sp.Heaviside, -] - - -def _cvt_to_key(expr: sp.Basic) -> str: - """Convert sympy expression to a string key, mainly as retrieval key in dict. - - Args: - expr (sp.Basic): Sympy expression. - - Returns: - str: Converted string key. - """ - if isinstance(expr, (sp.Symbol, sp.core.function.UndefinedFunction, sp.Function)): - if hasattr(expr, "name"): - # use name of custom function instead of itself. - return expr.name - else: - return str(expr) - elif isinstance(expr, sp.Derivative): - # convert Derivative(u(x,y),(x,2),(y,2)) to "u__x__x__y__y" - expr_str = expr.args[0].name - for symbol, order in expr.args[1:]: - expr_str += f"__{symbol}" * order - return expr_str - else: - return str(expr) - - -class Node(nn.Layer): - """The base class of the node in expression tree. - - Args: - expr (sp.Basic): Sympy expression. - """ - - def __init__(self, expr: sp.Basic): - super().__init__() - self.expr = expr - self.key = _cvt_to_key(self.expr) - - def forward(self, **kwargs): - raise NotImplementedError("Node.forward is not implemented") - - def __str__(self): - return f"{self.__class__.__name__}(expr: {self.expr}, expr_type: {type(self.expr)})" - - def __repr__(self): - return f"{self.__class__.__name__}(expr: {self.expr})" - - -class OperatorNode(Node): - """Class for operator node in converted expression tree. - - Args: - expr (SYMPY_BUILTIN_FUNC): Sympy expression. - """ - - def __init__(self, expr: SYMPY_BUILTIN_FUNC): - super().__init__(expr) - # preprocess childs' key instead of processing at run-time - # which can reduce considerable overhead of time for calling "_cvt_to_key" - if self.expr.func == sp.Derivative: - self.childs = [_cvt_to_key(self.expr.args[0])] + [ - (_cvt_to_key(arg), order) for (arg, order) in self.expr.args[1:] - ] - else: - self.childs = [_cvt_to_key(arg) for arg in self.expr.args] - - if self.expr.func == sp.Add: - self._operator_func = self._add_operator_func - elif self.expr.func == sp.Mul: - self._operator_func = self._mul_operator_func - elif self.expr.func == sp.Derivative: - self._operator_func = self._derivate_operator_func - else: - if self.expr.func == sp.Heaviside: - self._operator_func = self._heaviside_operator_func - self._compute_func = PADDLE_FUNC_MAP[sp.Heaviside] - else: - self._operator_func = self._vanilla_operator_func - self._compute_func = PADDLE_FUNC_MAP[self.expr.func] - - def forward(self, data_dict: Dict): - # use cache - if self.key in data_dict: - return data_dict - - return self._operator_func(data_dict) - - def _add_operator_func(self, data_dict): - data_dict[self.key] = sum([data_dict[child] for child in self.childs]) - return data_dict - - def _mul_operator_func(self, data_dict): - data_dict[self.key] = data_dict[self.childs[0]] - for child in self.childs[1:]: - data_dict[self.key] *= data_dict[child] - return data_dict - - def _derivate_operator_func(self, data_dict): - data_dict[self.key] = data_dict[self.childs[0]] - for child, order in self.childs[1:]: - if order & 1: - data_dict[self.key] = jacobian(data_dict[self.key], data_dict[child]) - order -= 1 - while order > 0: - data_dict[self.key] = hessian(data_dict[self.key], data_dict[child]) - order -= 2 - return data_dict - - def _heaviside_operator_func(self, data_dict): - data_dict[self.key] = self._compute_func(data_dict[self.childs[0]]) - return data_dict - - def _vanilla_operator_func(self, data_dict): - data_dict[self.key] = self._compute_func( - *tuple(data_dict[child] for child in self.childs) - ) - return data_dict - - -class LayerNode(Node): - """Class for layer node in converted expression tree. - - Args: - expr (sp.core.function.UndefinedFunction): Sympy expression. - model (nn.Layer): NN model for computing forward result in this node. - """ - - def __init__( - self, - expr: sp.core.function.UndefinedFunction, - model: arch.Arch, - detach_keys: Optional[Tuple[str, ...]] = None, - ): - super().__init__(expr) - self.model = model - self.detach_keys = detach_keys - - def forward(self, data_dict: Dict): - # use cache - if self.key in data_dict: - return data_dict - - output_dict = self.model(data_dict) - data_dict.update(output_dict) - - # detach Tensor(s) if specified - if self.detach_keys: - for key in self.detach_keys: - data_dict[key] = data_dict[key].detach() - - return data_dict - - -class ConstantNode(Node): - """Class for constant variable node in converted expression tree. - - Args: - expr (Union[sp.Number, sp.NumberSymbol]): Number expression. - """ - - def __init__(self, expr: Union[sp.Number, sp.NumberSymbol]): - super().__init__(expr) - if ( - self.expr.is_Float - or self.expr.is_Integer - or self.expr.is_Boolean - or self.expr.is_Rational - ): - self.expr = float(self.expr) - else: - raise TypeError( - f"expr({expr}) should be Float/Integer/Boolean/Rational, but got {type(self.expr)}" - ) - self.expr = paddle.to_tensor(self.expr) - - def forward(self, data_dict: Dict): - # use cache - if self.key in data_dict: - return data_dict - - data_dict[self.key] = self.expr - return data_dict - - -class ComposedNode(nn.Layer): - """ - Compose list of several callable objects together. - """ - - def __init__(self, funcs: List[Node]): - super().__init__() - self.funcs = funcs - - def forward(self, data_dict: Dict): - # call all funcs in order - for func in self.funcs: - data_dict = func(data_dict) - - # return result of last node(root node) for target - return data_dict[self.funcs[-1].key] - - -def _post_traverse(cur_node: sp.Basic, nodes: List[sp.Basic]) -> List[sp.Basic]: - """Traverse sympy expression tree in postorder. - - Args: - cur_node (sp.Basic): Sympy expression of current node. - nodes (List[sp.Basic]): Node list storing all tree nodes in postorder. - - Returns: - List[sp.Basic]: Node list storing all tree nodes in postorder. - """ - # traverse into sub-nodes - if isinstance(cur_node, sp.core.function.UndefinedFunction): - nodes.append(cur_node) - elif isinstance(cur_node, sp.Function): - for arg in cur_node.args: - nodes = _post_traverse(arg, nodes) - nodes.append(cur_node) - elif isinstance(cur_node, sp.Derivative): - nodes = _post_traverse(cur_node.args[0], nodes) - nodes.append(cur_node) - elif isinstance(cur_node, sp.Symbol): - return nodes - elif isinstance(cur_node, sp.Number): - nodes.append(cur_node) - else: - for arg in cur_node.args: - nodes = _post_traverse(arg, nodes) - nodes.append(cur_node) - return nodes - - -def sympy_to_function( - expr: sp.Expr, - models: Optional[Union[arch.Arch, Tuple[arch.Arch, ...]]] = None, - detach_keys: Tuple[str, ...] = None, -) -> ComposedNode: - """Convert sympy expression to callable function. - - Args: - expr (sp.Expr): Sympy expression to be converted. - models (Optional[Union[arch.Arch, Tuple[arch.Arch, ...]]]): Model(s) for computing forward result in `LayerNode`. - - Returns: - ComposedNode: Callable object for computing expr with necessary input(s) data in dict given. - - Examples: - >>> import paddle - >>> import sympy as sp - >>> from ppsci import arch - >>> from ppsci.utils import sym_to_func - - >>> a, b, c, x, y = sp.symbols("a b c x y") - >>> u = sp.Function("u")(x, y) - >>> v = sp.Function("v")(x, y) - >>> z = -a + b * (c ** 2) + u * v + 2.3 - - >>> model = arch.MLP(("x", "y"), ("u", "v"), 4, 16) - - >>> batch_size = 13 - >>> a_tensor = paddle.randn([batch_size, 1]) - >>> b_tensor = paddle.randn([batch_size, 1]) - >>> c_tensor = paddle.randn([batch_size, 1]) - >>> x_tensor = paddle.randn([batch_size, 1]) - >>> y_tensor = paddle.randn([batch_size, 1]) - - >>> model_output_dict = model({"x": x_tensor, "y": y_tensor}) - >>> u_tensor, v_tensor = model_output_dict["u"], model_output_dict["v"] - - >>> z_tensor_manually = ( - ... -a_tensor + b_tensor * (c_tensor ** 2) - ... + u_tensor * v_tensor + 2.3 - ... ) - >>> z_tensor_sympy = sym_to_func.sympy_to_function(z, model)( - ... { - ... "a": a_tensor, - ... "b": b_tensor, - ... "c": c_tensor, - ... "x": x_tensor, - ... "y": y_tensor, - ... } - ... ) - - >>> paddle.allclose(z_tensor_manually, z_tensor_sympy).item() - True - """ - - # NOTE: Those simplify methods seem complicate given expr instead, so not use them here - # simplify expression to reduce nodes in tree - # expr = sp.nsimplify(expr) - # expr = sp.expand(expr) - # expr = sp.simplify(expr) - - # convert sympy expression tree to list of nodes in postorder - sympy_nodes = [] - sympy_nodes = _post_traverse(expr, sympy_nodes) - - # remove unnecessary symbol node for already in input dict - sympy_nodes = [node for node in sympy_nodes if not node.is_Symbol] - - # remove duplicates with topo-order kept - sympy_nodes = list(dict.fromkeys(sympy_nodes)) - - if models is None: - models = () - if detach_keys is None: - detach_keys = () - if not isinstance(models, (tuple, list)): - models = (models,) - - # convert sympy node to callable node - callable_nodes = [] - for i, node in enumerate(sympy_nodes): - if isinstance(node.func, sp.core.function.UndefinedFunction): - match_index = None - for j, model in enumerate(models): - if str(node.func.name) in model.output_keys: - callable_nodes.append( - LayerNode( - node, - model, - tuple( - key for key in detach_keys if key in model.output_keys - ), - ) - ) - if match_index is not None: - raise ValueError( - f"Name of function({node}) should be unique along given models," - f" but got same output_key({node.func.name}) in models[{match_index}]" - f" and models[{j}]." - ) - match_index = j - elif ( - isinstance(node, tuple(PADDLE_FUNC_MAP.keys())) - or node.is_Add - or node.is_Mul - or node.is_Derivative - or node.is_Pow - ): - callable_nodes.append(OperatorNode(node)) - elif node.is_Number or node.is_NumberSymbol: - callable_nodes.append(ConstantNode(node)) - else: - raise NotImplementedError( - f"The node {node} is not supported in sympy_to_function." - ) - - # Compose callable nodes into one callable object - return ComposedNode(callable_nodes) diff --git a/ppsci/utils/symbolic.py b/ppsci/utils/symbolic.py new file mode 100644 index 0000000000..06ef4c3b80 --- /dev/null +++ b/ppsci/utils/symbolic.py @@ -0,0 +1,631 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sympy to python function conversion module +""" + +from __future__ import annotations + +import functools +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import paddle +import sympy as sp +from paddle import nn +from typing_extensions import TypeAlias + +from ppsci import arch +from ppsci import equation +from ppsci.autodiff import hessian +from ppsci.autodiff import jacobian + +__all__ = [ + "lambdify", +] + + +DATA_DICT: TypeAlias = Dict[str, paddle.Tensor] + +SYMPY_BUILTIN_FUNC: TypeAlias = Union[ + sp.sin, + sp.sinh, + sp.asin, + sp.cos, + sp.acos, + sp.cosh, + sp.tan, + sp.atan, + sp.atan2, + sp.acosh, + sp.asinh, + sp.tanh, + sp.atanh, + sp.erf, + sp.loggamma, + sp.exp, + sp.Pow, + sp.log, + sp.Max, + sp.Min, + sp.Abs, + sp.Heaviside, + sp.sign, + sp.ceiling, + sp.floor, + sp.Add, + sp.Mul, +] + +SYMPT_TO_PADDLE = { + sp.sin: paddle.sin, + sp.sinh: paddle.sinh, + sp.asin: paddle.asin, + sp.cos: paddle.cos, + sp.acos: paddle.acos, + sp.cosh: paddle.cosh, + sp.tan: paddle.tan, + sp.atan: paddle.atan, + sp.atan2: paddle.atan2, + sp.acosh: paddle.acosh, + sp.asinh: paddle.asinh, + sp.tanh: paddle.tanh, + sp.atanh: paddle.atanh, + sp.erf: paddle.erf, + sp.loggamma: paddle.lgamma, + sp.exp: paddle.exp, + sp.Pow: paddle.pow, + sp.log: paddle.log, + sp.Max: paddle.maximum, + sp.Min: paddle.minimum, + sp.Abs: paddle.abs, + sp.Heaviside: functools.partial(paddle.heaviside, y=paddle.zeros([])), + sp.sign: paddle.sign, + sp.ceiling: paddle.ceil, + sp.floor: paddle.floor, + # NOTE: sp.Add and sp.Mul is not included here for unalignment with sympy + # and are implemented manually. +} + + +def _cvt_to_key(expr: sp.Basic) -> str: + """Convert sympy expression to a string key, mainly as retrieval key in dict. + + Args: + expr (sp.Basic): Sympy expression. + + Returns: + str: Converted string key. + """ + if isinstance(expr, (sp.Symbol, sp.core.function.UndefinedFunction, sp.Function)): + if hasattr(expr, "name"): + # use name of custom function instead of itself. + return expr.name + else: + return str(expr) + elif isinstance(expr, sp.Derivative): + # convert Derivative(u(x,y),(x,2),(y,2)) to "u__x__x__y__y" + expr_str = expr.args[0].name + for symbol, order in expr.args[1:]: + expr_str += f"__{symbol}" * order + return expr_str + else: + return str(expr) + + +class Node(nn.Layer): + """The base class of the node in expression tree. + + Args: + expr (sp.Basic): Sympy expression. + """ + + def __init__(self, expr: sp.Basic): + super().__init__() + self.expr = expr + self.key = _cvt_to_key(self.expr) + + def forward(self, **kwargs): + raise NotImplementedError("Node.forward is not implemented") + + def __str__(self): + return ( + f"{self.__class__.__name__}(expr: {self.expr}, " + f"expr_type: {type(self.expr)})" + ) + + def __repr__(self): + return f"{self.__class__.__name__}(expr: {self.expr})" + + +class DetachNode(nn.Layer): + """Class for detach operation in converted expression tree. + + Args: + expr (sp.Basic): Sympy expression. + """ + + def __init__(self, expr: sp.Basic): + super().__init__() + self.expr = expr + self.key = _cvt_to_key(self.expr) + self.child = _cvt_to_key(self.expr.args[0]) + + def forward(self, data_dict: DATA_DICT): + if self.key in data_dict: + return data_dict + + data_dict[self.key] = data_dict[self.child].detach() + return data_dict + + +class OperatorNode(Node): + """Class for operator node in converted expression tree. + + Args: + expr (SYMPY_BUILTIN_FUNC): Sympy expression. + """ + + def __init__(self, expr: SYMPY_BUILTIN_FUNC): + super().__init__(expr) + # preprocess childs' key instead of processing at run-time in forward + # which can reduce considerable overhead of time for calling "_cvt_to_key" + if self.expr.func == sp.Derivative: + self.childs = [_cvt_to_key(self.expr.args[0])] + [ + (_cvt_to_key(arg), order) for (arg, order) in self.expr.args[1:] + ] + else: + self.childs = [_cvt_to_key(arg) for arg in self.expr.args] + + if self.expr.func == sp.Add: + self._apply_func = self._add_operator_func + elif self.expr.func == sp.Mul: + self._apply_func = self._mul_operator_func + elif self.expr.func == sp.Derivative: + self._apply_func = self._derivate_operator_func + elif self.expr.func == sp.Heaviside: + self._apply_func = self._heaviside_operator_func + self._auxiliary_func = SYMPT_TO_PADDLE[sp.Heaviside] + elif self.expr.func == sp.Min: + self._apply_func = self._minimum_operator_func + elif self.expr.func == sp.Max: + self._apply_func = self._maximum_operator_func + else: + self._apply_func = self._vanilla_operator_func + self._auxiliary_func = SYMPT_TO_PADDLE[self.expr.func] + + def forward(self, data_dict: DATA_DICT): + # use cache + if self.key in data_dict: + return data_dict + + return self._apply_func(data_dict) + + def _add_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT: + data_dict[self.key] = data_dict[self.childs[0]] + for p in self.childs[1:]: + data_dict[self.key] += data_dict[p] + return data_dict + + def _mul_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT: + data_dict[self.key] = data_dict[self.childs[0]] + for child in self.childs[1:]: + data_dict[self.key] *= data_dict[child] + return data_dict + + def _derivate_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT: + data_dict[self.key] = data_dict[self.childs[0]] + for child, order in self.childs[1:]: + if order & 1: + data_dict[self.key] = jacobian(data_dict[self.key], data_dict[child]) + order -= 1 + for _ in range(0, order, 2): + data_dict[self.key] = hessian(data_dict[self.key], data_dict[child]) + order -= 2 + return data_dict + + def _heaviside_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT: + data_dict[self.key] = self._auxiliary_func(data_dict[self.childs[0]]) + return data_dict + + def _minimum_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT: + data_dict[self.key] = paddle.minimum( + data_dict[self.childs[0]], data_dict[self.childs[1]] + ) + for i in range(2, len(self.childs)): + data_dict[self.key] = paddle.minimum( + data_dict[data_dict[self.key]], + data_dict[data_dict[self.childs[i]]], + ) + return data_dict + + def _maximum_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT: + data_dict[self.key] = paddle.maximum( + data_dict[self.childs[0]], data_dict[self.childs[1]] + ) + for i in range(2, len(self.childs)): + data_dict[self.key] = paddle.maximum( + data_dict[data_dict[self.key]], + data_dict[data_dict[self.childs[i]]], + ) + return data_dict + + def _vanilla_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT: + data_dict[self.key] = self._auxiliary_func( + *tuple(data_dict[child] for child in self.childs) + ) + return data_dict + + +class LayerNode(Node): + """Class for layer node in converted expression tree. + + Args: + expr (sp.core.function.UndefinedFunction): Sympy expression. + model (arch.Arch): NN model for computing forward result in this node. + """ + + def __init__( + self, + expr: sp.core.function.UndefinedFunction, + model: arch.Arch, + ): + super().__init__(expr) + self.model = model + + def forward(self, data_dict: DATA_DICT) -> DATA_DICT: + # use cache + if self.key in data_dict: + return data_dict + + output_dict = self.model(data_dict) + data_dict.update(output_dict) + + return data_dict + + +class ConstantNode(Node): + """Class for constant variable node in converted expression tree. + + Args: + expr (Union[sp.Number, sp.NumberSymbol]): Number expression. + """ + + def __init__(self, expr: Union[sp.Number, sp.NumberSymbol]): + super().__init__(expr) + if ( + self.expr.is_Float + or self.expr.is_Integer + or self.expr.is_Boolean + or self.expr.is_Rational + ): + self.expr = float(self.expr) + else: + raise TypeError( + "expr({expr}) should be Float/Integer/Boolean/Rational, " + f"but got {type(self.expr)}" + ) + self.expr = paddle.to_tensor(self.expr) + + def forward(self, data_dict: DATA_DICT) -> DATA_DICT: + # use cache + if self.key in data_dict: + return data_dict + + data_dict[self.key] = self.expr + return data_dict + + +class ParameterNode(Node): + """Class for constant variable node in converted expression tree. + + Args: + expr (sp.Symbol): Parameter expression. + parameter (paddle.framework.io.EagerParamBase): Parameter tensor. + """ + + def __init__(self, expr: sp.Symbol, parameter: paddle.framework.io.EagerParamBase): + super().__init__(expr) + self.parameter = parameter + + def forward(self, data_dict: DATA_DICT) -> DATA_DICT: + data_dict[self.key] = self.parameter + return data_dict + + +class ComposedNode(nn.Layer): + """ + Compose list of several callable objects together. + """ + + def __init__(self, callable_nodes: List[Node]): + super().__init__() + self.callable_nodes = callable_nodes + + def forward(self, data_dict: DATA_DICT) -> DATA_DICT: + # call all callable_nodes in order + for func in self.callable_nodes: + data_dict = func(data_dict) + + # return result of last node(root node) for target + return data_dict[self.callable_nodes[-1].key] + + +def _post_traverse(cur_node: sp.Basic, nodes: List[sp.Basic]) -> List[sp.Basic]: + """Traverse sympy expression tree in postorder. + + Args: + cur_node (sp.Basic): Sympy expression of current node. + nodes (List[sp.Basic]): Node list storing all tree nodes in postorder. + + Returns: + List[sp.Basic]: Node list storing all tree nodes in postorder. + """ + # traverse into sub-nodes + if isinstance(cur_node, sp.Function): + for arg in cur_node.args: + nodes = _post_traverse(arg, nodes) + nodes.append(cur_node) + elif isinstance(cur_node, sp.Derivative): + nodes = _post_traverse(cur_node.args[0], nodes) + nodes.append(cur_node) + elif isinstance(cur_node, sp.Symbol): + nodes.append(cur_node) + return nodes + elif isinstance(cur_node, sp.Number): + nodes.append(cur_node) + else: + for arg in cur_node.args: + nodes = _post_traverse(arg, nodes) + nodes.append(cur_node) + return nodes + + +def _visualize_graph(nodes: List[sp.Basic], graph_filename: str): + try: + import pygraphviz + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Please install pygraphviz by steps below:\n" + "1. apt-get install graphviz graphviz-dev\n" + "2. python -m pip install pygraphviz" + ) + + SYMPY_BUILTIN_NAME = { + sp.sin: "sin", + sp.sinh: "sinh", + sp.asin: "asin", + sp.cos: "cos", + sp.acos: "acos", + sp.cosh: "cosh", + sp.tan: "tan", + sp.atan: "atan", + sp.atan2: "atan2", + sp.acosh: "acosh", + sp.asinh: "asinh", + sp.tanh: "tanh", + sp.atanh: "atanh", + sp.erf: "erf", + sp.loggamma: "loggamma", + sp.exp: "exp", + sp.Pow: "Pow", + sp.log: "log", + sp.Max: "Max", + sp.Min: "Min", + sp.Abs: "Abs", + sp.Heaviside: "Heaviside", + sp.sign: "sign", + sp.ceiling: "ceiling", + sp.floor: "floor", + sp.Add: "Add", + sp.Mul: "Mul", + } + naming_counter = {k: 0 for k in SYMPY_BUILTIN_NAME} + + def get_operator_name(node): + ret = f"{SYMPY_BUILTIN_NAME[node.func]}_{naming_counter[node.func]}" + naming_counter[node.func] += 1 + return ret + + graph = pygraphviz.AGraph(directed=True, rankdir="TB") + C_FUNC = "#9196f1" # purple color function node + C_DATA = "#feb64d" # oringe color for data node + C_EDGE = "#000000" # black color for edge + + def add_edge(u: str, v: str, u_color: str = C_DATA, v_color: str = C_DATA): + """Add an edge from `u` to `v`. + + Args: + u (str): Name of begin node u. + v (str): Name of end node v. + u_color (str, optional): _description_. Defaults to C_DATA. + v_color (str, optional): _description_. Defaults to C_DATA. + """ + graph.add_node(u, style="filled", shape="ellipse", color=u_color) + graph.add_node(v, style="filled", shape="ellipse", color=v_color) + graph.add_edge(u, v, color=C_EDGE, style="solid", penwidth=0.5, arrowsize=0.5) + + for node in nodes: + if isinstance(node, tuple(SYMPY_BUILTIN_NAME.keys())): + operator_str = get_operator_name(node) + for arg in node.args: + add_edge(_cvt_to_key(arg), operator_str, v_color=C_FUNC) + add_edge(operator_str, _cvt_to_key(node), u_color=C_FUNC) + if isinstance(node, sp.Function): + for arg in node.args: + add_edge(_cvt_to_key(arg), str(node), v_color=C_FUNC) + add_edge(str(node), _cvt_to_key(node), u_color=C_FUNC) + elif isinstance(node, sp.Derivative): + add_edge(str(node), _cvt_to_key(node), u_color=C_FUNC) + add_edge(_cvt_to_key(node.args[0]), str(node), v_color=C_FUNC) + for arg in node.args[1:]: + add_edge(_cvt_to_key(arg[0]), str(node), v_color=C_FUNC) + + # export graph to image + from ppsci.utils import logger + + graph.layout() + image_path = f"{graph_filename}.png" + dot_path = f"{graph_filename}.dot" + graph.draw(image_path, prog="dot") + graph.write(dot_path) + logger.message( + f"Computational graph has been writen to {image_path} and {dot_path}. " + "dot file can be visualized at https://dreampuf.github.io/GraphvizOnline/" + ) + + +def lambdify( + expr: sp.Expr, + models: Optional[Union[arch.Arch, Tuple[arch.Arch, ...]]] = None, + extra_parameters: Optional[Sequence[paddle.Tensor]] = None, + graph_filename: Optional[str] = None, +) -> ComposedNode: + """Convert sympy expression to callable function. + + Args: + expr (sp.Expr): Sympy expression to be converted. + models (Optional[Union[arch.Arch, Tuple[arch.Arch, ...]]]): Model(s) for + computing forward result in `LayerNode`. + extra_parameters (Optional[nn.ParameterList]): Extra learnable parameters. + Defaults to None. + graph_filename (Optional[str]): Save computational graph to `graph_filename.png` + for given `expr`, if `graph_filename` is not None and a valid string, + such as 'momentum_x'. Defaults to None. + + Returns: + ComposedNode: Callable object for computing expr with necessary input(s) data + in dict given. + + Examples: + >>> import paddle + >>> import ppsci + >>> import sympy as sp + + >>> a, b, c, x, y = sp.symbols("a b c x y") + >>> u = sp.Function("u")(x, y) + >>> v = sp.Function("v")(x, y) + >>> z = -a + b * (c ** 2) + u * v + 2.3 + + >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v"), 4, 16) + + >>> batch_size = 13 + >>> a_tensor = paddle.randn([batch_size, 1]) + >>> b_tensor = paddle.randn([batch_size, 1]) + >>> c_tensor = paddle.randn([batch_size, 1]) + >>> x_tensor = paddle.randn([batch_size, 1]) + >>> y_tensor = paddle.randn([batch_size, 1]) + + >>> model_output_dict = model({"x": x_tensor, "y": y_tensor}) + >>> u_tensor, v_tensor = model_output_dict["u"], model_output_dict["v"] + + >>> z_tensor_manually = ( + ... -a_tensor + b_tensor * (c_tensor ** 2) + ... + u_tensor * v_tensor + 2.3 + ... ) + >>> z_tensor_sympy = ppsci.lambdify(z, model)( + ... { + ... "a": a_tensor, + ... "b": b_tensor, + ... "c": c_tensor, + ... "x": x_tensor, + ... "y": y_tensor, + ... } + ... ) + + >>> paddle.allclose(z_tensor_manually, z_tensor_sympy).item() + True + """ + + # NOTE: Those simplify methods may complicate given expr instead, so not use here + # simplify expression to reduce nodes in tree + # expr = sp.nsimplify(expr) + # expr = sp.expand(expr) + # expr = sp.simplify(expr) + + # remove 1.0 from sympy expression tree + expr = expr.subs(1.0, 1) + + # convert sympy expression tree to list of nodes in postorder + sympy_nodes = [] + sympy_nodes = _post_traverse(expr, sympy_nodes) + + # remove unnecessary symbol nodes already in input dict(except for paramter symbol) + if not extra_parameters: + extra_parameters = () + _parameter_names = tuple(param.name for param in extra_parameters) + sympy_nodes = [ + node + for node in sympy_nodes + if (not node.is_Symbol) or (_cvt_to_key(node) in _parameter_names) + ] + + # remove duplicates with topo-order kept + sympy_nodes = list(dict.fromkeys(sympy_nodes)) + + if isinstance(models, arch.ModelList): + models = tuple(models.model_list[i] for i in range(len(models.model_list))) + if not isinstance(models, (tuple, list)): + models = (models,) + + # convert sympy node to callable node + callable_nodes = [] + for i, node in enumerate(sympy_nodes): + if isinstance( + node, tuple(SYMPT_TO_PADDLE.keys()) + (sp.Add, sp.Mul, sp.Derivative) + ): + callable_nodes.append(OperatorNode(node)) + elif isinstance(node, sp.Function): + if node.name == equation.DETACH_FUNC_NAME: + callable_nodes.append(DetachNode(node)) + else: + match_index = None + for j, model in enumerate(models): + if str(node.func.name) in model.output_keys: + callable_nodes.append( + LayerNode( + node, + model, + ) + ) + if match_index is not None: + raise ValueError( + f"Name of function({node}) should be unique along given" + f" models, but got same output_key({node.func.name}) " + f"in models[{match_index}] and models[{j}]." + ) + match_index = j + elif node.is_Number or node.is_NumberSymbol: + callable_nodes.append(ConstantNode(node)) + elif isinstance(node, sp.Symbol): + callable_nodes.append( + ParameterNode( + node, + *[param for param in extra_parameters if param.name == node.name], + ) + ) + else: + raise NotImplementedError(f"The node {node} is not supported in lambdify.") + + # NOTE: Visualize computational graph using 'pygraphviz' + if isinstance(graph_filename, str): + _visualize_graph(sympy_nodes, graph_filename) + + # Compose callable nodes into one callable object + return ComposedNode(callable_nodes) diff --git a/ppsci/validate/geo_validator.py b/ppsci/validate/geo_validator.py index 9f46051941..6741baddc7 100644 --- a/ppsci/validate/geo_validator.py +++ b/ppsci/validate/geo_validator.py @@ -23,7 +23,6 @@ import numpy as np import paddle import sympy -from sympy.parsing import sympy_parser as sp_parser from typing_extensions import Literal from ppsci import geometry @@ -85,13 +84,9 @@ def __init__( name: Optional[str] = None, ): self.output_expr = output_expr - for label_name, expr in self.output_expr.items(): - if isinstance(expr, str): - self.output_expr[label_name] = sp_parser.parse_expr(expr) - self.label_dict = label_dict self.input_keys = geom.dim_keys - self.output_keys = list(label_dict.keys()) + self.output_keys = tuple(label_dict.keys()) nx = dataloader_cfg["total_size"] self.num_timestamps = 1 diff --git a/ppsci/validate/sup_validator.py b/ppsci/validate/sup_validator.py index a45a28f983..56f2b9a50a 100644 --- a/ppsci/validate/sup_validator.py +++ b/ppsci/validate/sup_validator.py @@ -75,7 +75,9 @@ def __init__( self.input_keys = _dataset.input_keys self.output_keys = ( - list(output_expr.keys()) if output_expr is not None else _dataset.label_keys + tuple(output_expr.keys()) + if output_expr is not None + else _dataset.label_keys ) if self.output_expr is None: diff --git a/test/equation/test_biharmonic.py b/test/equation/test_biharmonic.py index 314393844e..8e1d6c2be0 100644 --- a/test/equation/test_biharmonic.py +++ b/test/equation/test_biharmonic.py @@ -1,7 +1,9 @@ import paddle import pytest -from paddle import nn +import sympy as sp +import ppsci +from ppsci import arch from ppsci import equation __all__ = [] @@ -29,13 +31,10 @@ def test_biharmonic(dim): input_data = paddle.concat([x, y, z], axis=1) # build NN model - model = nn.Sequential( - nn.Linear(len(input_dims), len(output_dims)), - nn.Tanh(), - ) + model = arch.MLP(input_dims, output_dims, 2, 16) # manually generate output - u = model(input_data) + u = model.forward_tensor(input_data) # use self-defined jacobian and hessian def jacobian(y: "paddle.Tensor", x: "paddle.Tensor") -> "paddle.Tensor": @@ -57,6 +56,12 @@ def hessian(y: "paddle.Tensor", x: "paddle.Tensor") -> "paddle.Tensor": # compute result using built-in Biharmonic module biharmonic_equation = equation.Biharmonic(dim=dim, q=q, D=D) + for name, expr in biharmonic_equation.equations.items(): + if isinstance(expr, sp.Basic): + biharmonic_equation.equations[name] = ppsci.lambdify( + expr, + model, + ) data_dict = { "x": x, "y": y, diff --git a/test/equation/test_laplace.py b/test/equation/test_laplace.py index ce41e47cee..6c438df3e4 100644 --- a/test/equation/test_laplace.py +++ b/test/equation/test_laplace.py @@ -1,7 +1,9 @@ import paddle import pytest -from paddle import nn +import sympy as sp +import ppsci +from ppsci import arch from ppsci import equation __all__ = [] @@ -26,13 +28,10 @@ def test_l1loss_mean(dim): input_data = paddle.concat([x, y, z], axis=1) # build NN model - model = nn.Sequential( - nn.Linear(len(input_dims), len(output_dims)), - nn.Tanh(), - ) + model = arch.MLP(input_dims, output_dims, 2, 16) # manually generate output - u = model(input_data) + u = model.forward_tensor(input_data) # use self-defined jacobian and hessian def jacobian(y: "paddle.Tensor", x: "paddle.Tensor") -> "paddle.Tensor": @@ -48,6 +47,13 @@ def hessian(y: "paddle.Tensor", x: "paddle.Tensor") -> "paddle.Tensor": # compute result using built-in Laplace module laplace_equation = equation.Laplace(dim=dim) + for name, expr in laplace_equation.equations.items(): + if isinstance(expr, sp.Basic): + laplace_equation.equations[name] = ppsci.lambdify( + expr, + model, + ) + data_dict = { "x": x, "y": y, diff --git a/test/equation/test_linear_elasticity.py b/test/equation/test_linear_elasticity.py index c444effa37..973e3df104 100644 --- a/test/equation/test_linear_elasticity.py +++ b/test/equation/test_linear_elasticity.py @@ -1,7 +1,9 @@ import paddle import pytest -from paddle import nn +import sympy as sp +import ppsci +from ppsci import arch from ppsci import equation @@ -123,7 +125,32 @@ def traction_z_expected_result( ], ) def test_linear_elasticity(E, nu, lambda_, mu, rho, dim, time): + paddle.seed(42) batch_size = 13 + input_dims = ("x", "y", "z")[:dim] + if time: + input_dims += ("t",) + output_dims = ( + ( + "u", + "v", + "sigma_xx", + "sigma_yy", + "sigma_xy", + ) + if dim == 2 + else ( + "u", + "v", + "w", + "sigma_xx", + "sigma_yy", + "sigma_xy", + "sigma_zz", + "sigma_xz", + "sigma_yz", + ) + ) x = paddle.randn([batch_size, 1]) y = paddle.randn([batch_size, 1]) z = paddle.randn([batch_size, 1]) if dim == 3 else None @@ -145,12 +172,14 @@ def test_linear_elasticity(E, nu, lambda_, mu, rho, dim, time): if dim == 3: input_data = paddle.concat([input_data, z], axis=1) - model = nn.Sequential( - nn.Linear(input_data.shape[1], 9 if dim == 3 else 5), - nn.Tanh(), - ) + model = arch.MLP(input_dims, output_dims, 2, 16) + + # model = nn.Sequential( + # nn.Linear(input_data.shape[1], 9 if dim == 3 else 5), + # nn.Tanh(), + # ) - output = model(input_data) + output = model.forward_tensor(input_data) u, v, *other_outputs = paddle.split(output, num_or_sections=output.shape[1], axis=1) @@ -201,15 +230,20 @@ def test_linear_elasticity(E, nu, lambda_, mu, rho, dim, time): linear_elasticity = equation.LinearElasticity( E=E, nu=nu, lambda_=lambda_, mu=mu, rho=rho, dim=dim, time=time ) - + for name, expr in linear_elasticity.equations.items(): + if isinstance(expr, sp.Basic): + linear_elasticity.equations[name] = ppsci.lambdify( + expr, + model, + ) data_dict = { + "t": t, "x": x, "y": y, + "z": z, "u": u, "v": v, - "z": z, "w": w, - "t": t, "sigma_xx": sigma_xx, "sigma_xy": sigma_xy, "sigma_xz": sigma_xz, @@ -220,6 +254,14 @@ def test_linear_elasticity(E, nu, lambda_, mu, rho, dim, time): "normal_y": normal_y, "normal_z": normal_z, } + if not time: + data_dict.pop("t") + if dim == 2: + data_dict.pop("w") + data_dict.pop("sigma_xz") + data_dict.pop("sigma_yz") + data_dict.pop("sigma_zz") + data_dict.pop("normal_z") test_output_names = [ "stress_disp_xx", @@ -267,7 +309,7 @@ def test_linear_elasticity(E, nu, lambda_, mu, rho, dim, time): ) for name in test_output_names: - assert paddle.allclose(expected_output[name], test_output[name]) + assert paddle.allclose(expected_output[name], test_output[name], atol=1e-7) if __name__ == "__main__": diff --git a/test/equation/test_navier_stokes.py b/test/equation/test_navier_stokes.py index 9888fb5e10..0279374ac8 100644 --- a/test/equation/test_navier_stokes.py +++ b/test/equation/test_navier_stokes.py @@ -1,7 +1,9 @@ import paddle import pytest -from paddle import nn +import sympy as sp +import ppsci +from ppsci import arch from ppsci import equation @@ -26,8 +28,8 @@ def momentum_x_compute_func( momentum_x = ( u * jacobian(u, x) + v * jacobian(u, y) - - nu / rho * hessian(u, x) - - nu / rho * hessian(u, y) + - nu * hessian(u, x) + - nu * hessian(u, y) + 1 / rho * jacobian(p, x) ) @@ -35,7 +37,7 @@ def momentum_x_compute_func( momentum_x += jacobian(u, t) if dim == 3: momentum_x += w * jacobian(u, z) - momentum_x -= nu / rho * hessian(u, z) + momentum_x -= nu * hessian(u, z) return momentum_x @@ -45,8 +47,8 @@ def momentum_y_compute_func( momentum_y = ( u * jacobian(v, x) + v * jacobian(v, y) - - nu / rho * hessian(v, x) - - nu / rho * hessian(v, y) + - nu * hessian(v, x) + - nu * hessian(v, y) + 1 / rho * jacobian(p, y) ) @@ -54,7 +56,7 @@ def momentum_y_compute_func( momentum_y += jacobian(v, t) if dim == 3: momentum_y += w * jacobian(v, z) - momentum_y -= nu / rho * hessian(v, z) + momentum_y -= nu * hessian(v, z) return momentum_y @@ -65,9 +67,9 @@ def momentum_z_compute_func( u * jacobian(w, x) + v * jacobian(w, y) + w * jacobian(w, z) - - nu / rho * hessian(w, x) - - nu / rho * hessian(w, y) - - nu / rho * hessian(w, z) + - nu * hessian(w, x) + - nu * hessian(w, y) + - nu * hessian(w, z) + 1 / rho * jacobian(p, z) ) if time: @@ -91,40 +93,33 @@ def test_navierstokes(nu, rho, dim, time): y = paddle.randn([batch_size, 1]) x.stop_gradient = False y.stop_gradient = False - input_dims = 2 + + input_dims = ("x", "y") + output_dims = ("u", "v", "p") if dim == 2 else ("u", "v", "w", "p") inputs = (x, y) + if time: t = paddle.randn([batch_size, 1]) t.stop_gradient = False inputs = (t,) + inputs - input_dims += 1 + input_dims = ("t",) + input_dims if dim == 3: z = paddle.randn([batch_size, 1]) z.stop_gradient = False inputs = inputs + (z,) - input_dims += 1 + input_dims = input_dims + ("z",) input_data = paddle.concat(inputs, axis=1) - """ - Use the relatively simple Multilayer Perceptron - to represent the mapping function from (t, x, y, z) to (u, v, w, p): - f(x, y) = (u, v, p) or - f(t, x, y) = (u, v, p) or - f(t, x, y, z) = (u, v, w, p) - """ - model = nn.Sequential( - nn.Linear(input_dims, 3 if dim == 2 else 4), - nn.Tanh(), - ) + model = arch.MLP(input_dims, output_dims, 2, 16) # manually generate output - output = model(input_data) + output = model.forward_tensor(input_data) if dim == 2: - u, v, p = paddle.split(output, num_or_sections=output.shape[1], axis=1) + u, v, p = paddle.split(output, num_or_sections=len(output_dims), axis=1) w, z = None, None else: - u, v, w, p = paddle.split(output, num_or_sections=output.shape[1], axis=1) + u, v, w, p = paddle.split(output, num_or_sections=len(output_dims), axis=1) if not time: t = None expected_continuity = continuity_compute_func(x=x, y=y, u=u, v=v, dim=dim, w=w, z=z) @@ -141,6 +136,12 @@ def test_navierstokes(nu, rho, dim, time): # compute result using NavierStokes class navier_stokes_equation = equation.NavierStokes(nu=nu, rho=rho, dim=dim, time=time) + for name, expr in navier_stokes_equation.equations.items(): + if isinstance(expr, sp.Basic): + navier_stokes_equation.equations[name] = ppsci.lambdify( + expr, + model, + ) data_dict = {"x": x, "y": y, "u": u, "v": v, "p": p} if time: @@ -156,9 +157,7 @@ def test_navierstokes(nu, rho, dim, time): ] if dim == 3: - test_output_names.append( - "momentum_z", - ) + test_output_names.append("momentum_z") test_output = {} for name in test_output_names: @@ -174,7 +173,7 @@ def test_navierstokes(nu, rho, dim, time): # check result whether is equal for name in test_output_names: - assert paddle.allclose(expected_output[name], test_output[name]) + assert paddle.allclose(expected_output[name], test_output[name]), f"{name}" if __name__ == "__main__": diff --git a/test/equation/test_normal_dot_vec.py b/test/equation/test_normal_dot_vec.py index 6b930a2715..e701d2ea68 100644 --- a/test/equation/test_normal_dot_vec.py +++ b/test/equation/test_normal_dot_vec.py @@ -1,6 +1,9 @@ import paddle import pytest +import sympy as sp +import ppsci +from ppsci import arch from ppsci import equation @@ -13,15 +16,34 @@ def compute_func(x: tuple, y: tuple): def test_normal_dot_vel(): batch_size = 13 - u = paddle.randn([batch_size, 1]) - v = paddle.randn([batch_size, 1]) - w = paddle.randn([batch_size, 1]) + x = paddle.randn([batch_size, 1]) + y = paddle.randn([batch_size, 1]) + z = paddle.randn([batch_size, 1]) + input_dims = ("x", "y", "z") + output_dims = ("u", "v", "w") + model = arch.MLP(input_dims, output_dims, 2, 16) + output_dict = model( + { + "x": x, + "y": y, + "z": z, + } + ) + u = output_dict["u"] + v = output_dict["v"] + w = output_dict["w"] normal_x = paddle.randn([batch_size, 1]) normal_y = paddle.randn([batch_size, 1]) normal_z = paddle.randn([batch_size, 1]) - pde = equation.NormalDotVec(("u", "v", "w")) + norm_doc_vec = equation.NormalDotVec(output_dims) + for name, expr in norm_doc_vec.equations.items(): + if isinstance(expr, sp.Basic): + norm_doc_vec.equations[name] = ppsci.lambdify( + expr, + model, + ) out = { "u": u, "v": v, @@ -32,7 +54,9 @@ def test_normal_dot_vel(): } expected_result = compute_func((u, v, w), (normal_x, normal_y, normal_z)) - assert paddle.allclose(pde.equations["normal_dot_vel"](out), expected_result) + assert paddle.allclose( + norm_doc_vec.equations["normal_dot_vec"](out), expected_result + ) if __name__ == "__main__": diff --git a/test/equation/test_poisson.py b/test/equation/test_poisson.py index 502acb3103..ca86d98db2 100644 --- a/test/equation/test_poisson.py +++ b/test/equation/test_poisson.py @@ -14,8 +14,10 @@ import paddle import pytest -from paddle import nn +import sympy as sp +import ppsci +from ppsci import arch from ppsci import equation __all__ = [] @@ -40,13 +42,10 @@ def test_poisson(dim): input_data = paddle.concat([x, y, z], axis=1) # build NN model - model = nn.Sequential( - nn.Linear(len(input_dims), len(output_dims)), - nn.Tanh(), - ) + model = arch.MLP(input_dims, output_dims, 2, 16) # manually generate output - p = model(input_data) + p = model.forward_tensor(input_data) def jacobian(y: paddle.Tensor, x: paddle.Tensor) -> paddle.Tensor: return paddle.grad(y, x, create_graph=True)[0] @@ -61,6 +60,13 @@ def hessian(y: paddle.Tensor, x: paddle.Tensor) -> paddle.Tensor: # compute result using built-in Laplace module poisson_equation = equation.Poisson(dim=dim) + for name, expr in poisson_equation.equations.items(): + if isinstance(expr, sp.Basic): + poisson_equation.equations[name] = ppsci.lambdify( + expr, + model, + ) + data_dict = { "x": x, "y": y, diff --git a/test/equation/test_viv.py b/test/equation/test_viv.py index b5567727d3..2dfc4f7781 100644 --- a/test/equation/test_viv.py +++ b/test/equation/test_viv.py @@ -1,8 +1,10 @@ import paddle import pytest -from paddle import nn +import sympy as sp from paddle.nn import initializer +import ppsci +from ppsci import arch from ppsci.equation.pde import Vibration @@ -11,13 +13,15 @@ def test_vibration(rho, k1, k2): """Test for Vibration equation.""" batch_size = 13 rho = rho - k1 = paddle.create_parameter( + k11 = paddle.create_parameter( shape=[], dtype=paddle.get_default_dtype(), + name="k11", default_initializer=initializer.Constant(k1), ) - k2 = paddle.create_parameter( + k22 = paddle.create_parameter( shape=[], + name="k22", dtype=paddle.get_default_dtype(), default_initializer=initializer.Constant(k2), ) @@ -27,13 +31,12 @@ def test_vibration(rho, k1, k2): eta.stop_gradient = False t_f.stop_gradient = False input_data = paddle.concat([eta, t_f], axis=1) - model = nn.Sequential( - nn.Linear(2, 1), - nn.Tanh(), - ) + input_dims = ("eta", "t_f") + output_dims = ("f",) + model = arch.MLP(input_dims, output_dims, 2, 16) # manually generate output - eta = model(input_data) + eta = model.forward_tensor(input_data) def jacobian(y: paddle.Tensor, x: paddle.Tensor) -> paddle.Tensor: return paddle.grad(y, x, create_graph=True)[0] @@ -43,12 +46,19 @@ def hessian(y: paddle.Tensor, x: paddle.Tensor) -> paddle.Tensor: expected_result = ( rho * hessian(eta, t_f) - + paddle.exp(k1) * jacobian(eta, t_f) - + paddle.exp(k2) * eta + + paddle.exp(k11) * jacobian(eta, t_f) + + paddle.exp(k22) * eta ) # compute result using Vibration class vibration_equation = Vibration(rho=rho, k1=k1, k2=k2) + for name, expr in vibration_equation.equations.items(): + if isinstance(expr, sp.Basic): + vibration_equation.equations[name] = ppsci.lambdify( + expr, + model, + vibration_equation.learnable_parameters, + ) data_dict = {"eta": eta, "t_f": t_f} test_result = vibration_equation.equations["f"](data_dict) # check result whether is equal diff --git a/test/utils/speed_test_navier_stokes.py b/test/utils/speed_test_navier_stokes.py deleted file mode 100644 index 838a38c740..0000000000 --- a/test/utils/speed_test_navier_stokes.py +++ /dev/null @@ -1,476 +0,0 @@ -import time as time_module - -import paddle -import sympy as sp - -from ppsci import arch -from ppsci import equation -from ppsci.autodiff import clear -from ppsci.autodiff import hessian as H -from ppsci.autodiff import jacobian as J -from ppsci.utils import sym_to_func - - -class NavierStokes_sympy: - def __init__(self, nu, rho, dim, time): - # set params - self.dim = dim - self.time = time - - # coordinates - x, y, z = sp.Symbol("x"), sp.Symbol("y"), sp.Symbol("z") - - # time - t = sp.Symbol("t") - - # make input variables - input_variables = {"x": x, "y": y, "z": z, "t": t} - if self.dim == 2: - input_variables.pop("z") - if not self.time: - input_variables.pop("t") - - # velocity componets - u = sp.Function("u")(*input_variables) - v = sp.Function("v")(*input_variables) - if self.dim == 3: - w = sp.Function("w")(*input_variables) - else: - w = sp.Number(0) - - # pressure - p = sp.Function("p")(*input_variables) - - # kinematic viscosity - if isinstance(nu, str): - nu = sp.Function(nu)(*input_variables) - elif isinstance(nu, (float, int)): - nu = sp.Number(nu) - - # density - if isinstance(rho, str): - rho = sp.Function(rho)(*input_variables) - elif isinstance(rho, (float, int)): - rho = sp.Number(rho) - - # dynamic viscosity - mu = rho * nu - - # set equations - self.equations = {} - self.equations["continuity"] = ( - rho.diff(t) + (rho * u).diff(x) + (rho * v).diff(y) + (rho * w).diff(z) - ) - - curl = sp.Number(0) if rho.diff(x) == 0 else u.diff(x) + v.diff(y) + w.diff(z) - self.equations["momentum_x"] = ( - (rho * u).diff(t) - + ( - u * ((rho * u).diff(x)) - + v * ((rho * u).diff(y)) - + w * ((rho * u).diff(z)) - + rho * u * (curl) - ) - + p.diff(x) - - (-2 / 3 * mu * (curl)).diff(x) - - (mu * u.diff(x)).diff(x) - - (mu * u.diff(y)).diff(y) - - (mu * u.diff(z)).diff(z) - - (mu * (curl).diff(x)) - ) - self.equations["momentum_y"] = ( - (rho * v).diff(t) - + ( - u * ((rho * v).diff(x)) - + v * ((rho * v).diff(y)) - + w * ((rho * v).diff(z)) - + rho * v * (curl) - ) - + p.diff(y) - - (-2 / 3 * mu * (curl)).diff(y) - - (mu * v.diff(x)).diff(x) - - (mu * v.diff(y)).diff(y) - - (mu * v.diff(z)).diff(z) - - (mu * (curl).diff(y)) - ) - self.equations["momentum_z"] = ( - (rho * w).diff(t) - + ( - u * ((rho * w).diff(x)) - + v * ((rho * w).diff(y)) - + w * ((rho * w).diff(z)) - + rho * w * (curl) - ) - + p.diff(z) - - (-2 / 3 * mu * (curl)).diff(z) - - (mu * w.diff(x)).diff(x) - - (mu * w.diff(y)).diff(y) - - (mu * w.diff(z)).diff(z) - - (mu * (curl).diff(z)) - ) - - if self.dim == 2: - self.equations.pop("momentum_z") - - -class ZeroEquation_sympy: - def __init__( - self, nu, max_distance, rho=1, dim=3, time=True - ): # TODO add density into model - # set params - self.dim = dim - self.time = time - - # model coefficients - self.max_distance = max_distance - self.karman_constant = 0.419 - self.max_distance_ratio = 0.09 - - # coordinates - x, y, z = sp.Symbol("x"), sp.Symbol("y"), sp.Symbol("z") - - # time - t = sp.Symbol("t") - - # make input variables - input_variables = {"x": x, "y": y, "z": z, "t": t} - if self.dim == 2: - input_variables.pop("z") - if not self.time: - input_variables.pop("t") - - # velocity componets - u = sp.Function("u")(*input_variables) - v = sp.Function("v")(*input_variables) - if self.dim == 3: - w = sp.Function("w")(*input_variables) - else: - w = sp.Number(0) - - # density - if type(rho) is str: - rho = sp.Function(rho)(*input_variables) - elif type(rho) in [float, int]: - rho = sp.Number(rho) - - # wall distance - normal_distance = sp.Function("sdf")(*input_variables) - - # mixing length - mixing_length = sp.Min( - self.karman_constant * normal_distance, - self.max_distance_ratio * self.max_distance, - ) - G = ( - 2 * u.diff(x) ** 2 - + 2 * v.diff(y) ** 2 - + 2 * w.diff(z) ** 2 - + (u.diff(y) + v.diff(x)) ** 2 - + (u.diff(z) + w.diff(x)) ** 2 - + (v.diff(z) + w.diff(y)) ** 2 - ) - - # set equations - self.equations = {} - self.equations["nu"] = nu + rho * mixing_length**2 * sp.sqrt(G) - - -def compute_with_sympy(input_dicts, nu, rho, dim, time, model): - """Test for navier_stokes equation.""" - # define input/output keys - ze = ZeroEquation_sympy(nu=nu, rho=rho, dim=dim, max_distance=3.4, time=time) - nu_sympy = ze.equations["nu"] - - input_keys = ("x", "y", "z")[:dim] - if time: - input_keys = ("t",) + input_keys - - output_keys = ("u", "v") - if dim == 3: - output_keys += ("w",) - output_keys += ("p",) - - # prepare input data in dict - cost_list = [] - # prepare python function expressions and sympy-expression in dict - sympy_expr_dict = NavierStokes_sympy(nu_sympy, rho, dim, time).equations - for target, expr in sympy_expr_dict.items(): - sympy_expr_dict[target] = sym_to_func.sympy_to_function( - expr, - [ - model, - ], - ) - for i, input_dict in enumerate(input_dicts): - input_dict = input_dicts[i] - - # compute equation with funciton converted from sympy - output_dict_sympy = {k: v for k, v in input_dict.items()} - tmp = {k: v for k, v in output_dict_sympy.items()} - beg = time_module.perf_counter() - for name, expr in sympy_expr_dict.items(): - output = expr(tmp) - output_dict_sympy[name] = output - for key in model.output_keys: - output_dict_sympy[key] = tmp[key] - clear() - end = time_module.perf_counter() - cost_list.append(end - beg) - - # test for result - print( - f"compute_with_sympy overhead: {sum(cost_list[10:]) / len(cost_list[10:]):.5f}" - ) - return output_dict_sympy - - -def compute_with_pyfunc(input_dicts, nu, rho, dim, time, model): - def continuity_f(out): - x, y = out["x"], out["y"] - u, v = out["u"], out["v"] - return 1.0 * J(u, x) + 1.0 * J(v, y) - - def momentum_x_f(out): - x, y = out["x"], out["y"] - u, v, p = out["u"], out["v"], out["p"] - if time: - t = out["t"] - return ( - -( - 1.0 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - + 2.0 - ) - * H(u, x) - - ( - 1.0 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - + 2.0 - ) - * H(u, y) - - ( - 1.0 - * ( - (J(u, y) + J(v, x)) * (2 * H(u, y) + 2 * J(J(v, x), y)) / 2 - + 2 * J(u, x) * J(J(u, x), y) - + 2 * J(v, y) * H(v, y) - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - / paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - + 0.838 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.heaviside(0.306 - 0.419 * out["sdf"], paddle.zeros([])) - * out["sdf__y"] - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ) - * J(u, y) - - ( - 1.0 - * ( - (J(u, y) + J(v, x)) * (2 * H(v, x) + 2 * J(J(u, x), y)) / 2 - + 2 * J(u, x) * H(u, x) - + 2 * J(v, y) * J(J(v, x), y) - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - / paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - + 0.838 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.heaviside(0.306 - 0.419 * out["sdf"], paddle.zeros([])) - * out["sdf__x"] - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ) - * J(u, x) - + (1.0 * u * J(u, x) + 1.0 * v * J(u, y) + J(p, x)) - + (J(u, t) if time else 0) - ) - - def momentum_y_f(out): - x, y = out["x"], out["y"] - u, v, p = out["u"], out["v"], out["p"] - if time: - t = out["t"] - return ( - -( - 1.0 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - + 2.0 - ) - * H(v, x) - - ( - 1.0 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - + 2.0 - ) - * H(v, y) - - ( - 1.0 - * ( - (J(u, y) + J(v, x)) * (2 * H(u, y) + 2 * J(J(v, x), y)) / 2 - + 2 * J(u, x) * J(J(u, x), y) - + 2 * J(v, y) * H(v, y) - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - / paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - + 0.838 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.heaviside(0.306 - 0.419 * out["sdf"], paddle.zeros([])) - * out["sdf__y"] - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ) - * J(v, y) - - ( - 1.0 - * ( - (J(u, y) + J(v, x)) * (2 * H(v, x) + 2 * J(J(u, x), y)) / 2 - + 2 * J(u, x) * H(u, x) - + 2 * J(v, y) * J(J(v, x), y) - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - / paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - + 0.838 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.heaviside(0.306 - 0.419 * out["sdf"], paddle.zeros([])) - * out["sdf__x"] - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ) - * J(v, x) - + (1.0 * u * J(v, x) + 1.0 * v * J(v, y) + J(p, y)) - + (J(v, t) if time else 0) - ) - - """Test for navier_stokes equation.""" - # define input/output keys - - # prepare input data in dict - cost_list = [] - for i, input_dict in enumerate(input_dicts): - input_dict = input_dicts[i] - - # prepare python function expressions in dict - functional_expr_dict = equation.NavierStokes(nu, rho, dim, time).equations - functional_expr_dict["continuity"] = continuity_f - functional_expr_dict["momentum_x"] = momentum_x_f - functional_expr_dict["momentum_y"] = momentum_y_f - - # compute equation with python function - output_dict_functional = model(input_dict) - beg = time_module.perf_counter() - for name, expr in functional_expr_dict.items(): - if callable(expr): - output_dict_functional[name] = expr( - {**output_dict_functional, **input_dict} - ) - else: - raise TypeError(f"expr type({type(expr)}) is invalid") - clear() - end = time_module.perf_counter() - cost_list.append(end - beg) - - # test for result - print( - f"compute_with_pyfunc overhead: {sum(cost_list[10:]) / len(cost_list[10:]):.5f}" - ) - return output_dict_functional - - -if __name__ == "__main__": - input_keys = ("t", "x", "y") - output_keys = ("u", "v", "p") - nu = 2 - rho = 1 - dim = 2 - time = True - model = arch.MLP(input_keys, output_keys, 4, 50) - - batch_size = 2048 - input_dicts = [] - for i in range(50): - input_dict = {} - for var in input_keys: - input_dict[var] = paddle.randn([batch_size, 1]) - input_dict[var].stop_gradient = False - if var != "t": - input_dict[f"sdf__{var}"] = paddle.randn([batch_size, 1]) - input_dict[f"normal__{var}"] = paddle.randn([batch_size, 1]) - - input_dict[f"sdf__{var}"].stop_gradient = False - input_dict[f"normal__{var}"].stop_gradient = False - - input_dict["sdf"] = paddle.randn([batch_size, 1]) - input_dict["sdf"].stop_gradient = False - input_dicts.append(input_dict) - - output_dict_sympy = compute_with_sympy( - input_dicts, nu=nu, rho=rho, dim=dim, time=time, model=model - ) - output_dict_pyfunc = compute_with_pyfunc( - input_dicts, nu=nu, rho=rho, dim=dim, time=time, model=model - ) - - for key in output_dict_pyfunc: - if not paddle.allclose( - output_dict_sympy[key], output_dict_pyfunc[key], atol=1e-7 - ): - print(f"{key} {output_dict_sympy[key]}\n{output_dict_pyfunc[key]}") - else: - print(f"{key} check pass") diff --git a/test/utils/test_linear_elasticity_sympy.py b/test/utils/test_linear_elasticity_sympy.py deleted file mode 100644 index aeddf7f53f..0000000000 --- a/test/utils/test_linear_elasticity_sympy.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -import pytest -from sympy import Function -from sympy import Number -from sympy import Symbol - -import ppsci -from ppsci import equation -from ppsci.autodiff import clear -from ppsci.utils import sym_to_func - -__all__ = [] - - -class LinearElasticity_sympy: - def __init__( - self, E=None, nu=None, lambda_=None, mu=None, rho=1, dim=3, time=False - ): - - # set params - self.dim = dim - self.time = time - - # coordinates - x, y, z = Symbol("x"), Symbol("y"), Symbol("z") - normal_x, normal_y, normal_z = ( - Symbol("normal_x"), - Symbol("normal_y"), - Symbol("normal_z"), - ) - - # time - t = Symbol("t") - - # make input variables - input_variables = {"x": x, "y": y, "z": z, "t": t} - if self.dim == 2: - input_variables.pop("z") - if not self.time: - input_variables.pop("t") - - # displacement componets - u = Function("u")(*input_variables) - v = Function("v")(*input_variables) - sigma_xx = Function("sigma_xx")(*input_variables) - sigma_yy = Function("sigma_yy")(*input_variables) - sigma_xy = Function("sigma_xy")(*input_variables) - if self.dim == 3: - w = Function("w")(*input_variables) - sigma_zz = Function("sigma_zz")(*input_variables) - sigma_xz = Function("sigma_xz")(*input_variables) - sigma_yz = Function("sigma_yz")(*input_variables) - else: - w = Number(0) - sigma_zz = Number(0) - sigma_xz = Number(0) - sigma_yz = Number(0) - - # material properties - if lambda_ is None: - if isinstance(nu, str): - nu = Function(nu)(*input_variables) - elif isinstance(nu, (float, int)): - nu = Number(nu) - if isinstance(E, str): - E = Function(E)(*input_variables) - elif isinstance(E, (float, int)): - E = Number(E) - lambda_ = nu * E / ((1 + nu) * (1 - 2 * nu)) - mu = E / (2 * (1 + nu)) - else: - if isinstance(lambda_, str): - lambda_ = Function(lambda_)(*input_variables) - elif isinstance(lambda_, (float, int)): - lambda_ = Number(lambda_) - if isinstance(mu, str): - mu = Function(mu)(*input_variables) - elif isinstance(mu, (float, int)): - mu = Number(mu) - if isinstance(rho, str): - rho = Function(rho)(*input_variables) - elif isinstance(rho, (float, int)): - rho = Number(rho) - - # set equations - self.equations = {} - - # Stress equations - self.equations["stress_disp_xx"] = ( - lambda_ * (u.diff(x) + v.diff(y) + w.diff(z)) - + 2 * mu * u.diff(x) - - sigma_xx - ) - self.equations["stress_disp_yy"] = ( - lambda_ * (u.diff(x) + v.diff(y) + w.diff(z)) - + 2 * mu * v.diff(y) - - sigma_yy - ) - self.equations["stress_disp_zz"] = ( - lambda_ * (u.diff(x) + v.diff(y) + w.diff(z)) - + 2 * mu * w.diff(z) - - sigma_zz - ) - self.equations["stress_disp_xy"] = mu * (u.diff(y) + v.diff(x)) - sigma_xy - self.equations["stress_disp_xz"] = mu * (u.diff(z) + w.diff(x)) - sigma_xz - self.equations["stress_disp_yz"] = mu * (v.diff(z) + w.diff(y)) - sigma_yz - - # Equations of equilibrium - self.equations["equilibrium_x"] = rho * ((u.diff(t)).diff(t)) - ( - sigma_xx.diff(x) + sigma_xy.diff(y) + sigma_xz.diff(z) - ) - self.equations["equilibrium_y"] = rho * ((v.diff(t)).diff(t)) - ( - sigma_xy.diff(x) + sigma_yy.diff(y) + sigma_yz.diff(z) - ) - self.equations["equilibrium_z"] = rho * ((w.diff(t)).diff(t)) - ( - sigma_xz.diff(x) + sigma_yz.diff(y) + sigma_zz.diff(z) - ) - - # Traction equations - self.equations["traction_x"] = ( - normal_x * sigma_xx + normal_y * sigma_xy + normal_z * sigma_xz - ) - self.equations["traction_y"] = ( - normal_x * sigma_xy + normal_y * sigma_yy + normal_z * sigma_yz - ) - self.equations["traction_z"] = ( - normal_x * sigma_xz + normal_y * sigma_yz + normal_z * sigma_zz - ) - - if self.dim == 2: - self.equations.pop("stress_disp_zz") - self.equations.pop("stress_disp_xz") - self.equations.pop("stress_disp_yz") - self.equations.pop("equilibrium_z") - self.equations.pop("traction_z") - - -@pytest.mark.parametrize( - "E,nu,lambda_,mu", - ( - (2.0, 3.0, None, None), - (None, None, 2.0, 3.0), - ), -) -@pytest.mark.parametrize("rho", (1,)) -@pytest.mark.parametrize("dim", (2, 3)) -@pytest.mark.parametrize("time", (False, True)) -def test_linearelasticity(E, nu, lambda_, mu, rho, dim, time): - """Test for linearelasticity equation.""" - # define input/output keys - input_keys = ("x", "y", "z")[:dim] - if time: - input_keys = ("t",) + input_keys - - disp_output_keys = ("u", "v") - if dim == 3: - disp_output_keys += ("w",) - disp_output_keys += ("p",) - - stress_output_keys = ("sigma_xx", "sigma_yy") - if dim == 3: - stress_output_keys += ("sigma_zz",) - stress_output_keys += ("sigma_xy",) - if dim == 3: - stress_output_keys += ("sigma_xz", "sigma_yz") - - # prepare input data in dict - batch_size = 13 - input_dict = {} - for var in input_keys: - input_dict[var] = paddle.randn([batch_size, 1]) - input_dict[var].stop_gradient = False - input_dict[f"normal_{var}"] = paddle.randn([batch_size, 1]) - input_dict[f"normal_{var}"].stop_gradient = False - - # prepare model - disp_net = ppsci.arch.MLP( - input_keys, disp_output_keys, 3, 16, "silu", weight_norm=True - ) - stress_net = ppsci.arch.MLP( - input_keys, - stress_output_keys, - 3, - 16, - "silu", - weight_norm=True, - ) - model_list = ppsci.arch.ModelList((disp_net, stress_net)) - - # prepare python function expressions and sympy-expression in dict - functional_expr_dict = equation.LinearElasticity( - E, nu, lambda_, mu, rho, dim, time - ).equations - sympy_expr_dict = LinearElasticity_sympy( - E, nu, lambda_, mu, rho, dim, time - ).equations - for target, expr in sympy_expr_dict.items(): - sympy_expr_dict[target] = sym_to_func.sympy_to_function( - expr, [disp_net, stress_net] - ) - - # compute equation with python function - output_dict_functional = model_list(input_dict) - for name, expr in functional_expr_dict.items(): - if callable(expr): - output_dict_functional[name] = expr( - {**output_dict_functional, **input_dict} - ) - else: - raise TypeError(f"expr type({type(expr)}) is invalid") - clear() - - # compute equation with funciton converted from sympy - output_dict_sympy = {k: v for k, v in input_dict.items()} - for name, _ in sympy_expr_dict.items(): - output_dict_sympy[name] = sympy_expr_dict[name]( - {**output_dict_sympy, **input_dict} - ) - clear() - - # test for result - for key in functional_expr_dict: - assert paddle.allclose( - output_dict_functional[key], output_dict_sympy[key], atol=2e-7 - ) - - -if __name__ == "__main__": - pytest.main() diff --git a/test/utils/test_navier_stokes_sympy.py b/test/utils/test_navier_stokes_sympy.py deleted file mode 100644 index e10592c136..0000000000 --- a/test/utils/test_navier_stokes_sympy.py +++ /dev/null @@ -1,540 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -import pytest -import sympy as sp - -import ppsci -from ppsci import equation -from ppsci.autodiff import clear -from ppsci.autodiff import hessian as H -from ppsci.autodiff import jacobian as J -from ppsci.utils import sym_to_func - - -class NavierStokes_sympy: - def __init__(self, nu, rho=1, dim=3, time=True): - # set params - self.dim = dim - self.time = time - - # coordinates - x, y, z = sp.Symbol("x"), sp.Symbol("y"), sp.Symbol("z") - - # time - t = sp.Symbol("t") - - # make input variables - input_variables = {"x": x, "y": y, "z": z, "t": t} - if self.dim == 2: - input_variables.pop("z") - if not self.time: - input_variables.pop("t") - - # velocity componets - u = sp.Function("u")(*input_variables) - v = sp.Function("v")(*input_variables) - if self.dim == 3: - w = sp.Function("w")(*input_variables) - else: - w = sp.Number(0) - - # pressure - p = sp.Function("p")(*input_variables) - - # kinematic viscosity - if isinstance(nu, str): - nu = sp.Function(nu)(*input_variables) - elif isinstance(nu, (float, int)): - nu = sp.Number(nu) - - # density - if isinstance(rho, str): - rho = sp.Function(rho)(*input_variables) - elif isinstance(rho, (float, int)): - rho = sp.Number(rho) - - # dynamic viscosity - mu = rho * nu - - # set equations - self.equations = {} - self.equations["continuity"] = ( - rho.diff(t) + (rho * u).diff(x) + (rho * v).diff(y) + (rho * w).diff(z) - ) - - curl = sp.Number(0) if rho.diff(x) == 0 else u.diff(x) + v.diff(y) + w.diff(z) - self.equations["momentum_x"] = ( - (rho * u).diff(t) - + ( - u * ((rho * u).diff(x)) - + v * ((rho * u).diff(y)) - + w * ((rho * u).diff(z)) - + rho * u * (curl) - ) - + p.diff(x) - - (-2 / 3 * mu * (curl)).diff(x) - - (mu * u.diff(x)).diff(x) - - (mu * u.diff(y)).diff(y) - - (mu * u.diff(z)).diff(z) - - (mu * (curl).diff(x)) - ) - self.equations["momentum_y"] = ( - (rho * v).diff(t) - + ( - u * ((rho * v).diff(x)) - + v * ((rho * v).diff(y)) - + w * ((rho * v).diff(z)) - + rho * v * (curl) - ) - + p.diff(y) - - (-2 / 3 * mu * (curl)).diff(y) - - (mu * v.diff(x)).diff(x) - - (mu * v.diff(y)).diff(y) - - (mu * v.diff(z)).diff(z) - - (mu * (curl).diff(y)) - ) - self.equations["momentum_z"] = ( - (rho * w).diff(t) - + ( - u * ((rho * w).diff(x)) - + v * ((rho * w).diff(y)) - + w * ((rho * w).diff(z)) - + rho * w * (curl) - ) - + p.diff(z) - - (-2 / 3 * mu * (curl)).diff(z) - - (mu * w.diff(x)).diff(x) - - (mu * w.diff(y)).diff(y) - - (mu * w.diff(z)).diff(z) - - (mu * (curl).diff(z)) - ) - - if self.dim == 2: - self.equations.pop("momentum_z") - - -class ZeroEquation_sympy: - def __init__( - self, nu, max_distance, rho=1, dim=3, time=True - ): # TODO add density into model - # set params - self.dim = dim - self.time = time - - # model coefficients - self.max_distance = max_distance - self.karman_constant = 0.419 - self.max_distance_ratio = 0.09 - - # coordinates - x, y, z = sp.Symbol("x"), sp.Symbol("y"), sp.Symbol("z") - - # time - t = sp.Symbol("t") - - # make input variables - input_variables = {"x": x, "y": y, "z": z, "t": t} - if self.dim == 2: - input_variables.pop("z") - if not self.time: - input_variables.pop("t") - - # velocity componets - u = sp.Function("u")(*input_variables) - v = sp.Function("v")(*input_variables) - if self.dim == 3: - w = sp.Function("w")(*input_variables) - else: - w = sp.Number(0) - - # density - if type(rho) is str: - rho = sp.Function(rho)(*input_variables) - elif type(rho) in [float, int]: - rho = sp.Number(rho) - - # wall distance - normal_distance = sp.Function("sdf")(*input_variables) - - # mixing length - mixing_length = sp.Min( - self.karman_constant * normal_distance, - self.max_distance_ratio * self.max_distance, - ) - G = ( - 2 * u.diff(x) ** 2 - + 2 * v.diff(y) ** 2 - + 2 * w.diff(z) ** 2 - + (u.diff(y) + v.diff(x)) ** 2 - + (u.diff(z) + w.diff(x)) ** 2 - + (v.diff(z) + w.diff(y)) ** 2 - ) - - # set equations - self.equations = {} - self.equations["nu"] = nu + rho * mixing_length**2 * sp.sqrt(G) - - -class Test_NavierStokes_sympy: - @pytest.mark.parametrize("nu", (2.0,)) - @pytest.mark.parametrize("rho", (1.0,)) - @pytest.mark.parametrize("dim", (2,)) - @pytest.mark.parametrize("time", (False, True)) - def test_nu_sympy(self, nu, rho, dim, time): - """Test for navier_stokes equation.""" - # define input/output keys - ze = ZeroEquation_sympy(nu=nu, rho=rho, dim=dim, max_distance=3.4, time=time) - nu_sympy = ze.equations["nu"] - - input_keys = ("x", "y", "z")[:dim] - if time: - input_keys = ("t",) + input_keys - - output_keys = ("u", "v") - if dim == 3: - output_keys += ("w",) - output_keys += ("p",) - - # prepare input data in dict - batch_size = 13 - input_dict = {} - for var in input_keys: - input_dict[var] = paddle.randn([batch_size, 1]) - input_dict[var].stop_gradient = False - if var != "t": - input_dict[f"sdf__{var}"] = paddle.randn([batch_size, 1]) - input_dict[f"normal__{var}"] = paddle.randn([batch_size, 1]) - - input_dict[f"sdf__{var}"].stop_gradient = False - input_dict[f"normal__{var}"].stop_gradient = False - - input_dict["sdf"] = paddle.randn([batch_size, 1]) - input_dict["sdf"].stop_gradient = False - - # prepare model - model = ppsci.arch.MLP(input_keys, output_keys, 2, 10) - - # prepare python function expressions and sympy-expression in dict - def nu_f(out): - karman_constant = 0.419 - max_distance_ratio = 0.09 - normal_distance = out["sdf"] - max_distance = ze.max_distance - mixing_length = paddle.minimum( - karman_constant * normal_distance, - max_distance_ratio * max_distance, - ) - x, y = out["x"], out["y"] - u, v = out["u"], out["v"] - G = 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 + (J(u, y) + J(v, x)) ** 2 - if dim == 3: - z, w = out["z"], out["w"] - G += ( - +2 * J(w, z) ** 2 - + (J(u, z) + J(w, x)) ** 2 - + (J(v, z) + J(w, y)) ** 2 - ) - return nu + rho * mixing_length**2 * paddle.sqrt(G) - - functional_expr_dict = equation.NavierStokes(nu_f, rho, dim, time).equations - - def continuity_f(out): - x, y = out["x"], out["y"] - u, v = out["u"], out["v"] - return 1.0 * J(u, x) + 1.0 * J(v, y) - - def momentum_x_f(out): - x, y = out["x"], out["y"] - u, v, p = out["u"], out["v"], out["p"] - if time: - t = out["t"] - return ( - -( - 1.0 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - + 2.0 - ) - * H(u, x) - - ( - 1.0 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - + 2.0 - ) - * H(u, y) - - ( - 1.0 - * ( - (J(u, y) + J(v, x)) * (2 * H(u, y) + 2 * J(J(v, x), y)) / 2 - + 2 * J(u, x) * J(J(u, x), y) - + 2 * J(v, y) * H(v, y) - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - / paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - + 0.838 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.heaviside(0.306 - 0.419 * out["sdf"], paddle.zeros([])) - * out["sdf__y"] - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ) - * J(u, y) - - ( - 1.0 - * ( - (J(u, y) + J(v, x)) * (2 * H(v, x) + 2 * J(J(u, x), y)) / 2 - + 2 * J(u, x) * H(u, x) - + 2 * J(v, y) * J(J(v, x), y) - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - / paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - + 0.838 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.heaviside(0.306 - 0.419 * out["sdf"], paddle.zeros([])) - * out["sdf__x"] - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ) - * J(u, x) - + (1.0 * u * J(u, x) + 1.0 * v * J(u, y) + J(p, x)) - + (J(u, t) if time else 0) - ) - - def momentum_y_f(out): - x, y = out["x"], out["y"] - u, v, p = out["u"], out["v"], out["p"] - if time: - t = out["t"] - return ( - -( - 1.0 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - + 2.0 - ) - * H(v, x) - - ( - 1.0 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - + 2.0 - ) - * H(v, y) - - ( - 1.0 - * ( - (J(u, y) + J(v, x)) * (2 * H(u, y) + 2 * J(J(v, x), y)) / 2 - + 2 * J(u, x) * J(J(u, x), y) - + 2 * J(v, y) * H(v, y) - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - / paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - + 0.838 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.heaviside(0.306 - 0.419 * out["sdf"], paddle.zeros([])) - * out["sdf__y"] - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ) - * J(v, y) - - ( - 1.0 - * ( - (J(u, y) + J(v, x)) * (2 * H(v, x) + 2 * J(J(u, x), y)) / 2 - + 2 * J(u, x) * H(u, x) - + 2 * J(v, y) * J(J(v, x), y) - ) - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ** 2 - / paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - + 0.838 - * paddle.sqrt( - (J(u, y) + J(v, x)) ** 2 + 2 * J(u, x) ** 2 + 2 * J(v, y) ** 2 - ) - * paddle.heaviside(0.306 - 0.419 * out["sdf"], paddle.zeros([])) - * out["sdf__x"] - * paddle.minimum( - paddle.full_like(out["sdf"], 0.306), 0.419 * out["sdf"] - ) - ) - * J(v, x) - + (1.0 * u * J(v, x) + 1.0 * v * J(v, y) + J(p, y)) - + (J(v, t) if time else 0) - ) - - functional_expr_dict["continuity"] = continuity_f - functional_expr_dict["momentum_x"] = momentum_x_f - functional_expr_dict["momentum_y"] = momentum_y_f - - sympy_expr_dict = NavierStokes_sympy(nu_sympy, rho, dim, time).equations - for target, expr in sympy_expr_dict.items(): - sympy_expr_dict[target] = sym_to_func.sympy_to_function( - expr, - [ - model, - ], - ) - - # compute equation with python function - output_dict_functional = model(input_dict) - for name, expr in functional_expr_dict.items(): - if callable(expr): - output_dict_functional[name] = expr( - {**output_dict_functional, **input_dict} - ) - else: - raise TypeError(f"expr type({type(expr)}) is invalid") - clear() - - # compute equation with funciton converted from sympy - output_dict_sympy = {k: v for k, v in input_dict.items()} - for name, expr in sympy_expr_dict.items(): - tmp = expr(output_dict_sympy) - output_dict_sympy[name] = tmp - clear() - - # test for result - for key in functional_expr_dict: - assert paddle.allclose( - output_dict_functional[key], output_dict_sympy[key], atol=1e-7 - ), f"{key} not equal." - - @pytest.mark.parametrize("nu", (2.0,)) - @pytest.mark.parametrize("rho", (1.0,)) - @pytest.mark.parametrize("dim", (2,)) - @pytest.mark.parametrize("time", (False, True)) - def test_nu_constant(self, nu, rho, dim, time): - """Test for navier_stokes equation.""" - # define input/output keys - nu_sympy = nu - - input_keys = ("x", "y", "z")[:dim] - if time: - input_keys = ("t",) + input_keys - - output_keys = ("u", "v") - if dim == 3: - output_keys += ("w",) - output_keys += ("p",) - - # prepare input data in dict - batch_size = 13 - input_dict = {} - for var in input_keys: - input_dict[var] = paddle.randn([batch_size, 1]) - input_dict[var].stop_gradient = False - if var != "t": - input_dict[f"sdf__{var}"] = paddle.randn([batch_size, 1]) - input_dict[f"normal__{var}"] = paddle.randn([batch_size, 1]) - - input_dict[f"sdf__{var}"].stop_gradient = False - input_dict[f"normal__{var}"].stop_gradient = False - - input_dict["sdf"] = paddle.randn([batch_size, 1]) - input_dict["sdf"].stop_gradient = False - - # prepare model - model = ppsci.arch.MLP(input_keys, output_keys, 2, 10) - - # prepare python function expressions and sympy-expression in dict - functional_expr_dict = equation.NavierStokes(nu, rho, dim, time).equations - - sympy_expr_dict = NavierStokes_sympy(nu_sympy, rho, dim, time).equations - for target, expr in sympy_expr_dict.items(): - sympy_expr_dict[target] = sym_to_func.sympy_to_function( - expr, - [ - model, - ], - ) - - # compute equation with python function - output_dict_functional = model(input_dict) - for name, expr in functional_expr_dict.items(): - if callable(expr): - output_dict_functional[name] = expr( - {**output_dict_functional, **input_dict} - ) - else: - raise TypeError(f"expr type({type(expr)}) is invalid") - clear() - - # compute equation with funciton converted from sympy - output_dict_sympy = {k: v for k, v in input_dict.items()} - tmp = {k: v for k, v in output_dict_sympy.items()} - for name, expr in sympy_expr_dict.items(): - output = expr(tmp) - output_dict_sympy[name] = output - clear() - - # test for result - for key in functional_expr_dict: - assert paddle.allclose( - output_dict_functional[key], output_dict_sympy[key], atol=1e-7 - ), f"{key} not equal." - - -if __name__ == "__main__": - pytest.main()