Skip to content

add category property to OutputVariableDef #3228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Feb 6, 2024
98 changes: 96 additions & 2 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import functools
from enum import (
IntEnum,
)
from typing import (
Dict,
List,
Expand Down Expand Up @@ -107,6 +110,38 @@ def __call__(
return wrapper


class OutputVariableOperation(IntEnum):
"""Defines the operation of the output variable."""

NONE = 0
"""No operation."""
REDU = 1
"""Reduce the output variable."""
DERV_R = 2
"""Derivative w.r.t. coordinates."""
DERV_C = 4
"""Derivative w.r.t. cell."""
SEC_DERV_R = 8
"""Second derivative w.r.t. coordinates."""


class OutputVariableCategory(IntEnum):
"""Defines the category of the output variable."""

OUT = OutputVariableOperation.NONE
"""Output variable. (e.g. atom energy)"""
REDU = OutputVariableOperation.REDU
"""Reduced output variable. (e.g. system energy)"""
DERV_R = OutputVariableOperation.DERV_R
"""Negative derivative w.r.t. coordinates. (e.g. force)"""
DERV_C = OutputVariableOperation.DERV_C
"""Atomic component of the virial, see PRB 104, 224202 (2021) """
DERV_C_REDU = OutputVariableOperation.DERV_C | OutputVariableOperation.REDU
"""Virial, the transposed negative gradient with cell tensor times cell tensor, see eq 40 JCP 159, 054801 (2023). """
DERV_R_DERV_R = OutputVariableOperation.DERV_R | OutputVariableOperation.SEC_DERV_R
"""Hession matrix, the second derivative w.r.t. coordinates."""


class OutputVariableDef:
"""Defines the shape and other properties of the one output variable.

Expand All @@ -129,7 +164,8 @@ class OutputVariableDef:
If the variable is differentiated with respect to coordinates
of atoms and cell tensor (pbc case). Only reduciable variable
are differentiable.

category : int
The category of the output variable.
"""

def __init__(
Expand All @@ -139,6 +175,7 @@ def __init__(
reduciable: bool = False,
differentiable: bool = False,
atomic: bool = True,
category: int = OutputVariableCategory.OUT.value,
):
self.name = name
self.shape = list(shape)
Expand All @@ -149,6 +186,7 @@ def __init__(
raise ValueError("only reduciable variable are differentiable")
if self.reduciable and not self.atomic:
raise ValueError("only reduciable variable should be atomic")
self.category = category


class FittingOutputDef:
Expand Down Expand Up @@ -255,6 +293,55 @@ def get_deriv_name(name: str) -> Tuple[str, str]:
return name + "_derv_r", name + "_derv_c"


def apply_operation(var_def: OutputVariableDef, op: OutputVariableOperation) -> int:
"""Apply a operation to the category of a variable definition.

Parameters
----------
var_def : OutputVariableDef
The variable definition.
op : OutputVariableOperation
The operation to be applied.

Returns
-------
int
The new category of the variable definition.
"""
return var_def.category | op.value


def check_operation_applied(
var_def: OutputVariableDef, op: OutputVariableOperation
) -> bool:
"""Check if a operation has been applied to a variable definition.

Parameters
----------
var_def : OutputVariableDef
The variable definition.
op : OutputVariableOperation
The operation to be checked.

Returns
-------
bool
True if the operation has been applied, False otherwise.
"""
if op in (OutputVariableOperation.DERV_REDU, OutputVariableOperation.DERV_C):
assert not check_operation_applied(var_def, op)
elif op == OutputVariableOperation.DERV_R:
if check_operation_applied(var_def, OutputVariableOperation.DERV_R):
op = OutputVariableOperation.SEC_DERV_R
else:
assert not check_operation_applied(
var_def, OutputVariableOperation.SEC_DERV_R
)
else:
raise ValueError(f"operation {op} not supported")
return var_def.category & op.value == op.value


def do_reduce(
def_outp_data: Dict[str, OutputVariableDef],
) -> Dict[str, OutputVariableDef]:
Expand All @@ -263,7 +350,12 @@ def do_reduce(
if vv.reduciable:
rk = get_reduce_name(kk)
def_redu[rk] = OutputVariableDef(
rk, vv.shape, reduciable=False, differentiable=False, atomic=False
rk,
vv.shape,
reduciable=False,
differentiable=False,
atomic=False,
category=apply_operation(vv, OutputVariableOperation.REDU),
)
return def_redu

Expand All @@ -282,12 +374,14 @@ def do_derivative(
reduciable=False,
differentiable=False,
atomic=True,
category=apply_operation(vv, OutputVariableOperation.DERV_R),
)
def_derv_c[rkc] = OutputVariableDef(
rkc,
vv.shape + [3, 3], # noqa: RUF005
reduciable=True,
differentiable=False,
atomic=True,
category=apply_operation(vv, OutputVariableOperation.DERV_C),
)
return def_derv_r, def_derv_c
55 changes: 55 additions & 0 deletions source/tests/common/test_output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
model_check_output,
)
from deepmd.dpmodel.output_def import (
OutputVariableCategory,
OutputVariableOperation,
check_var,
)

Expand Down Expand Up @@ -103,6 +105,59 @@ def test_model_output_def(self):
self.assertEqual(md["energy_derv_r"].atomic, True)
self.assertEqual(md["energy_derv_c"].atomic, True)
self.assertEqual(md["energy_derv_c_redu"].atomic, False)
# category
self.assertEqual(md["energy"].category, OutputVariableCategory.OUT)
self.assertEqual(md["dos"].category, OutputVariableCategory.OUT)
self.assertEqual(md["foo"].category, OutputVariableCategory.OUT)
self.assertEqual(md["energy_redu"].category, OutputVariableCategory.REDU)
self.assertEqual(md["energy_derv_r"].category, OutputVariableCategory.DERV_R)
self.assertEqual(md["energy_derv_c"].category, OutputVariableCategory.DERV_C)
self.assertEqual(
md["energy_derv_c_redu"].category, OutputVariableCategory.DERV_C_REDU
)
# flag
self.assertEqual(md["energy"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(md["energy"].category & OutputVariableOperation.DERV_R, 0)
self.assertEqual(md["energy"].category & OutputVariableOperation.DERV_C, 0)
self.assertEqual(md["dos"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(md["dos"].category & OutputVariableOperation.DERV_R, 0)
self.assertEqual(md["dos"].category & OutputVariableOperation.DERV_C, 0)
self.assertEqual(md["foo"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(md["foo"].category & OutputVariableOperation.DERV_R, 0)
self.assertEqual(md["foo"].category & OutputVariableOperation.DERV_C, 0)
self.assertEqual(
md["energy_redu"].category & OutputVariableOperation.REDU,
OutputVariableOperation.REDU,
)
self.assertEqual(md["energy_redu"].category & OutputVariableOperation.DERV_R, 0)
self.assertEqual(md["energy_redu"].category & OutputVariableOperation.DERV_C, 0)
self.assertEqual(md["energy_derv_r"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(
md["energy_derv_r"].category & OutputVariableOperation.DERV_R,
OutputVariableOperation.DERV_R,
)
self.assertEqual(
md["energy_derv_r"].category & OutputVariableOperation.DERV_C, 0
)
self.assertEqual(md["energy_derv_c"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(
md["energy_derv_c"].category & OutputVariableOperation.DERV_R, 0
)
self.assertEqual(
md["energy_derv_c"].category & OutputVariableOperation.DERV_C,
OutputVariableOperation.DERV_C,
)
self.assertEqual(
md["energy_derv_c_redu"].category & OutputVariableOperation.REDU,
OutputVariableOperation.REDU,
)
self.assertEqual(
md["energy_derv_c_redu"].category & OutputVariableOperation.DERV_R, 0
)
self.assertEqual(
md["energy_derv_c_redu"].category & OutputVariableOperation.DERV_C,
OutputVariableOperation.DERV_C,
)

def test_raise_no_redu_deriv(self):
with self.assertRaises(ValueError) as context:
Expand Down