10
10
import psutil
11
11
from torchmdnet .datasets import Custom , HDF5 , Ace
12
12
from torchmdnet .utils import write_as_hdf5
13
+ from torch_geometric .loader import DataLoader
13
14
import h5py
14
15
import glob
15
16
@@ -297,3 +298,36 @@ def test_ace(tmpdir):
297
298
assert len (dataset_v2 ) == 8
298
299
f2 .flush ()
299
300
f2 .close ()
301
+
302
+
303
+ @mark .parametrize ("num_files" , [1 , 3 ])
304
+ @mark .parametrize ("tile_embed" , [True , False ])
305
+ @mark .parametrize ("batch_size" , [1 , 5 ])
306
+ def test_hdf5_with_and_without_caching (num_files , tile_embed , batch_size , tmpdir ):
307
+ """This test ensures that the output from the get of the HDF5 dataset is the same
308
+ when the dataset is loaded with and without caching."""
309
+
310
+ # set up necessary files
311
+ _ = write_sample_npy_files (True , True , tmpdir , num_files )
312
+ files = {}
313
+ files ["pos" ] = sorted (glob .glob (join (tmpdir , "coords*" )))
314
+ files ["z" ] = sorted (glob .glob (join (tmpdir , "embed*" )))
315
+ files ["y" ] = sorted (glob .glob (join (tmpdir , "energy*" )))
316
+ files ["neg_dy" ] = sorted (glob .glob (join (tmpdir , "forces*" )))
317
+
318
+ write_as_hdf5 (files , join (tmpdir , "test.hdf5" ), tile_embed )
319
+ # Assert file is present in the disk
320
+ assert os .path .isfile (join (tmpdir , "test.hdf5" )), "HDF5 file was not created"
321
+
322
+ data = HDF5 (join (tmpdir , "test.hdf5" ), dataset_preload_limit = 0 ) # no caching
323
+ data_cached = HDF5 (join (tmpdir , "test.hdf5" ), dataset_preload_limit = 256 ) # caching
324
+ assert len (data ) == len (data_cached ), "Number of samples does not match"
325
+
326
+ dl = DataLoader (data , batch_size )
327
+ dl_cached = DataLoader (data_cached , batch_size )
328
+
329
+ for sample_cached , sample in zip (dl_cached , dl ):
330
+ assert np .allclose (sample_cached .pos , sample .pos ), "Sample has incorrect coords"
331
+ assert np .allclose (sample_cached .z , sample .z ), "Sample has incorrect atom numbers"
332
+ assert np .allclose (sample_cached .y , sample .y ), "Sample has incorrect energy"
333
+ assert np .allclose (sample_cached .neg_dy , sample .neg_dy ), "Sample has incorrect forces"
0 commit comments