From bd27ad1f2ef978fa99ffb1640ff48b710f3cb3bb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 27 Jun 2024 21:36:23 -0400 Subject: [PATCH 1/3] feat: support array API Fix #3430. This PR sets up the basic support for the array API, and make an example function (`compute_smooth_weight`) to support the array API. I believe NumPy and JAX have supported it (or through `array-api-compat`), so we don't need to write things twice for NumPy and JAX (although we can write them using the ChatGPT, it's still better to maintain only one thing). There are some challeging to use it in the TorchScript, so I give it up. Supporting more function can be implemented in the following PRs. Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/array_api.py | 29 +++++++++++++++++++ deepmd/dpmodel/utils/env_mat.py | 12 ++++++-- doc/backend.md | 2 ++ pyproject.toml | 2 ++ .../common/dpmodel/array_api/__init__.py | 2 ++ .../common/dpmodel/array_api/test_env_mat.py | 26 +++++++++++++++++ .../tests/common/dpmodel/array_api/utils.py | 27 +++++++++++++++++ 7 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 deepmd/dpmodel/array_api.py create mode 100644 source/tests/common/dpmodel/array_api/__init__.py create mode 100644 source/tests/common/dpmodel/array_api/test_env_mat.py create mode 100644 source/tests/common/dpmodel/array_api/utils.py diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py new file mode 100644 index 0000000000..e4af2ad627 --- /dev/null +++ b/deepmd/dpmodel/array_api.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Utilities for the array API.""" + + +def support_array_api(version: str) -> callable: + """Mark a function as supporting the specific version of the array API. + + Parameters + ---------- + version : str + The version of the array API + + Returns + ------- + callable + The decorated function + + Examples + -------- + >>> @support_array_api(version="2022.12") + ... def f(x): + ... pass + """ + + def set_version(func: callable) -> callable: + func.array_api_version = version + return func + + return set_version diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 94cf3a7c21..41f2591279 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -4,13 +4,18 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( NativeOP, ) +from deepmd.dpmodel.array_api import ( + support_array_api, +) +@support_array_api(version="2022.12") def compute_smooth_weight( distance: np.ndarray, rmin: float, @@ -19,12 +24,15 @@ def compute_smooth_weight( """Compute smooth weight for descriptor elements.""" if rmin >= rmax: raise ValueError("rmin should be less than rmax.") + xp = array_api_compat.array_namespace(distance) min_mask = distance <= rmin max_mask = distance >= rmax - mid_mask = np.logical_not(np.logical_or(min_mask, max_mask)) + mid_mask = xp.logical_not(xp.logical_or(min_mask, max_mask)) uu = (distance - rmin) / (rmax - rmin) vv = uu * uu * uu * (-6.0 * uu * uu + 15.0 * uu - 10.0) + 1.0 - return vv * mid_mask + min_mask + return vv * xp.astype(mid_mask, distance.dtype) + xp.astype( + min_mask, distance.dtype + ) def _make_env_mat( diff --git a/doc/backend.md b/doc/backend.md index 2f0bc7ed20..e164cd8405 100644 --- a/doc/backend.md +++ b/doc/backend.md @@ -37,6 +37,8 @@ As a reference backend, it is not aimed at the best performance, but only the co The DP backend uses [HDF5](https://docs.h5py.org/) to store model serialization data, which is backend-independent. Only Python inference interface can load this format. +NumPy 1.21 or above is required. + ## Switch the backend ### Training diff --git a/pyproject.toml b/pyproject.toml index d9cbeb44e4..1e68d1287a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ 'packaging', 'ml_dtypes', 'mendeleev', + 'array-api-compat', ] requires-python = ">=3.8" keywords = ["deepmd"] @@ -79,6 +80,7 @@ test = [ "pytest-sugar", "pytest-split", "dpgui", + "array-api-strict>=2", ] docs = [ "sphinx>=3.1.1", diff --git a/source/tests/common/dpmodel/array_api/__init__.py b/source/tests/common/dpmodel/array_api/__init__.py new file mode 100644 index 0000000000..e02301188e --- /dev/null +++ b/source/tests/common/dpmodel/array_api/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Test array API compatibility to be completely sure their usage of the array API is portable.""" diff --git a/source/tests/common/dpmodel/array_api/test_env_mat.py b/source/tests/common/dpmodel/array_api/test_env_mat.py new file mode 100644 index 0000000000..8dfa199d53 --- /dev/null +++ b/source/tests/common/dpmodel/array_api/test_env_mat.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import array_api_strict as xp + +from deepmd.dpmodel.utils.env_mat import ( + compute_smooth_weight, +) + +from .utils import ( + ArrayAPITest, +) + + +class TestEnvMat(unittest.TestCase, ArrayAPITest): + def test_compute_smooth_weight(self): + self.set_array_api_version(compute_smooth_weight) + d = xp.arange(10, dtype=xp.float64) + w = compute_smooth_weight( + d, + 4.0, + 6.0, + ) + self.assert_namespace_equal(w, d) + self.assert_device_equal(w, d) + self.assert_dtype_equal(w, d) diff --git a/source/tests/common/dpmodel/array_api/utils.py b/source/tests/common/dpmodel/array_api/utils.py new file mode 100644 index 0000000000..7e422c2ead --- /dev/null +++ b/source/tests/common/dpmodel/array_api/utils.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import array_api_compat +from array_api_strict import ( + set_array_api_strict_flags, +) + + +class ArrayAPITest: + """Utils for array API tests.""" + + def set_array_api_version(self, func): + """Set the array API version for a function.""" + set_array_api_strict_flags(api_version=func.array_api_version) + + def assert_namespace_equal(self, a, b): + """Assert two array has the same namespace.""" + self.assertEqual( + array_api_compat.array_namespace(a), array_api_compat.array_namespace(b) + ) + + def assert_dtype_equal(self, a, b): + """Assert two array has the same dtype.""" + self.assertEqual(a.dtype, b.dtype) + + def assert_device_equal(self, a, b): + """Assert two array has the same device.""" + self.assertEqual(array_api_compat.device(a), array_api_compat.device(b)) From 40ad1457a2fbd93a7dfe078a6c5b3afae1d83f3a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 27 Jun 2024 21:41:28 -0400 Subject: [PATCH 2/3] skip tests in py38 Signed-off-by: Jinzhe Zeng --- pyproject.toml | 2 +- source/tests/common/dpmodel/array_api/test_env_mat.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1e68d1287a..6f79c4fcbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,7 @@ test = [ "pytest-sugar", "pytest-split", "dpgui", - "array-api-strict>=2", + "array-api-strict>=2;python_version>=3.9", ] docs = [ "sphinx>=3.1.1", diff --git a/source/tests/common/dpmodel/array_api/test_env_mat.py b/source/tests/common/dpmodel/array_api/test_env_mat.py index 8dfa199d53..d5bc7b6c18 100644 --- a/source/tests/common/dpmodel/array_api/test_env_mat.py +++ b/source/tests/common/dpmodel/array_api/test_env_mat.py @@ -1,7 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import sys import unittest -import array_api_strict as xp +if sys.version_info >= (3, 9): + import array_api_strict as xp +else: + raise unittest.SkipTest("array_api_strict doesn't support Python<=3.8") from deepmd.dpmodel.utils.env_mat import ( compute_smooth_weight, From c9616a00e089f2f7ced4c3d51b32776420ea328f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 27 Jun 2024 22:12:54 -0400 Subject: [PATCH 3/3] fix require --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6f79c4fcbf..861fea6399 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,7 @@ test = [ "pytest-sugar", "pytest-split", "dpgui", - "array-api-strict>=2;python_version>=3.9", + 'array-api-strict>=2;python_version>="3.9"', ] docs = [ "sphinx>=3.1.1",