@@ -337,6 +337,44 @@ def tf_compute_input_stats(self):
337
337
)
338
338
339
339
340
+ class TestExcludeTypes (DatasetTest , unittest .TestCase ):
341
+ def setup_data (self ):
342
+ original_data = str (Path (__file__ ).parent / "water/data/data_0" )
343
+ picked_data = str (Path (__file__ ).parent / "picked_data_for_test_stat" )
344
+ dpdata .LabeledSystem (original_data , fmt = "deepmd/npy" )[:2 ].to_deepmd_npy (
345
+ picked_data
346
+ )
347
+ self .mixed_type = False
348
+ return picked_data
349
+
350
+ def setup_tf (self ):
351
+ return DescrptSeA_tf (
352
+ rcut = self .rcut ,
353
+ rcut_smth = self .rcut_smth ,
354
+ sel = self .sel ,
355
+ neuron = self .filter_neuron ,
356
+ axis_neuron = self .axis_neuron ,
357
+ exclude_types = [[0 , 0 ], [1 , 1 ]],
358
+ )
359
+
360
+ def setup_pt (self ):
361
+ return DescrptSeA (
362
+ self .rcut ,
363
+ self .rcut_smth ,
364
+ self .sel ,
365
+ self .filter_neuron ,
366
+ self .axis_neuron ,
367
+ exclude_types = [[0 , 0 ], [1 , 1 ]],
368
+ ).sea # get the block who has stat as private vars
369
+
370
+ def tf_compute_input_stats (self ):
371
+ coord = self .dp_merged ["coord" ]
372
+ atype = self .dp_merged ["type" ]
373
+ natoms = self .dp_merged ["natoms_vec" ]
374
+ box = self .dp_merged ["box" ]
375
+ self .dp_d .compute_input_stats (coord , box , atype , natoms , self .dp_mesh , {})
376
+
377
+
340
378
class TestOutputStat (unittest .TestCase ):
341
379
def setUp (self ):
342
380
self .data_file = [str (Path (__file__ ).parent / "water/data/data_0" )]
0 commit comments