Skip to content

Commit 352da8a

Browse files
committed
Merge remote-tracking branch 'origin/main' into swiglu
2 parents 3572c32 + 6dea4b6 commit 352da8a

File tree

4 files changed

+417
-1
lines changed

4 files changed

+417
-1
lines changed

tests/test_mdcath.py

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import h5py
2+
import psutil
3+
import numpy as np
4+
from pytest import mark
5+
from os.path import join
6+
from torchmdnet.datasets.mdcath import MDCATH
7+
from torch_geometric.loader import DataLoader
8+
from tqdm import tqdm
9+
10+
11+
def test_mdcath(tmpdir):
12+
num_atoms_list = np.linspace(50, 1000, 50)
13+
source_file = h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w")
14+
for num_atoms in num_atoms_list:
15+
z = np.zeros(int(num_atoms))
16+
pos = np.zeros((100, int(num_atoms), 3))
17+
forces = np.zeros((100, int(num_atoms), 3))
18+
19+
s_group = source_file.create_group(f"A{num_atoms}")
20+
21+
s_group.attrs["numChains"] = 1
22+
s_group.attrs["numNoHAtoms"] = int(num_atoms) / 2
23+
s_group.attrs["numProteinAtoms"] = int(num_atoms)
24+
s_group.attrs["numResidues"] = int(num_atoms) / 10
25+
s_temp_group = s_group.create_group("348")
26+
s_replica_group = s_temp_group.create_group("0")
27+
s_replica_group.attrs["numFrames"] = 100
28+
s_replica_group.attrs["alpha"] = 0.30
29+
s_replica_group.attrs["beta"] = 0.25
30+
s_replica_group.attrs["coil"] = 0.45
31+
s_replica_group.attrs["max_gyration_radius"] = 2
32+
s_replica_group.attrs["max_num_neighbors_5A"] = 55
33+
s_replica_group.attrs["max_num_neighbors_9A"] = 200
34+
s_replica_group.attrs["min_gyration_radius"] = 1
35+
36+
# write the dataset
37+
data = h5py.File(join(tmpdir, f"mdcath_dataset_A{num_atoms}.h5"), mode="w")
38+
group = data.create_group(f"A{num_atoms}")
39+
group.create_dataset("z", data=z)
40+
tempgroup = group.create_group("348")
41+
replicagroup = tempgroup.create_group("0")
42+
replicagroup.create_dataset("coords", data=pos)
43+
replicagroup.create_dataset("forces", data=forces)
44+
# add some attributes
45+
replicagroup.attrs["numFrames"] = 100
46+
replicagroup["coords"].attrs["unit"] = "Angstrom"
47+
replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom"
48+
49+
data.flush()
50+
data.close()
51+
52+
dataset = MDCATH(root=tmpdir)
53+
dl = DataLoader(
54+
dataset,
55+
batch_size=1,
56+
shuffle=False,
57+
num_workers=0,
58+
pin_memory=True,
59+
persistent_workers=False,
60+
)
61+
for _, data in enumerate(tqdm(dl)):
62+
pass
63+
64+
65+
def test_mdcath_multiprocessing(tmpdir, num_entries=100, numFrames=10):
66+
# generate sample data
67+
z = np.zeros(num_entries)
68+
pos = np.zeros((numFrames, num_entries, 3))
69+
forces = np.zeros((numFrames, num_entries, 3))
70+
71+
source_file = h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w")
72+
s_group = source_file.create_group("A00")
73+
74+
s_group.attrs["numChains"] = 1
75+
s_group.attrs["numNoHAtoms"] = num_entries / 2
76+
s_group.attrs["numProteinAtoms"] = num_entries
77+
s_group.attrs["numResidues"] = num_entries / 10
78+
s_temp_group = s_group.create_group("348")
79+
s_replica_group = s_temp_group.create_group("0")
80+
s_replica_group.attrs["numFrames"] = numFrames
81+
s_replica_group.attrs["alpha"] = 0.30
82+
s_replica_group.attrs["beta"] = 0.25
83+
s_replica_group.attrs["coil"] = 0.45
84+
s_replica_group.attrs["max_gyration_radius"] = 2
85+
s_replica_group.attrs["max_num_neighbors_5A"] = 55
86+
s_replica_group.attrs["max_num_neighbors_9A"] = 200
87+
s_replica_group.attrs["min_gyration_radius"] = 1
88+
89+
# write the dataset
90+
data = h5py.File(join(tmpdir, "mdcath_dataset_A00.h5"), mode="w")
91+
group = data.create_group("A00")
92+
group.create_dataset("z", data=z)
93+
tempgroup = group.create_group("348")
94+
replicagroup = tempgroup.create_group("0")
95+
replicagroup.create_dataset("coords", data=pos)
96+
replicagroup.create_dataset("forces", data=forces)
97+
# add some attributes
98+
replicagroup.attrs["numFrames"] = numFrames
99+
replicagroup["coords"].attrs["unit"] = "Angstrom"
100+
replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom"
101+
102+
data.flush()
103+
data.close()
104+
105+
# make sure creating the dataset doesn't open any files on the main process
106+
proc = psutil.Process()
107+
n_open = len(proc.open_files())
108+
109+
dset = MDCATH(
110+
root=tmpdir,
111+
)
112+
assert len(proc.open_files()) == n_open, "creating the dataset object opened a file"
113+
114+
115+
def replacer(arr, skipframes):
116+
tmp_arr = arr.copy()
117+
# function that take a numpy array of zeros and based on a skipframes value, replaces the zeros with 1s in that position
118+
for i in range(0, len(tmp_arr), skipframes):
119+
tmp_arr[i, :, :] = 1
120+
return tmp_arr
121+
122+
123+
@mark.parametrize("skipframes", [1, 2, 5])
124+
@mark.parametrize("batch_size", [1, 10])
125+
@mark.parametrize("pdb_list", [["A50", "A612", "A1000"], None])
126+
def test_mdcath_args(tmpdir, skipframes, batch_size, pdb_list):
127+
with h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w") as source_file:
128+
num_frames_list = np.linspace(50, 1000, 50).astype(int)
129+
for num_frame in tqdm(num_frames_list, desc="Creating tmp files"):
130+
z = np.zeros(100)
131+
pos = np.zeros((num_frame, 100, 3))
132+
forces = np.zeros((num_frame, 100, 3))
133+
134+
pos = replacer(pos, skipframes)
135+
forces = replacer(forces, skipframes)
136+
137+
s_group = source_file.create_group(f"A{num_frame}")
138+
139+
s_group.attrs["numChains"] = 1
140+
s_group.attrs["numNoHAtoms"] = 100 / 2
141+
s_group.attrs["numProteinAtoms"] = 100
142+
s_group.attrs["numResidues"] = 100 / 10
143+
s_temp_group = s_group.create_group("348")
144+
s_replica_group = s_temp_group.create_group("0")
145+
s_replica_group.attrs["numFrames"] = num_frame
146+
s_replica_group.attrs["alpha"] = 0.30
147+
s_replica_group.attrs["beta"] = 0.25
148+
s_replica_group.attrs["coil"] = 0.45
149+
s_replica_group.attrs["max_gyration_radius"] = 2
150+
s_replica_group.attrs["max_num_neighbors_5A"] = 55
151+
s_replica_group.attrs["max_num_neighbors_9A"] = 200
152+
s_replica_group.attrs["min_gyration_radius"] = 1
153+
154+
# write the dataset
155+
data = h5py.File(join(tmpdir, f"mdcath_dataset_A{num_frame}.h5"), mode="w")
156+
group = data.create_group(f"A{num_frame}")
157+
group.create_dataset("z", data=z)
158+
tempgroup = group.create_group("348")
159+
replicagroup = tempgroup.create_group("0")
160+
replicagroup.create_dataset("coords", data=pos)
161+
replicagroup.create_dataset("forces", data=forces)
162+
# add some attributes
163+
replicagroup.attrs["numFrames"] = num_frame
164+
replicagroup["coords"].attrs["unit"] = "Angstrom"
165+
replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom"
166+
167+
data.flush()
168+
data.close()
169+
170+
dataset = MDCATH(
171+
root=tmpdir, skip_frames=skipframes, pdb_list=pdb_list
172+
)
173+
dl = DataLoader(
174+
dataset,
175+
batch_size=batch_size,
176+
shuffle=False,
177+
num_workers=0,
178+
pin_memory=True,
179+
persistent_workers=False,
180+
)
181+
for _, data in enumerate(tqdm(dl)):
182+
# if the skipframes works correclty, data returned should be only 1s
183+
assert data.pos.all() == 1, "skipframes not working correctly for positions"
184+
assert data.neg_dy.all() == 1, "skipframes not working correctly for forces"

torchmdnet/datasets/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
COMP6v1,
1515
COMP6v2,
1616
)
17+
from .mdcath import MDCATH
1718
from .custom import Custom
1819
from .water import WaterBox
1920
from .hdf import HDF5
@@ -40,6 +41,7 @@
4041
"GDB10to13",
4142
"GenentechTorsions",
4243
"HDF5",
44+
"MDCATH",
4345
"MD17",
4446
"MD22",
4547
"QM9",

0 commit comments

Comments
 (0)