Skip to content

add universal Python inference interface DeepPot #3164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from deepmd.utils.sess import (
run_sess,
)
from deepmd_utils.infer.deep_pot import DeepPot as DeepPotBase

Check notice

Code scanning / CodeQL

Cyclic import

Import of module [deepmd_utils.infer.deep_pot](1) begins an import cycle.

if TYPE_CHECKING:
from pathlib import (
Expand All @@ -35,7 +36,7 @@
log = logging.getLogger(__name__)


class DeepPot(DeepEval):
class DeepPot(DeepEval, DeepPotBase):
"""Constructor.

Parameters
Expand Down
6 changes: 6 additions & 0 deletions deepmd_utils/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .deep_pot import (
DeepPot,
)

__all__ = ["DeepPot"]
33 changes: 33 additions & 0 deletions deepmd_utils/infer/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from enum import (
Enum,
)


class DPBackend(Enum):
"""DeePMD-kit backend."""

TensorFlow = 1
PyTorch = 2
Paddle = 3
Unknown = 4


def detect_backend(filename: str) -> DPBackend:
"""Detect the backend of the given model file.

Parameters
----------
filename : str
The model file name
"""
if filename.endswith(".pb"):
return DPBackend.TensorFlow
elif filename.endswith(".pth") or filename.endswith(".pt"):
return DPBackend.PyTorch
elif filename.endswith(".pdmodel"):
return DPBackend.Paddle
return DPBackend.Unknown


__all__ = ["DPBackend", "detect_backend"]
88 changes: 88 additions & 0 deletions deepmd_utils/infer/deep_pot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractmethod,
)
from typing import (
List,
Optional,
Tuple,
Union,
)

import numpy as np

from deepmd_utils.utils import (
AutoBatchSize,
)

from .backend import (
Backend,
detect_backend,
)


class DeepPot(ABC):
"""Potential energy model.

Parameters
----------
model_file : Path
The name of the frozen model file.
auto_batch_size : bool or int or AutoBatchSize, default: True
If True, automatic batch size will be used. If int, it will be used
as the initial batch size.
neighbor_list : ase.neighborlist.NewPrimitiveNeighborList, optional
The ASE neighbor list class to produce the neighbor list. If None, the
neighbor list will be built natively in the model.
"""

@abstractmethod
def __init__(
self,
model_file,
*args,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list=None,
**kwargs,
) -> None:
...

def __new__(cls, model_file: str, *args, **kwargs):
if cls is DeepPot:
backend = detect_backend(model_file)
if backend == Backend.TensorFlow:
from deepmd.infer.deep_pot import DeepPot as DeepPotTF

Check notice

Code scanning / CodeQL

Cyclic import

Import of module [deepmd.infer.deep_pot](1) begins an import cycle.

return super().__new__(DeepPotTF)
elif backend == Backend.PyTorch:
from deepmd_pt.infer.deep_eval import DeepPot as DeepPotPT

return super().__new__(DeepPotPT)
else:
raise NotImplementedError("Unsupported backend: " + str(backend))
return super().__new__(cls)

@abstractmethod
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
atom_types: List[int],
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
efield: Optional[np.ndarray] = None,
mixed_type: bool = False,
) -> Tuple[np.ndarray, ...]:
"""Evaluate the model."""
# This method has been used by:
# documentation python.md
# dp model_devi: +fparam, +aparam, +mixed_type
# dp test: +atomic, +fparam, +aparam, +efield, +mixed_type
# finetune: +mixed_type
# dpdata
# ase


__all__ = ["DeepPot"]