Skip to content

Commit d091f43

Browse files
committed
cleanup
1 parent db91b60 commit d091f43

File tree

1 file changed

+9
-16
lines changed

1 file changed

+9
-16
lines changed

torchmdnet/datasets/ace.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,6 @@ class AceHF(Dataset):
297297
def __init__(self, root="parquet", paths=None, split="train") -> None:
298298
from datasets import load_dataset
299299

300-
self.properties = ("y", "neg_dy", "q", "pq", "dp")
301-
self.split = split
302-
303300
self.dataset = load_dataset(root, data_files=paths, split=split)
304301
self.dataset = self.dataset.with_format("torch")
305302

@@ -326,16 +323,12 @@ def __getitem__(self, idx):
326323
:obj:`torch_geometric.data.Data`: The data object.
327324
"""
328325
data = self.dataset[int(idx)]
329-
330-
props = {}
331-
if "y" in self.properties:
332-
props["y"] = data["formation_energy"].view(1, 1)
333-
if "neg_dy" in self.properties:
334-
props["neg_dy"] = data["forces"]
335-
if "q" in self.properties:
336-
props["q"] = sum(data["formal_charges"])
337-
if "pq" in self.properties:
338-
props["pq"] = data["partial_charges"]
339-
if "dp" in self.properties:
340-
props["dp"] = data["dipole_moment"]
341-
return Data(z=data["atomic_numbers"], pos=data["positions"], **props)
326+
return Data(
327+
z=data["atomic_numbers"],
328+
pos=data["positions"],
329+
y=data["formation_energy"].view(1, 1),
330+
neg_dy=data["forces"],
331+
q=sum(data["formal_charges"]),
332+
pq=data["partial_charges"],
333+
dp=data["dipole_moment"],
334+
)

0 commit comments

Comments
 (0)