Skip to content

feat(jax/array-api): se_e2_r #4257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
DEFAULT_PRECISION,
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.common import (
get_xp_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
EmbeddingNet,
EnvMat,
Expand All @@ -25,9 +30,6 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -144,31 +146,33 @@ def __init__(
self.env_protection = env_protection

in_dim = 1 # not considiering type embedding
self.embeddings = NetworkCollection(
embeddings = NetworkCollection(
ntypes=self.ntypes,
ndim=(1 if self.type_one_side else 2),
network_type="embedding_network",
)
if not self.type_one_side:
raise NotImplementedError("type_one_side == False not implemented")
for ii in range(self.ntypes):
self.embeddings[(ii,)] = EmbeddingNet(
embeddings[(ii,)] = EmbeddingNet(
in_dim,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
seed=child_seed(seed, ii),
)
self.embeddings = embeddings
self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection)
self.nnei = np.sum(self.sel)
self.nnei = np.sum(self.sel).item()
self.davg = np.zeros(
[self.ntypes, self.nnei, 1], dtype=PRECISION_DICT[self.precision]
)
self.dstd = np.ones(
[self.ntypes, self.nnei, 1], dtype=PRECISION_DICT[self.precision]
)
self.orig_sel = self.sel
self.sel_cumsum = [0, *np.cumsum(self.sel).tolist()]

def __setitem__(self, key, value):
if key in ("avg", "data_avg", "davg"):
Expand Down Expand Up @@ -279,8 +283,9 @@ def cal_g(
ss,
ll,
):
xp = array_api_compat.array_namespace(ss)
nf, nloc, nnei = ss.shape[0:3]
ss = ss.reshape(nf, nloc, nnei, 1)
ss = xp.reshape(ss, (nf, nloc, nnei, 1))
# nf x nloc x nnei x ng
gg = self.embeddings[(ll,)].call(ss)
return gg
Expand Down Expand Up @@ -321,29 +326,34 @@ def call(
sw
The smooth switch function.
"""
xp = array_api_compat.array_namespace(coord_ext)
del mapping
# nf x nloc x nnei x 1
rr, diff, ww = self.env_mat.call(
coord_ext, atype_ext, nlist, self.davg, self.dstd, True
)
nf, nloc, nnei, _ = rr.shape
sec = np.append([0], np.cumsum(self.sel))
sec = self.sel_cumsum

ng = self.neuron[-1]
xyz_scatter = np.zeros([nf, nloc, ng], dtype=PRECISION_DICT[self.precision])
xyz_scatter = xp.zeros(
[nf, nloc, ng], dtype=get_xp_precision(xp, self.precision)
)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
rr = xp.astype(rr, xyz_scatter.dtype)
for tt in range(self.ntypes):
mm = exclude_mask[:, :, sec[tt] : sec[tt + 1]]
tr = rr[:, :, sec[tt] : sec[tt + 1], :]
tr = tr * mm[:, :, :, None]
tr = tr * xp.astype(mm[:, :, :, None], tr.dtype)
gg = self.cal_g(tr, tt)
gg = np.mean(gg, axis=2)
gg = xp.mean(gg, axis=2)
# nf x nloc x ng x 1
xyz_scatter += gg * (self.sel[tt] / self.nnei)

res_rescale = 1.0 / 5.0
res = xyz_scatter * res_rescale
res = res.reshape(nf, nloc, ng).astype(GLOBAL_NP_FLOAT_PRECISION)
res = xp.reshape(res, (nf, nloc, ng))
res = xp.astype(res, get_xp_precision(xp, "global"))
return res, None, None, None, ww

def serialize(self) -> dict:
Expand All @@ -369,8 +379,8 @@ def serialize(self) -> dict:
"env_mat": self.env_mat.serialize(),
"embeddings": self.embeddings.serialize(),
"@variables": {
"davg": self.davg,
"dstd": self.dstd,
"davg": to_numpy_array(self.davg),
"dstd": to_numpy_array(self.dstd),
},
"type_map": self.type_map,
}
Expand Down
4 changes: 4 additions & 0 deletions deepmd/jax/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from deepmd.jax.descriptor.se_e2_a import (
DescrptSeA,
)
from deepmd.jax.descriptor.se_e2_r import (
DescrptSeR,
)

__all__ = [
"DescrptSeA",
"DescrptSeR",
"DescrptDPA1",
]
41 changes: 41 additions & 0 deletions deepmd/jax/descriptor/se_e2_r.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.jax.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.jax.utils.network import (
NetworkCollection,
)


@BaseDescriptor.register("se_e2_r")
@BaseDescriptor.register("se_r")
@flax_module
class DescrptSeR(DescrptSeRDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)
32 changes: 32 additions & 0 deletions source/tests/array_api_strict/descriptor/se_e2_r.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP

from ..common import (
to_array_api_strict_array,
)
from ..utils.exclude_mask import (
PairExcludeMask,
)
from ..utils.network import (
NetworkCollection,
)


class DescrptSeR(DescrptSeRDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_array_api_strict_array(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)
55 changes: 54 additions & 1 deletion source/tests/consistent/descriptor/test_se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -33,14 +35,25 @@
descrpt_se_r_args,
)

if INSTALLED_JAX:
from deepmd.jax.descriptor.se_e2_r import DescrptSeR as DescrptSeRJAX
else:
DescrptSeRJAX = None
if INSTALLED_ARRAY_API_STRICT:
from ...array_api_strict.descriptor.se_e2_r import (
DescrptSeR as DescrptSeRArrayAPIStrict,
)
else:
DescrptSeRArrayAPIStrict = None


@parameterized(
(True, False), # resnet_dt
(True, False), # type_one_side
([], [[0, 1]]), # excluded_types
("float32", "float64"), # precision
)
class TestSeA(CommonTest, DescriptorTest, unittest.TestCase):
class TestSeR(CommonTest, DescriptorTest, unittest.TestCase):
@property
def data(self) -> dict:
(
Expand Down Expand Up @@ -81,9 +94,31 @@ def skip_dp(self) -> bool:
) = self.param
return not type_one_side or CommonTest.skip_dp

@property
def skip_jax(self) -> bool:
(
resnet_dt,
type_one_side,
excluded_types,
precision,
) = self.param
return not type_one_side or not INSTALLED_JAX

@property
def skip_array_api_strict(self) -> bool:
(
resnet_dt,
type_one_side,
excluded_types,
precision,
) = self.param
return not type_one_side or not INSTALLED_ARRAY_API_STRICT

tf_class = DescrptSeRTF
dp_class = DescrptSeRDP
pt_class = DescrptSeRPT
jax_class = DescrptSeRJAX
array_api_strict_class = DescrptSeRArrayAPIStrict
args = descrpt_se_r_args()

def setUp(self):
Expand Down Expand Up @@ -148,6 +183,24 @@ def eval_pt(self, pt_obj: Any) -> Any:
self.box,
)

def eval_jax(self, jax_obj: Any) -> Any:
return self.eval_jax_descriptor(
jax_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return self.eval_array_api_strict_descriptor(
array_api_strict_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0],)

Expand Down