@@ -297,9 +297,6 @@ class AceHF(Dataset):
297
297
def __init__ (self , root = "parquet" , paths = None , split = "train" ) -> None :
298
298
from datasets import load_dataset
299
299
300
- self .properties = ("y" , "neg_dy" , "q" , "pq" , "dp" )
301
- self .split = split
302
-
303
300
self .dataset = load_dataset (root , data_files = paths , split = split )
304
301
self .dataset = self .dataset .with_format ("torch" )
305
302
@@ -326,16 +323,12 @@ def __getitem__(self, idx):
326
323
:obj:`torch_geometric.data.Data`: The data object.
327
324
"""
328
325
data = self .dataset [int (idx )]
329
-
330
- props = {}
331
- if "y" in self .properties :
332
- props ["y" ] = data ["formation_energy" ].view (1 , 1 )
333
- if "neg_dy" in self .properties :
334
- props ["neg_dy" ] = data ["forces" ]
335
- if "q" in self .properties :
336
- props ["q" ] = sum (data ["formal_charges" ])
337
- if "pq" in self .properties :
338
- props ["pq" ] = data ["partial_charges" ]
339
- if "dp" in self .properties :
340
- props ["dp" ] = data ["dipole_moment" ]
341
- return Data (z = data ["atomic_numbers" ], pos = data ["positions" ], ** props )
326
+ return Data (
327
+ z = data ["atomic_numbers" ],
328
+ pos = data ["positions" ],
329
+ y = data ["formation_energy" ].view (1 , 1 ),
330
+ neg_dy = data ["forces" ],
331
+ q = sum (data ["formal_charges" ]),
332
+ pq = data ["partial_charges" ],
333
+ dp = data ["dipole_moment" ],
334
+ )
0 commit comments