Skip to content

Commit 74702da

Browse files
authored
Fix HDF5 not understanding some files (#313)
* Fix bug in HDF5 that would cause an error during training when the dataset provides energies with shape (Nsamples,) instead of (Nsamples, 1) * Accommodate for previous behavior
1 parent 8b47246 commit 74702da

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torchmdnet/datasets/hdf.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ def _preload_data(self):
8989
# Watchout for the 1D case, embed can be shared for all samples
9090
tmp = torch.tensor(np.array(data), dtype=dtype)
9191
if tmp.ndim == 1:
92-
tmp = tmp.unsqueeze(0).expand(size, -1)
92+
if len(tmp) == size:
93+
tmp = tmp.unsqueeze(-1)
94+
else:
95+
tmp = tmp.unsqueeze(0).expand(size, -1)
9396
self.stored_data[field].append(tmp)
9497
self.index.extend(list(zip([i] * size, range(size))))
9598
i += 1

0 commit comments

Comments
 (0)