Skip to content

Commit 11eebc0

Browse files
committed
dataset filtering
1 parent d091f43 commit 11eebc0

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

torchmdnet/datasets/ace.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,23 @@ def sample_iter(self, mol_ids=False):
294294

295295

296296
class AceHF(Dataset):
297-
def __init__(self, root="parquet", paths=None, split="train") -> None:
297+
def __init__(
298+
self, root="parquet", paths=None, split="train", max_gradient=None
299+
) -> None:
298300
from datasets import load_dataset
301+
import numpy as np
299302

300303
self.dataset = load_dataset(root, data_files=paths, split=split)
304+
if max_gradient is not None:
305+
306+
def _filter(x):
307+
if np.isnan(x["forces"]).any() or np.isnan(x["formation_energy"]).any():
308+
return False
309+
return np.max(np.linalg.norm(x["forces"], axis=1)) < max_gradient
310+
311+
self.dataset = self.dataset.filter(
312+
_filter, desc="Filtering", num_proc=os.cpu_count() // 2
313+
)
301314
self.dataset = self.dataset.with_format("torch")
302315

303316
def __len__(self):

0 commit comments

Comments
 (0)