Skip to content

consistent energy model #3306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DescrptSeA,
)
from deepmd.dpmodel.fitting import ( # noqa # TODO: should import all fittings!
EnergyFittingNet,
InvarFitting,
)
from deepmd.dpmodel.output_def import (
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/fitting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .dipole_fitting import (
DipoleFitting,
)
from .ener_fitting import (
EnergyFittingNet,
)
from .invar_fitting import (
InvarFitting,
)
Expand All @@ -16,5 +19,6 @@
"InvarFitting",
"make_base_fitting",
"DipoleFitting",
"EnergyFittingNet",
"PolarFitting",
]
5 changes: 4 additions & 1 deletion deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@
make_model,
)

DPModel = make_model(DPAtomicModel)

# use "class" to resolve "Variable not allowed in type expression"
class DPModel(make_model(DPAtomicModel)):
pass
5 changes: 4 additions & 1 deletion deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import numpy as np

from deepmd.dpmodel.common import (
NativeOP,
)
from deepmd.dpmodel.output_def import (
ModelOutputDef,
)
Expand Down Expand Up @@ -45,7 +48,7 @@ def make_model(T_AtomicModel):

"""

class CM(T_AtomicModel):
class CM(T_AtomicModel, NativeOP):
def __init__(
self,
*args,
Expand Down
41 changes: 41 additions & 0 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
)
from deepmd.dpmodel.fitting.ener_fitting import (
EnergyFittingNet,
)
from deepmd.dpmodel.model.dp_model import (
DPModel,
)


def get_model(data: dict) -> DPModel:
"""Get a DPModel from a dictionary.

Parameters
----------
data : dict
The data to construct the model.
"""
descriptor_type = data["descriptor"].pop("type")
fitting_type = data["fitting_net"].pop("type")
if descriptor_type == "se_e2_a":
descriptor = DescrptSeA(
**data["descriptor"],
)
else:
raise ValueError(f"Unknown descriptor type {descriptor_type}")
if fitting_type == "ener":
fitting = EnergyFittingNet(
ntypes=descriptor.get_ntypes(),
dim_descrpt=descriptor.get_dim_out(),
**data["fitting_net"],
)
else:
raise ValueError(f"Unknown fitting type {fitting_type}")
return DPModel(
descriptor=descriptor,
fitting=fitting,
type_map=data["type_map"],
)
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DescrptSeA,
)
from deepmd.pt.model.task.ener import ( # noqa # TODO: should import all fittings!
EnergyFittingNet,
InvarFitting,
)
from deepmd.pt.utils.utils import (
Expand Down
29 changes: 25 additions & 4 deletions deepmd/pt/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import (
Callable,
Optional,
overload,
)

import numpy as np
Expand Down Expand Up @@ -51,9 +52,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
raise RuntimeError(f"activation function {self.activation} not supported")


@overload
def to_numpy_array(xx: torch.Tensor) -> np.ndarray:
...

Check notice

Code scanning / CodeQL

Statement has no effect

This statement has no effect.


@overload
def to_numpy_array(xx: None) -> None:
...

Check notice

Code scanning / CodeQL

Statement has no effect

This statement has no effect.


def to_numpy_array(
xx: torch.Tensor,
) -> np.ndarray:
xx,
):
if xx is None:
return None
assert xx is not None
Expand All @@ -67,9 +78,19 @@ def to_numpy_array(
return xx.detach().cpu().numpy().astype(prec)


@overload
def to_torch_tensor(xx: np.ndarray) -> torch.Tensor:
...

Check notice

Code scanning / CodeQL

Statement has no effect

This statement has no effect.


@overload
def to_torch_tensor(xx: None) -> None:
...

Check notice

Code scanning / CodeQL

Statement has no effect

This statement has no effect.


def to_torch_tensor(
xx: np.ndarray,
) -> torch.Tensor:
xx,
):
if xx is None:
return None
assert xx is not None
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
The deserialized descriptor
"""
if cls is Descriptor:
return Descriptor.get_class_by_input(data).deserialize(data)
return Descriptor.get_class_by_input(data).deserialize(data, suffix=suffix)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

def serialize(self, suffix: str = "") -> dict:
Expand Down
4 changes: 2 additions & 2 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,7 @@ def get_loss(self, loss: dict, lr) -> Loss:
raise RuntimeError("unknown loss type")

@classmethod
def deserialize(cls, data: dict, suffix: str):
def deserialize(cls, data: dict, suffix: str = ""):
"""Deserialize the model.

Parameters
Expand All @@ -956,7 +956,7 @@ def deserialize(cls, data: dict, suffix: str):
fitting.aparam_inv_std = data["@variables"]["aparam_inv_std"]
return fitting

def serialize(self, suffix: str) -> dict:
def serialize(self, suffix: str = "") -> dict:
"""Serialize the model.

Returns
Expand Down
72 changes: 64 additions & 8 deletions deepmd/tf/fit/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import (
Callable,
List,
Type,
)

from deepmd.dpmodel.utils.network import (
Expand Down Expand Up @@ -50,16 +51,33 @@ class SomeFitting(Fitting):
"""
return Fitting.__plugins.register(key)

@classmethod
def get_class_by_input(cls, data: dict) -> Type["Fitting"]:
"""Get the fitting class by the input data.

Parameters
----------
data : dict
The input data

Returns
-------
Fitting
The fitting class
"""
try:
fitting_type = data["type"]
except KeyError:
raise KeyError("the type of fitting should be set by `type`")
if fitting_type in Fitting.__plugins.plugins:
cls = Fitting.__plugins.plugins[fitting_type]
else:
raise RuntimeError("Unknown descriptor type: " + fitting_type)
return cls

def __new__(cls, *args, **kwargs):
if cls is Fitting:
try:
fitting_type = kwargs["type"]
except KeyError:
raise KeyError("the type of fitting should be set by `type`")
if fitting_type in Fitting.__plugins.plugins:
cls = Fitting.__plugins.plugins[fitting_type]
else:
raise RuntimeError("Unknown descriptor type: " + fitting_type)
cls = cls.get_class_by_input(kwargs)
return super().__new__(cls)

@property
Expand Down Expand Up @@ -110,6 +128,44 @@ def get_loss(self, loss: dict, lr) -> Loss:
the loss function
"""

@classmethod
def deserialize(cls, data: dict, suffix: str = "") -> "Fitting":
"""Deserialize the fitting.

There is no suffix in a native DP model, but it is important
for the TF backend.

Parameters
----------
data : dict
The serialized data
suffix : str, optional
Name suffix to identify this fitting

Returns
-------
Fitting
The deserialized fitting
"""
if cls is Fitting:
return Fitting.get_class_by_input(data).deserialize(data, suffix=suffix)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

def serialize(self, suffix: str = "") -> dict:
"""Serialize the fitting.

There is no suffix in a native DP model, but it is important
for the TF backend.

Returns
-------
dict
The serialized data
suffix : str, optional
Name suffix to identify this fitting
"""
raise NotImplementedError("Not implemented in class %s" % self.__name__)

def serialize_network(
self,
ntypes: int,
Expand Down
Loading