Skip to content

Feat: Refactor dipole fitting pytorch #3281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 110 additions & 50 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,127 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
List,
Optional,
)

import torch

from deepmd.pt.model.network.network import (
ResidualDeep,
)
from deepmd.pt.model.task.fitting import (
Fitting,
GeneralFitting,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
)

log = logging.getLogger(__name__)


class DipoleFittingNet(Fitting):
def __init__(
self, ntypes, embedding_width, neuron, out_dim, resnet_dt=True, **kwargs
):
"""Construct a fitting net for dipole.
class DipoleFittingNet(GeneralFitting):
"""Construct a general fitting net.

Args:
- ntypes: Element count.
- embedding_width: Embedding width per atom.
- neuron: Number of neurons in each hidden layers of the fitting net.
- bias_atom_e: Average enery per atom for each element.
- resnet_dt: Using time-step in the ResNet construction.
"""
super().__init__()
self.ntypes = ntypes
self.embedding_width = embedding_width
self.out_dim = out_dim
Parameters
----------
var_name : str
The atomic property to fit, 'dipole'.
ntypes : int
Element count.
dim_descrpt : int
Embedding width per atom.
dim_out : int
The output dimension of the fitting net.
dim_rot_mat : int
The dimension of rotation matrix, m1.
neuron : List[int]
Number of neurons in each hidden layers of the fitting net.
resnet_dt : bool
Using time-step in the ResNet construction.
numb_fparam : int
Number of frame parameters.
numb_aparam : int
Number of atomic parameters.
activation_function : str
Activation function.
precision : str
Numerical precision.
distinguish_types : bool
Neighbor list that distinguish different atomic types or not.
rcond : float, optional
The condition number for the regression of atomic energy.
seed : int, optional
Random seed.
"""

filter_layers = []
one = ResidualDeep(
0, embedding_width, neuron, 0.0, out_dim=self.out_dim, resnet_dt=resnet_dt
def __init__(
self,
var_name: str,
ntypes: int,
dim_descrpt: int,
dim_out: int,
dim_rot_mat: int,
neuron: List[int] = [128, 128, 128],
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
distinguish_types: bool = False,
rcond: Optional[float] = None,
seed: Optional[int] = None,
**kwargs,
):
self.dim_rot_mat = dim_rot_mat
super().__init__(
var_name=var_name,
ntypes=ntypes,
dim_descrpt=dim_descrpt,
dim_out=dim_out,
neuron=neuron,
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
activation_function=activation_function,
precision=precision,
distinguish_types=distinguish_types,
rcond=rcond,
seed=seed,
**kwargs,
)
filter_layers.append(one)
self.filter_layers = torch.nn.ModuleList(filter_layers)
self.old_impl = False # this only supports the new implementation.

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class

Assignment overwrites attribute old_impl, which was previously defined in superclass [GeneralFitting](1).

if "seed" in kwargs:
log.info("Set seed to %d in fitting net.", kwargs["seed"])
torch.manual_seed(kwargs["seed"])
def _net_out_dim(self):
"""Set the FittingNet output dim."""
return self.dim_rot_mat

def forward(self, inputs, atype, atype_tebd, rot_mat):
"""Based on embedding net output, alculate total energy.
def serialize(self) -> dict:
data = super().serialize()
data["dim_rot_mat"] = self.dim_rot_mat
data["old_impl"] = self.old_impl
return data

Args:
- inputs: Descriptor. Its shape is [nframes, nloc, self.embedding_width].
- atype: Atom type. Its shape is [nframes, nloc].
- atype_tebd: Atom type embedding. Its shape is [nframes, nloc, tebd_dim]
- rot_mat: GR during descriptor calculation. Its shape is [nframes * nloc, m1, 3].

Returns
-------
- vec_out: output vector. Its shape is [nframes, nloc, 3].
"""
nframes, nloc, _ = inputs.size()
if atype_tebd is not None:
inputs = torch.concat([inputs, atype_tebd], dim=-1)
vec_out = self.filter_layers[0](inputs) # Shape is [nframes, nloc, m1]
assert list(vec_out.size()) == [nframes, nloc, self.out_dim]
vec_out = vec_out.view(-1, 1, self.out_dim)
vec_out = (
torch.bmm(vec_out, rot_mat).squeeze(-2).view(nframes, nloc, 3)
) # Shape is [nframes, nloc, 3]
return vec_out
def forward(
self,
descriptor: torch.Tensor,
atype: torch.Tensor,
gr: Optional[torch.Tensor] = None,
g2: Optional[torch.Tensor] = None,
h2: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
):
nframes, nloc, _ = descriptor.shape
assert gr is not None, "Must provide the rotation matrix for dipole fitting."
# (nframes, nloc, m1)
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
# (nframes * nloc, 1, m1)
out = out.view(-1, 1, self.dim_rot_mat)
# (nframes * nloc, m1, 3)
gr = gr.view(nframes * nloc, -1, 3)
# (nframes, nloc, 3)
out = torch.bmm(out, gr).squeeze(-2).view(nframes, nloc, 3)
return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}
10 changes: 5 additions & 5 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,12 @@ def __init__(
self.prec = PRECISION_DICT[self.precision]
self.rcond = rcond

net_dim_out = self._net_out_dim()

Check warning

Code scanning / CodeQL

`__init__` method calls overridden method

Call to self.[_net_out_dim](1) in __init__ method, which is overridden by [method DipoleFittingNet._net_out_dim](2). Call to self.[_net_out_dim](1) in __init__ method, which is overridden by [method InvarFitting._net_out_dim](3).
# init constants
if bias_atom_e is None:
bias_atom_e = np.zeros([self.ntypes, self.dim_out])
bias_atom_e = np.zeros([self.ntypes, net_dim_out])
bias_atom_e = torch.tensor(bias_atom_e, dtype=self.prec, device=device)
bias_atom_e = bias_atom_e.view([self.ntypes, self.dim_out])
bias_atom_e = bias_atom_e.view([self.ntypes, net_dim_out])
if not self.use_tebd:
assert self.ntypes == bias_atom_e.shape[0], "Element count mismatches!"
self.register_buffer("bias_atom_e", bias_atom_e)
Expand Down Expand Up @@ -449,7 +450,6 @@ def __init__(
in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam

self.old_impl = kwargs.get("old_impl", False)
net_dim_out = self._net_out_dim()
if self.old_impl:
filter_layers = []
for type_i in range(self.ntypes):
Expand Down Expand Up @@ -591,6 +591,7 @@ def _forward_common(
):
xx = descriptor
nf, nloc, nd = xx.shape
net_dim_out = self._net_out_dim()

if nd != self.dim_descrpt:
raise ValueError(
Expand Down Expand Up @@ -638,7 +639,7 @@ def _forward_common(
)

outs = torch.zeros(
(nf, nloc, self.dim_out),
(nf, nloc, net_dim_out),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
) # jit assertion
Expand All @@ -665,7 +666,6 @@ def _forward_common(
)
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
else:
net_dim_out = self._net_out_dim()
for type_i, ll in enumerate(self.filter_layers.networks):
mask = (atype == type_i).unsqueeze(-1)
mask = torch.tile(mask, (1, 1, net_dim_out))
Expand Down