23
23
from deepmd .pt .utils import (
24
24
AtomExcludeMask ,
25
25
PairExcludeMask ,
26
+ env ,
26
27
)
27
28
from deepmd .pt .utils .nlist import (
28
29
extend_input_and_build_neighbor_list ,
35
36
)
36
37
37
38
log = logging .getLogger (__name__ )
39
+ dtype = env .GLOBAL_PT_FLOAT_PRECISION
40
+ device = env .DEVICE
38
41
39
42
BaseAtomicModel_ = make_base_atomic_model (torch .Tensor )
40
43
41
44
42
- class BaseAtomicModel (BaseAtomicModel_ ):
45
+ class BaseAtomicModel (torch .nn .Module , BaseAtomicModel_ ):
46
+ """The base of atomic model.
47
+
48
+ Parameters
49
+ ----------
50
+ type_map
51
+ Mapping atom type to the name (str) of the type.
52
+ For example `type_map[1]` gives the name of the type 1.
53
+ atom_exclude_types
54
+ Exclude the atomic contribution of the given types
55
+ pair_exclude_types
56
+ Exclude the pair of atoms of the given types from computing the output
57
+ of the atomic model. Implemented by removing the pairs from the nlist.
58
+ rcond : float, optional
59
+ The condition number for the regression of atomic energy.
60
+ preset_out_bias : Dict[str, List[Optional[torch.Tensor]]], optional
61
+ Specifying atomic energy contribution in vacuum. Given by key:value pairs.
62
+ The value is a list specifying the bias. the elements can be None or np.array of output shape.
63
+ For example: [None, [2.]] means type 0 is not set, type 1 is set to [2.]
64
+ The `set_davg_zero` key in the descrptor should be set.
65
+
66
+ """
67
+
43
68
def __init__ (
44
69
self ,
70
+ type_map : List [str ],
45
71
atom_exclude_types : List [int ] = [],
46
72
pair_exclude_types : List [Tuple [int , int ]] = [],
73
+ rcond : Optional [float ] = None ,
74
+ preset_out_bias : Optional [Dict [str , torch .Tensor ]] = None ,
47
75
):
48
- super ().__init__ ()
76
+ torch .nn .Module .__init__ (self )
77
+ BaseAtomicModel_ .__init__ (self )
78
+ self .type_map = type_map
49
79
self .reinit_atom_exclude (atom_exclude_types )
50
80
self .reinit_pair_exclude (pair_exclude_types )
81
+ self .rcond = rcond
82
+ self .preset_out_bias = preset_out_bias
83
+
84
+ def init_out_stat (self ):
85
+ """Initialize the output bias."""
86
+ ntypes = self .get_ntypes ()
87
+ self .bias_keys : List [str ] = list (self .fitting_output_def ().keys ())
88
+ self .max_out_size = max (
89
+ [self .atomic_output_def ()[kk ].size for kk in self .bias_keys ]
90
+ )
91
+ self .n_out = len (self .bias_keys )
92
+ out_bias_data = torch .zeros (
93
+ [self .n_out , ntypes , self .max_out_size ], dtype = dtype , device = device
94
+ )
95
+ out_std_data = torch .ones (
96
+ [self .n_out , ntypes , self .max_out_size ], dtype = dtype , device = device
97
+ )
98
+ self .register_buffer ("out_bias" , out_bias_data )
99
+ self .register_buffer ("out_std" , out_std_data )
100
+
101
+ def __setitem__ (self , key , value ):
102
+ if key in ["out_bias" ]:
103
+ self .out_bias = value
104
+ elif key in ["out_std" ]:
105
+ self .out_std = value
106
+ else :
107
+ raise KeyError (key )
108
+
109
+ def __getitem__ (self , key ):
110
+ if key in ["out_bias" ]:
111
+ return self .out_bias
112
+ elif key in ["out_std" ]:
113
+ return self .out_std
114
+ else :
115
+ raise KeyError (key )
116
+
117
+ @torch .jit .export
118
+ def get_type_map (self ) -> List [str ]:
119
+ """Get the type map."""
120
+ return self .type_map
51
121
52
122
def reinit_atom_exclude (
53
123
self ,
@@ -165,6 +235,7 @@ def forward_common_atomic(
165
235
fparam = fparam ,
166
236
aparam = aparam ,
167
237
)
238
+ ret_dict = self .apply_out_stat (ret_dict , atype )
168
239
169
240
# nf x nloc
170
241
atom_mask = ext_atom_mask [:, :nloc ].to (torch .int32 )
@@ -210,9 +281,60 @@ def compute_or_load_stat(
210
281
"""
211
282
raise NotImplementedError
212
283
284
+ def compute_or_load_out_stat (
285
+ self ,
286
+ merged : Union [Callable [[], List [dict ]], List [dict ]],
287
+ stat_file_path : Optional [DPPath ] = None ,
288
+ ):
289
+ """
290
+ Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
291
+
292
+ Parameters
293
+ ----------
294
+ merged : Union[Callable[[], List[dict]], List[dict]]
295
+ - List[dict]: A list of data samples from various data systems.
296
+ Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
297
+ originating from the `i`-th data system.
298
+ - Callable[[], List[dict]]: A lazy function that returns data samples in the above format
299
+ only when needed. Since the sampling process can be slow and memory-intensive,
300
+ the lazy function helps by only sampling once.
301
+ stat_file_path : Optional[DPPath]
302
+ The path to the stat file.
303
+
304
+ """
305
+ self .change_out_bias (
306
+ merged ,
307
+ stat_file_path = stat_file_path ,
308
+ bias_adjust_mode = "set-by-statistic" ,
309
+ )
310
+
311
+ def apply_out_stat (
312
+ self ,
313
+ ret : Dict [str , torch .Tensor ],
314
+ atype : torch .Tensor ,
315
+ ):
316
+ """Apply the stat to each atomic output.
317
+ The developer may override the method to define how the bias is applied
318
+ to the atomic output of the model.
319
+
320
+ Parameters
321
+ ----------
322
+ ret
323
+ The returned dict by the forward_atomic method
324
+ atype
325
+ The atom types. nf x nloc
326
+
327
+ """
328
+ out_bias , out_std = self ._fetch_out_stat (self .bias_keys )
329
+ for kk in self .bias_keys :
330
+ # nf x nloc x odims, out_bias: ntypes x odims
331
+ ret [kk ] = ret [kk ] + out_bias [kk ][atype ]
332
+ return ret
333
+
213
334
def change_out_bias (
214
335
self ,
215
336
sample_merged ,
337
+ stat_file_path : Optional [DPPath ] = None ,
216
338
bias_adjust_mode = "change-by-statistic" ,
217
339
) -> None :
218
340
"""Change the output bias according to the input data and the pretrained model.
@@ -231,22 +353,32 @@ def change_out_bias(
231
353
'change-by-statistic' : perform predictions on labels of target dataset,
232
354
and do least square on the errors to obtain the target shift as bias.
233
355
'set-by-statistic' : directly use the statistic output bias in the target dataset.
356
+ stat_file_path : Optional[DPPath]
357
+ The path to the stat file.
234
358
"""
235
359
if bias_adjust_mode == "change-by-statistic" :
236
- delta_bias = compute_output_stats (
360
+ delta_bias , out_std = compute_output_stats (
237
361
sample_merged ,
238
362
self .get_ntypes (),
239
- keys = self .get_output_keys (),
363
+ keys = list (self .atomic_output_def ().keys ()),
364
+ stat_file_path = stat_file_path ,
240
365
model_forward = self ._get_forward_wrapper_func (),
241
- )["energy" ]
242
- self .set_out_bias (delta_bias , add = True )
366
+ rcond = self .rcond ,
367
+ preset_bias = self .preset_out_bias ,
368
+ )
369
+ # self.set_out_bias(delta_bias, add=True)
370
+ self ._store_out_stat (delta_bias , out_std , add = True )
243
371
elif bias_adjust_mode == "set-by-statistic" :
244
- bias_atom = compute_output_stats (
372
+ bias_out , std_out = compute_output_stats (
245
373
sample_merged ,
246
374
self .get_ntypes (),
247
- keys = self .get_output_keys (),
248
- )["energy" ]
249
- self .set_out_bias (bias_atom )
375
+ keys = list (self .atomic_output_def ().keys ()),
376
+ stat_file_path = stat_file_path ,
377
+ rcond = self .rcond ,
378
+ preset_bias = self .preset_out_bias ,
379
+ )
380
+ # self.set_out_bias(bias_out)
381
+ self ._store_out_stat (bias_out , std_out )
250
382
else :
251
383
raise RuntimeError ("Unknown bias_adjust_mode mode: " + bias_adjust_mode )
252
384
@@ -279,3 +411,63 @@ def model_forward(coord, atype, box, fparam=None, aparam=None):
279
411
return {kk : vv .detach () for kk , vv in atomic_ret .items ()}
280
412
281
413
return model_forward
414
+
415
+ def _varsize (
416
+ self ,
417
+ shape : List [int ],
418
+ ) -> int :
419
+ output_size = 1
420
+ len_shape = len (shape )
421
+ for i in range (len_shape ):
422
+ output_size *= shape [i ]
423
+ return output_size
424
+
425
+ def _get_bias_index (
426
+ self ,
427
+ kk : str ,
428
+ ) -> int :
429
+ res : List [int ] = []
430
+ for i , e in enumerate (self .bias_keys ):
431
+ if e == kk :
432
+ res .append (i )
433
+ assert len (res ) == 1
434
+ return res [0 ]
435
+
436
+ def _store_out_stat (
437
+ self ,
438
+ out_bias : Dict [str , torch .Tensor ],
439
+ out_std : Dict [str , torch .Tensor ],
440
+ add : bool = False ,
441
+ ):
442
+ ntypes = self .get_ntypes ()
443
+ out_bias_data = torch .clone (self .out_bias )
444
+ out_std_data = torch .clone (self .out_std )
445
+ for kk in out_bias .keys ():
446
+ assert kk in out_std .keys ()
447
+ idx = self ._get_bias_index (kk )
448
+ size = self ._varsize (self .atomic_output_def ()[kk ].shape )
449
+ if not add :
450
+ out_bias_data [idx , :, :size ] = out_bias [kk ].view (ntypes , size )
451
+ else :
452
+ out_bias_data [idx , :, :size ] += out_bias [kk ].view (ntypes , size )
453
+ out_std_data [idx , :, :size ] = out_std [kk ].view (ntypes , size )
454
+ self .out_bias .copy_ (out_bias_data )
455
+ self .out_std .copy_ (out_std_data )
456
+
457
+ def _fetch_out_stat (
458
+ self ,
459
+ keys : List [str ],
460
+ ) -> Tuple [Dict [str , torch .Tensor ], Dict [str , torch .Tensor ]]:
461
+ ret_bias = {}
462
+ ret_std = {}
463
+ ntypes = self .get_ntypes ()
464
+ for kk in keys :
465
+ idx = self ._get_bias_index (kk )
466
+ isize = self ._varsize (self .atomic_output_def ()[kk ].shape )
467
+ ret_bias [kk ] = self .out_bias [idx , :, :isize ].view (
468
+ [ntypes ] + list (self .atomic_output_def ()[kk ].shape ) # noqa: RUF005
469
+ )
470
+ ret_std [kk ] = self .out_std [idx , :, :isize ].view (
471
+ [ntypes ] + list (self .atomic_output_def ()[kk ].shape ) # noqa: RUF005
472
+ )
473
+ return ret_bias , ret_std
0 commit comments