Skip to content

Commit 39d027e

Browse files
wanghan-iapcmHan WangiProzdpre-commit-ci[bot]
authored
compute output stat for atomic model (deepmodeling#3642)
This PR: - breaking change: the base atomic model is now a module. - reason: the out stat is a data attribute of the base atomic model. - implement the `compute_or_load_output_stat` for the base atomic model. the method computes both bias and std. - the derived atomic models call the `compute_or_load_output_stat` method for computing output stat. - atomic model provides the `apply_out_stat`, the derived class may override the method to define how the statistics is applied to an atomic model's output. @anyangml may need. - `out_stat` support statistics of output tensor of any shape. @iProzd please check if i took it correctly in [ce7ec1f](deepmodeling@ce7ec1f) To be done: - atomic statistics of the bias and std. @anyangml - erialization and deserialization. --------- Signed-off-by: Han Wang <[email protected]> Co-authored-by: Han Wang <[email protected]> Co-authored-by: Duo <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent edb9da8 commit 39d027e

18 files changed

+823
-121
lines changed

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,19 @@
2727
class BaseAtomicModel(BaseAtomicModel_):
2828
def __init__(
2929
self,
30+
type_map: List[str],
3031
atom_exclude_types: List[int] = [],
3132
pair_exclude_types: List[Tuple[int, int]] = [],
3233
):
3334
super().__init__()
35+
self.type_map = type_map
3436
self.reinit_atom_exclude(atom_exclude_types)
3537
self.reinit_pair_exclude(pair_exclude_types)
3638

39+
def get_type_map(self) -> List[str]:
40+
"""Get the type map."""
41+
return self.type_map
42+
3743
def reinit_atom_exclude(
3844
self,
3945
exclude_types: List[int] = [],

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
self.descriptor = descriptor
5454
self.fitting = fitting
5555
self.type_map = type_map
56-
super().__init__(**kwargs)
56+
super().__init__(type_map, **kwargs)
5757

5858
def fitting_output_def(self) -> FittingOutputDef:
5959
"""Get the output def of the fitting net."""
@@ -67,10 +67,6 @@ def get_sel(self) -> List[int]:
6767
"""Get the neighbor selection."""
6868
return self.descriptor.get_sel()
6969

70-
def get_type_map(self) -> List[str]:
71-
"""Get the type map."""
72-
return self.type_map
73-
7470
def mixed_types(self) -> bool:
7571
"""If true, the model
7672
1. assumes total number of atoms aligned across frames;

deepmd/dpmodel/atomic_model/linear_atomic_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
self.mapping_list.append(self.remap_atype(tpmp, self.type_map))
6767
assert len(err_msg) == 0, "\n".join(err_msg)
6868
self.mixed_types_list = [model.mixed_types() for model in self.models]
69-
super().__init__(**kwargs)
69+
super().__init__(type_map, **kwargs)
7070

7171
def mixed_types(self) -> bool:
7272
"""If true, the model

deepmd/dpmodel/atomic_model/make_base_atomic_model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ def atomic_output_def(self) -> FittingOutputDef:
5151
"""
5252
return self.fitting_output_def()
5353

54-
def get_output_keys(self) -> List[str]:
55-
return list(self.atomic_output_def().keys())
56-
5754
@abstractmethod
5855
def get_rcut(self) -> float:
5956
"""Get the cut-off radius."""

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,20 @@ def __init__(
5959
rcut: float,
6060
sel: Union[int, List[int]],
6161
type_map: List[str],
62+
rcond: Optional[float] = None,
63+
atom_ener: Optional[List[float]] = None,
6264
**kwargs,
6365
):
64-
super().__init__()
66+
super().__init__(type_map, **kwargs)
6567
self.tab_file = tab_file
6668
self.rcut = rcut
6769
self.type_map = type_map
6870

6971
self.tab = PairTab(self.tab_file, rcut=rcut)
7072
self.type_map = type_map
7173
self.ntypes = len(type_map)
74+
self.rcond = rcond
75+
self.atom_ener = atom_ener
7276

7377
if self.tab_file is not None:
7478
self.tab_info, self.tab_data = self.tab.get()

deepmd/dpmodel/output_def.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,10 @@ def __init__(
224224
if not self.r_differentiable:
225225
raise ValueError("only r_differentiable variable can calculate hessian")
226226

227+
@property
228+
def size(self):
229+
return self.output_size
230+
227231

228232
class FittingOutputDef:
229233
"""Defines the shapes and other properties of the fitting network outputs.

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 202 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from deepmd.pt.utils import (
2424
AtomExcludeMask,
2525
PairExcludeMask,
26+
env,
2627
)
2728
from deepmd.pt.utils.nlist import (
2829
extend_input_and_build_neighbor_list,
@@ -35,19 +36,88 @@
3536
)
3637

3738
log = logging.getLogger(__name__)
39+
dtype = env.GLOBAL_PT_FLOAT_PRECISION
40+
device = env.DEVICE
3841

3942
BaseAtomicModel_ = make_base_atomic_model(torch.Tensor)
4043

4144

42-
class BaseAtomicModel(BaseAtomicModel_):
45+
class BaseAtomicModel(torch.nn.Module, BaseAtomicModel_):
46+
"""The base of atomic model.
47+
48+
Parameters
49+
----------
50+
type_map
51+
Mapping atom type to the name (str) of the type.
52+
For example `type_map[1]` gives the name of the type 1.
53+
atom_exclude_types
54+
Exclude the atomic contribution of the given types
55+
pair_exclude_types
56+
Exclude the pair of atoms of the given types from computing the output
57+
of the atomic model. Implemented by removing the pairs from the nlist.
58+
rcond : float, optional
59+
The condition number for the regression of atomic energy.
60+
preset_out_bias : Dict[str, List[Optional[torch.Tensor]]], optional
61+
Specifying atomic energy contribution in vacuum. Given by key:value pairs.
62+
The value is a list specifying the bias. the elements can be None or np.array of output shape.
63+
For example: [None, [2.]] means type 0 is not set, type 1 is set to [2.]
64+
The `set_davg_zero` key in the descrptor should be set.
65+
66+
"""
67+
4368
def __init__(
4469
self,
70+
type_map: List[str],
4571
atom_exclude_types: List[int] = [],
4672
pair_exclude_types: List[Tuple[int, int]] = [],
73+
rcond: Optional[float] = None,
74+
preset_out_bias: Optional[Dict[str, torch.Tensor]] = None,
4775
):
48-
super().__init__()
76+
torch.nn.Module.__init__(self)
77+
BaseAtomicModel_.__init__(self)
78+
self.type_map = type_map
4979
self.reinit_atom_exclude(atom_exclude_types)
5080
self.reinit_pair_exclude(pair_exclude_types)
81+
self.rcond = rcond
82+
self.preset_out_bias = preset_out_bias
83+
84+
def init_out_stat(self):
85+
"""Initialize the output bias."""
86+
ntypes = self.get_ntypes()
87+
self.bias_keys: List[str] = list(self.fitting_output_def().keys())
88+
self.max_out_size = max(
89+
[self.atomic_output_def()[kk].size for kk in self.bias_keys]
90+
)
91+
self.n_out = len(self.bias_keys)
92+
out_bias_data = torch.zeros(
93+
[self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device
94+
)
95+
out_std_data = torch.ones(
96+
[self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device
97+
)
98+
self.register_buffer("out_bias", out_bias_data)
99+
self.register_buffer("out_std", out_std_data)
100+
101+
def __setitem__(self, key, value):
102+
if key in ["out_bias"]:
103+
self.out_bias = value
104+
elif key in ["out_std"]:
105+
self.out_std = value
106+
else:
107+
raise KeyError(key)
108+
109+
def __getitem__(self, key):
110+
if key in ["out_bias"]:
111+
return self.out_bias
112+
elif key in ["out_std"]:
113+
return self.out_std
114+
else:
115+
raise KeyError(key)
116+
117+
@torch.jit.export
118+
def get_type_map(self) -> List[str]:
119+
"""Get the type map."""
120+
return self.type_map
51121

52122
def reinit_atom_exclude(
53123
self,
@@ -165,6 +235,7 @@ def forward_common_atomic(
165235
fparam=fparam,
166236
aparam=aparam,
167237
)
238+
ret_dict = self.apply_out_stat(ret_dict, atype)
168239

169240
# nf x nloc
170241
atom_mask = ext_atom_mask[:, :nloc].to(torch.int32)
@@ -210,9 +281,60 @@ def compute_or_load_stat(
210281
"""
211282
raise NotImplementedError
212283

284+
def compute_or_load_out_stat(
285+
self,
286+
merged: Union[Callable[[], List[dict]], List[dict]],
287+
stat_file_path: Optional[DPPath] = None,
288+
):
289+
"""
290+
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
291+
292+
Parameters
293+
----------
294+
merged : Union[Callable[[], List[dict]], List[dict]]
295+
- List[dict]: A list of data samples from various data systems.
296+
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
297+
originating from the `i`-th data system.
298+
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
299+
only when needed. Since the sampling process can be slow and memory-intensive,
300+
the lazy function helps by only sampling once.
301+
stat_file_path : Optional[DPPath]
302+
The path to the stat file.
303+
304+
"""
305+
self.change_out_bias(
306+
merged,
307+
stat_file_path=stat_file_path,
308+
bias_adjust_mode="set-by-statistic",
309+
)
310+
311+
def apply_out_stat(
312+
self,
313+
ret: Dict[str, torch.Tensor],
314+
atype: torch.Tensor,
315+
):
316+
"""Apply the stat to each atomic output.
317+
The developer may override the method to define how the bias is applied
318+
to the atomic output of the model.
319+
320+
Parameters
321+
----------
322+
ret
323+
The returned dict by the forward_atomic method
324+
atype
325+
The atom types. nf x nloc
326+
327+
"""
328+
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
329+
for kk in self.bias_keys:
330+
# nf x nloc x odims, out_bias: ntypes x odims
331+
ret[kk] = ret[kk] + out_bias[kk][atype]
332+
return ret
333+
213334
def change_out_bias(
214335
self,
215336
sample_merged,
337+
stat_file_path: Optional[DPPath] = None,
216338
bias_adjust_mode="change-by-statistic",
217339
) -> None:
218340
"""Change the output bias according to the input data and the pretrained model.
@@ -231,22 +353,32 @@ def change_out_bias(
231353
'change-by-statistic' : perform predictions on labels of target dataset,
232354
and do least square on the errors to obtain the target shift as bias.
233355
'set-by-statistic' : directly use the statistic output bias in the target dataset.
356+
stat_file_path : Optional[DPPath]
357+
The path to the stat file.
234358
"""
235359
if bias_adjust_mode == "change-by-statistic":
236-
delta_bias = compute_output_stats(
360+
delta_bias, out_std = compute_output_stats(
237361
sample_merged,
238362
self.get_ntypes(),
239-
keys=self.get_output_keys(),
363+
keys=list(self.atomic_output_def().keys()),
364+
stat_file_path=stat_file_path,
240365
model_forward=self._get_forward_wrapper_func(),
241-
)["energy"]
242-
self.set_out_bias(delta_bias, add=True)
366+
rcond=self.rcond,
367+
preset_bias=self.preset_out_bias,
368+
)
369+
# self.set_out_bias(delta_bias, add=True)
370+
self._store_out_stat(delta_bias, out_std, add=True)
243371
elif bias_adjust_mode == "set-by-statistic":
244-
bias_atom = compute_output_stats(
372+
bias_out, std_out = compute_output_stats(
245373
sample_merged,
246374
self.get_ntypes(),
247-
keys=self.get_output_keys(),
248-
)["energy"]
249-
self.set_out_bias(bias_atom)
375+
keys=list(self.atomic_output_def().keys()),
376+
stat_file_path=stat_file_path,
377+
rcond=self.rcond,
378+
preset_bias=self.preset_out_bias,
379+
)
380+
# self.set_out_bias(bias_out)
381+
self._store_out_stat(bias_out, std_out)
250382
else:
251383
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
252384

@@ -279,3 +411,63 @@ def model_forward(coord, atype, box, fparam=None, aparam=None):
279411
return {kk: vv.detach() for kk, vv in atomic_ret.items()}
280412

281413
return model_forward
414+
415+
def _varsize(
416+
self,
417+
shape: List[int],
418+
) -> int:
419+
output_size = 1
420+
len_shape = len(shape)
421+
for i in range(len_shape):
422+
output_size *= shape[i]
423+
return output_size
424+
425+
def _get_bias_index(
426+
self,
427+
kk: str,
428+
) -> int:
429+
res: List[int] = []
430+
for i, e in enumerate(self.bias_keys):
431+
if e == kk:
432+
res.append(i)
433+
assert len(res) == 1
434+
return res[0]
435+
436+
def _store_out_stat(
437+
self,
438+
out_bias: Dict[str, torch.Tensor],
439+
out_std: Dict[str, torch.Tensor],
440+
add: bool = False,
441+
):
442+
ntypes = self.get_ntypes()
443+
out_bias_data = torch.clone(self.out_bias)
444+
out_std_data = torch.clone(self.out_std)
445+
for kk in out_bias.keys():
446+
assert kk in out_std.keys()
447+
idx = self._get_bias_index(kk)
448+
size = self._varsize(self.atomic_output_def()[kk].shape)
449+
if not add:
450+
out_bias_data[idx, :, :size] = out_bias[kk].view(ntypes, size)
451+
else:
452+
out_bias_data[idx, :, :size] += out_bias[kk].view(ntypes, size)
453+
out_std_data[idx, :, :size] = out_std[kk].view(ntypes, size)
454+
self.out_bias.copy_(out_bias_data)
455+
self.out_std.copy_(out_std_data)
456+
457+
def _fetch_out_stat(
458+
self,
459+
keys: List[str],
460+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
461+
ret_bias = {}
462+
ret_std = {}
463+
ntypes = self.get_ntypes()
464+
for kk in keys:
465+
idx = self._get_bias_index(kk)
466+
isize = self._varsize(self.atomic_output_def()[kk].shape)
467+
ret_bias[kk] = self.out_bias[idx, :, :isize].view(
468+
[ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005
469+
)
470+
ret_std[kk] = self.out_std[idx, :, :isize].view(
471+
[ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005
472+
)
473+
return ret_bias, ret_std

0 commit comments

Comments
 (0)