Skip to content

Commit d616c8a

Browse files
authored
Merge pull request #356 from AntonioMirarchi/hdf5
Fix HDF5 dataset consistency
2 parents 1deecd1 + cfd2700 commit d616c8a

File tree

3 files changed

+46
-3
lines changed

3 files changed

+46
-3
lines changed

tests/test_datasets.py

+34
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import psutil
1111
from torchmdnet.datasets import Custom, HDF5, Ace
1212
from torchmdnet.utils import write_as_hdf5
13+
from torch_geometric.loader import DataLoader
1314
import h5py
1415
import glob
1516

@@ -297,3 +298,36 @@ def test_ace(tmpdir):
297298
assert len(dataset_v2) == 8
298299
f2.flush()
299300
f2.close()
301+
302+
303+
@mark.parametrize("num_files", [1, 3])
304+
@mark.parametrize("tile_embed", [True, False])
305+
@mark.parametrize("batch_size", [1, 5])
306+
def test_hdf5_with_and_without_caching(num_files, tile_embed, batch_size, tmpdir):
307+
"""This test ensures that the output from the get of the HDF5 dataset is the same
308+
when the dataset is loaded with and without caching."""
309+
310+
# set up necessary files
311+
_ = write_sample_npy_files(True, True, tmpdir, num_files)
312+
files = {}
313+
files["pos"] = sorted(glob.glob(join(tmpdir, "coords*")))
314+
files["z"] = sorted(glob.glob(join(tmpdir, "embed*")))
315+
files["y"] = sorted(glob.glob(join(tmpdir, "energy*")))
316+
files["neg_dy"] = sorted(glob.glob(join(tmpdir, "forces*")))
317+
318+
write_as_hdf5(files, join(tmpdir, "test.hdf5"), tile_embed)
319+
# Assert file is present in the disk
320+
assert os.path.isfile(join(tmpdir, "test.hdf5")), "HDF5 file was not created"
321+
322+
data = HDF5(join(tmpdir, "test.hdf5"), dataset_preload_limit=0) # no caching
323+
data_cached = HDF5(join(tmpdir, "test.hdf5"), dataset_preload_limit=256) # caching
324+
assert len(data) == len(data_cached), "Number of samples does not match"
325+
326+
dl = DataLoader(data, batch_size)
327+
dl_cached = DataLoader(data_cached, batch_size)
328+
329+
for sample_cached, sample in zip(dl_cached, dl):
330+
assert np.allclose(sample_cached.pos, sample.pos), "Sample has incorrect coords"
331+
assert np.allclose(sample_cached.z, sample.z), "Sample has incorrect atom numbers"
332+
assert np.allclose(sample_cached.y, sample.y), "Sample has incorrect energy"
333+
assert np.allclose(sample_cached.neg_dy, sample.neg_dy), "Sample has incorrect forces"

torchmdnet/datasets/hdf.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,13 @@ def get(self, idx):
125125
if self.index is None:
126126
self._setup_index()
127127
*fields_data, i = self.index[idx]
128+
# Assuming the first element of fields_data is 'pos' based on the definition of self.fields
129+
size = len(fields_data[0])
128130
for (name, _, dtype), d in zip(self.fields, fields_data):
129-
tensor_input = [[d[i]]] if d.ndim == 1 else d[i]
131+
if d.ndim == 1:
132+
tensor_input = [d[i]] if len(d) == size else d[:]
133+
else:
134+
tensor_input = d[i]
130135
data[name] = torch.tensor(tensor_input, dtype=dtype)
131136
return data
132137

torchmdnet/utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -346,12 +346,13 @@ class MissingEnergyException(Exception):
346346
pass
347347

348348

349-
def write_as_hdf5(files, hdf5_dataset):
349+
def write_as_hdf5(files, hdf5_dataset, tile_embed=True):
350350
"""Transform the input numpy files to hdf5 format compatible with the HDF5 Dataset class.
351351
The input files to this function are the same as the ones required by the Custom dataset.
352352
Args:
353353
files (dict): Dictionary of numpy input files. Must contain "pos", "z" and at least one of "y" or "neg_dy".
354354
hdf5_dataset (string): Path to the output HDF5 dataset.
355+
tile_embed (bool): Whether to tile the embeddings to match the number of samples. Default: True
355356
Example:
356357
>>> files = {}
357358
>>> files["pos"] = sorted(glob.glob(join(tmpdir, "coords*")))
@@ -370,7 +371,10 @@ def write_as_hdf5(files, hdf5_dataset):
370371
group = f.create_group(str(i))
371372
num_samples = coord_data.shape[0]
372373
group.create_dataset("pos", data=coord_data)
373-
group.create_dataset("types", data=np.tile(embed_data, (num_samples, 1)))
374+
if tile_embed:
375+
group.create_dataset("types", data=np.tile(embed_data, (num_samples, 1)))
376+
else:
377+
group.create_dataset("types", data=embed_data)
374378
if "y" in files:
375379
energy_data = np.load(files["y"][i], mmap_mode="r")
376380
group.create_dataset("energy", data=energy_data)

0 commit comments

Comments
 (0)