Skip to content

Commit 0ed2e7c

Browse files
authored
Filter non used sample values in QM9 (#316)
1 parent 72d6e8e commit 0ed2e7c

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

torchmdnet/datasets/qm9.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch_geometric.transforms import Compose
77
from torch_geometric.datasets import QM9 as QM9_geometric
88
from torch_geometric.nn.models.schnet import qm9_target_dict
9+
from torch_geometric.data import Data
910

1011

1112
class QM9(QM9_geometric):
@@ -25,17 +26,27 @@ def __init__(self, root, transform=None, label=None):
2526
else:
2627
transform = Compose([transform, self._filter_label])
2728

28-
super(QM9, self).__init__(root, transform=transform)
29+
# Keep only pos, z and y in each sample
30+
def pre_transform(x):
31+
return Data(
32+
pos=x.pos,
33+
z=x.z,
34+
y=x.y,
35+
)
36+
37+
super(QM9, self).__init__(
38+
root, transform=transform, pre_transform=pre_transform
39+
)
2940

3041
def get_atomref(self, max_z=100):
3142
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.
3243
33-
Args:
34-
max_z (int): Maximum atomic number
44+
Args:
45+
max_z (int): Maximum atomic number
3546
36-
Returns:
37-
torch.Tensor: Atomic energy reference values for each element in the dataset.
38-
"""
47+
Returns:
48+
torch.Tensor: Atomic energy reference values for each element in the dataset.
49+
"""
3950
atomref = self.atomref(self.label_idx)
4051
if atomref is None:
4152
return None

0 commit comments

Comments
 (0)