Skip to content

Commit fdd86d6

Browse files
authored
remove the mol_idx property as it was unused and also added extra requirement in pandas (#361)
1 parent f960354 commit fdd86d6

File tree

3 files changed

+2
-29
lines changed

3 files changed

+2
-29
lines changed

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ dependencies = [
1616
"torch_geometric",
1717
"lightning",
1818
"tqdm",
19-
"pandas",
2019
]
2120

2221
[project.urls]

torchmdnet/datasets/ace.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from torchmdnet.datasets.memdataset import MemmappedDataset
1010
from torch_geometric.data import Data
1111
from tqdm import tqdm
12-
import pandas as pd
1312

1413

1514
class Ace(MemmappedDataset):
@@ -133,7 +132,6 @@ def __init__(
133132
paths=None,
134133
max_gradient=None,
135134
subsample_molecules=1,
136-
index_csv=None,
137135
):
138136
assert isinstance(paths, (str, list))
139137

@@ -143,13 +141,8 @@ def __init__(
143141
self.paths = paths
144142
self.max_gradient = max_gradient
145143
self.subsample_molecules = int(subsample_molecules)
146-
if index_csv is not None:
147-
df = pd.read_csv(index_csv, dtype=int, converters={"name": str})
148-
self.mol_indexes = {mol_id: i for i, mol_id in enumerate(df.name)}
149144

150145
props = ["y", "neg_dy", "q", "pq", "dp"]
151-
if index_csv is not None:
152-
props += ["mol_idx"]
153146
super().__init__(
154147
root,
155148
transform,
@@ -239,7 +232,7 @@ def _load_confs_2_0(mol, n_atoms):
239232
def sample_iter(self, mol_ids=False):
240233
assert self.subsample_molecules > 0
241234

242-
for path in tqdm(self.raw_paths, desc="Files"):
235+
for i_path, path in tqdm(enumerate(self.raw_paths), desc="Files"):
243236
h5 = h5py.File(path)
244237
assert h5.attrs["layout"] == "Ace"
245238
version = h5.attrs["layout_version"]
@@ -285,10 +278,9 @@ def sample_iter(self, mol_ids=False):
285278
z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, q=q, pq=pq, dp=dp
286279
)
287280
if mol_ids:
281+
args["i_path"] = i_path
288282
args["mol_id"] = mol_id
289283
args["i_conf"] = i_conf
290-
if "mol_idx" in self.properties:
291-
args["mol_idx"] = self.mol_indexes[mol_id]
292284

293285
data = Data(**args)
294286

torchmdnet/datasets/memdataset.py

-18
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,6 @@ def __init__(
8686
self.mmaps["dp"] = np.memmap(
8787
fnames["dp"], mode="r", dtype=np.float32, shape=(num_all_confs, 3)
8888
)
89-
if "mol_idx" in self.properties:
90-
self.mmaps["mol_idx"] = np.memmap(
91-
fnames["mol_idx"], mode="r", dtype=np.uint64
92-
)
9389

9490
assert self.mmaps["idx"][0] == 0
9591
assert self.mmaps["idx"][-1] == len(self.mmaps["z"])
@@ -178,13 +174,6 @@ def process(self):
178174
dtype=np.float32,
179175
shape=(num_all_confs, 3),
180176
)
181-
if "mol_idx" in self.properties:
182-
mmaps["mol_idx"] = np.memmap(
183-
fnames["mol_idx"] + ".tmp",
184-
mode="w+",
185-
dtype=np.uint64,
186-
shape=(num_all_confs,),
187-
)
188177

189178
print("Storing data...")
190179
i_atom = 0
@@ -204,8 +193,6 @@ def process(self):
204193
mmaps["pq"][i_atom:i_next_atom] = data.pq
205194
if "dp" in self.properties:
206195
mmaps["dp"][i_conf] = data.dp
207-
if "mol_idx" in self.properties:
208-
mmaps["mol_idx"][i_conf] = data.mol_idx
209196
i_atom = i_next_atom
210197

211198
mmaps["idx"][-1] = num_all_atoms
@@ -231,8 +218,6 @@ def process(self):
231218
os.rename(fnames["pq"] + ".tmp", fnames["pq"])
232219
if "dp" in self.properties:
233220
os.rename(fnames["dp"] + ".tmp", fnames["dp"])
234-
if "mol_idx" in self.properties:
235-
os.rename(fnames["mol_idx"] + ".tmp", fnames["mol_idx"])
236221

237222
def len(self):
238223
return len(self.mmaps["idx"]) - 1
@@ -249,7 +234,6 @@ def get(self, idx):
249234
- :obj:`q`: Total charge of the molecule.
250235
- :obj:`pq`: Partial charges of the atoms.
251236
- :obj:`dp`: Dipole moment of the molecule.
252-
- :obj:`mol_idx`: The index of the molecule of the conformer.
253237
254238
Args:
255239
idx (int): Index of the data object.
@@ -272,8 +256,6 @@ def get(self, idx):
272256
props["pq"] = pt.tensor(self.mmaps["pq"][atoms])
273257
if "dp" in self.properties:
274258
props["dp"] = pt.tensor(self.mmaps["dp"][idx])
275-
# if "mol_idx" in self.properties:
276-
# props["mol_idx"] = pt.tensor(self.mmaps["mol_idx"][idx], dtype=pt.int64).view(1, 1)
277259
return Data(z=z, pos=pos, **props)
278260

279261
def __del__(self):

0 commit comments

Comments
 (0)