|
| 1 | +import h5py |
| 2 | +import psutil |
| 3 | +import numpy as np |
| 4 | +from pytest import mark |
| 5 | +from os.path import join |
| 6 | +from torchmdnet.datasets.mdcath import MDCATH |
| 7 | +from torch_geometric.loader import DataLoader |
| 8 | +from tqdm import tqdm |
| 9 | + |
| 10 | + |
| 11 | +def test_mdcath(tmpdir): |
| 12 | + num_atoms_list = np.linspace(50, 1000, 50) |
| 13 | + source_file = h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w") |
| 14 | + for num_atoms in num_atoms_list: |
| 15 | + z = np.zeros(int(num_atoms)) |
| 16 | + pos = np.zeros((100, int(num_atoms), 3)) |
| 17 | + forces = np.zeros((100, int(num_atoms), 3)) |
| 18 | + |
| 19 | + s_group = source_file.create_group(f"A{num_atoms}") |
| 20 | + |
| 21 | + s_group.attrs["numChains"] = 1 |
| 22 | + s_group.attrs["numNoHAtoms"] = int(num_atoms) / 2 |
| 23 | + s_group.attrs["numProteinAtoms"] = int(num_atoms) |
| 24 | + s_group.attrs["numResidues"] = int(num_atoms) / 10 |
| 25 | + s_temp_group = s_group.create_group("348") |
| 26 | + s_replica_group = s_temp_group.create_group("0") |
| 27 | + s_replica_group.attrs["numFrames"] = 100 |
| 28 | + s_replica_group.attrs["alpha"] = 0.30 |
| 29 | + s_replica_group.attrs["beta"] = 0.25 |
| 30 | + s_replica_group.attrs["coil"] = 0.45 |
| 31 | + s_replica_group.attrs["max_gyration_radius"] = 2 |
| 32 | + s_replica_group.attrs["max_num_neighbors_5A"] = 55 |
| 33 | + s_replica_group.attrs["max_num_neighbors_9A"] = 200 |
| 34 | + s_replica_group.attrs["min_gyration_radius"] = 1 |
| 35 | + |
| 36 | + # write the dataset |
| 37 | + data = h5py.File(join(tmpdir, f"mdcath_dataset_A{num_atoms}.h5"), mode="w") |
| 38 | + group = data.create_group(f"A{num_atoms}") |
| 39 | + group.create_dataset("z", data=z) |
| 40 | + tempgroup = group.create_group("348") |
| 41 | + replicagroup = tempgroup.create_group("0") |
| 42 | + replicagroup.create_dataset("coords", data=pos) |
| 43 | + replicagroup.create_dataset("forces", data=forces) |
| 44 | + # add some attributes |
| 45 | + replicagroup.attrs["numFrames"] = 100 |
| 46 | + replicagroup["coords"].attrs["unit"] = "Angstrom" |
| 47 | + replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom" |
| 48 | + |
| 49 | + data.flush() |
| 50 | + data.close() |
| 51 | + |
| 52 | + dataset = MDCATH(root=tmpdir) |
| 53 | + dl = DataLoader( |
| 54 | + dataset, |
| 55 | + batch_size=1, |
| 56 | + shuffle=False, |
| 57 | + num_workers=0, |
| 58 | + pin_memory=True, |
| 59 | + persistent_workers=False, |
| 60 | + ) |
| 61 | + for _, data in enumerate(tqdm(dl)): |
| 62 | + pass |
| 63 | + |
| 64 | + |
| 65 | +def test_mdcath_multiprocessing(tmpdir, num_entries=100, numFrames=10): |
| 66 | + # generate sample data |
| 67 | + z = np.zeros(num_entries) |
| 68 | + pos = np.zeros((numFrames, num_entries, 3)) |
| 69 | + forces = np.zeros((numFrames, num_entries, 3)) |
| 70 | + |
| 71 | + source_file = h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w") |
| 72 | + s_group = source_file.create_group("A00") |
| 73 | + |
| 74 | + s_group.attrs["numChains"] = 1 |
| 75 | + s_group.attrs["numNoHAtoms"] = num_entries / 2 |
| 76 | + s_group.attrs["numProteinAtoms"] = num_entries |
| 77 | + s_group.attrs["numResidues"] = num_entries / 10 |
| 78 | + s_temp_group = s_group.create_group("348") |
| 79 | + s_replica_group = s_temp_group.create_group("0") |
| 80 | + s_replica_group.attrs["numFrames"] = numFrames |
| 81 | + s_replica_group.attrs["alpha"] = 0.30 |
| 82 | + s_replica_group.attrs["beta"] = 0.25 |
| 83 | + s_replica_group.attrs["coil"] = 0.45 |
| 84 | + s_replica_group.attrs["max_gyration_radius"] = 2 |
| 85 | + s_replica_group.attrs["max_num_neighbors_5A"] = 55 |
| 86 | + s_replica_group.attrs["max_num_neighbors_9A"] = 200 |
| 87 | + s_replica_group.attrs["min_gyration_radius"] = 1 |
| 88 | + |
| 89 | + # write the dataset |
| 90 | + data = h5py.File(join(tmpdir, "mdcath_dataset_A00.h5"), mode="w") |
| 91 | + group = data.create_group("A00") |
| 92 | + group.create_dataset("z", data=z) |
| 93 | + tempgroup = group.create_group("348") |
| 94 | + replicagroup = tempgroup.create_group("0") |
| 95 | + replicagroup.create_dataset("coords", data=pos) |
| 96 | + replicagroup.create_dataset("forces", data=forces) |
| 97 | + # add some attributes |
| 98 | + replicagroup.attrs["numFrames"] = numFrames |
| 99 | + replicagroup["coords"].attrs["unit"] = "Angstrom" |
| 100 | + replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom" |
| 101 | + |
| 102 | + data.flush() |
| 103 | + data.close() |
| 104 | + |
| 105 | + # make sure creating the dataset doesn't open any files on the main process |
| 106 | + proc = psutil.Process() |
| 107 | + n_open = len(proc.open_files()) |
| 108 | + |
| 109 | + dset = MDCATH( |
| 110 | + root=tmpdir, |
| 111 | + ) |
| 112 | + assert len(proc.open_files()) == n_open, "creating the dataset object opened a file" |
| 113 | + |
| 114 | + |
| 115 | +def replacer(arr, skipframes): |
| 116 | + tmp_arr = arr.copy() |
| 117 | + # function that take a numpy array of zeros and based on a skipframes value, replaces the zeros with 1s in that position |
| 118 | + for i in range(0, len(tmp_arr), skipframes): |
| 119 | + tmp_arr[i, :, :] = 1 |
| 120 | + return tmp_arr |
| 121 | + |
| 122 | + |
| 123 | +@mark.parametrize("skipframes", [1, 2, 5]) |
| 124 | +@mark.parametrize("batch_size", [1, 10]) |
| 125 | +@mark.parametrize("pdb_list", [["A50", "A612", "A1000"], None]) |
| 126 | +def test_mdcath_args(tmpdir, skipframes, batch_size, pdb_list): |
| 127 | + with h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w") as source_file: |
| 128 | + num_frames_list = np.linspace(50, 1000, 50).astype(int) |
| 129 | + for num_frame in tqdm(num_frames_list, desc="Creating tmp files"): |
| 130 | + z = np.zeros(100) |
| 131 | + pos = np.zeros((num_frame, 100, 3)) |
| 132 | + forces = np.zeros((num_frame, 100, 3)) |
| 133 | + |
| 134 | + pos = replacer(pos, skipframes) |
| 135 | + forces = replacer(forces, skipframes) |
| 136 | + |
| 137 | + s_group = source_file.create_group(f"A{num_frame}") |
| 138 | + |
| 139 | + s_group.attrs["numChains"] = 1 |
| 140 | + s_group.attrs["numNoHAtoms"] = 100 / 2 |
| 141 | + s_group.attrs["numProteinAtoms"] = 100 |
| 142 | + s_group.attrs["numResidues"] = 100 / 10 |
| 143 | + s_temp_group = s_group.create_group("348") |
| 144 | + s_replica_group = s_temp_group.create_group("0") |
| 145 | + s_replica_group.attrs["numFrames"] = num_frame |
| 146 | + s_replica_group.attrs["alpha"] = 0.30 |
| 147 | + s_replica_group.attrs["beta"] = 0.25 |
| 148 | + s_replica_group.attrs["coil"] = 0.45 |
| 149 | + s_replica_group.attrs["max_gyration_radius"] = 2 |
| 150 | + s_replica_group.attrs["max_num_neighbors_5A"] = 55 |
| 151 | + s_replica_group.attrs["max_num_neighbors_9A"] = 200 |
| 152 | + s_replica_group.attrs["min_gyration_radius"] = 1 |
| 153 | + |
| 154 | + # write the dataset |
| 155 | + data = h5py.File(join(tmpdir, f"mdcath_dataset_A{num_frame}.h5"), mode="w") |
| 156 | + group = data.create_group(f"A{num_frame}") |
| 157 | + group.create_dataset("z", data=z) |
| 158 | + tempgroup = group.create_group("348") |
| 159 | + replicagroup = tempgroup.create_group("0") |
| 160 | + replicagroup.create_dataset("coords", data=pos) |
| 161 | + replicagroup.create_dataset("forces", data=forces) |
| 162 | + # add some attributes |
| 163 | + replicagroup.attrs["numFrames"] = num_frame |
| 164 | + replicagroup["coords"].attrs["unit"] = "Angstrom" |
| 165 | + replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom" |
| 166 | + |
| 167 | + data.flush() |
| 168 | + data.close() |
| 169 | + |
| 170 | + dataset = MDCATH( |
| 171 | + root=tmpdir, skip_frames=skipframes, pdb_list=pdb_list |
| 172 | + ) |
| 173 | + dl = DataLoader( |
| 174 | + dataset, |
| 175 | + batch_size=batch_size, |
| 176 | + shuffle=False, |
| 177 | + num_workers=0, |
| 178 | + pin_memory=True, |
| 179 | + persistent_workers=False, |
| 180 | + ) |
| 181 | + for _, data in enumerate(tqdm(dl)): |
| 182 | + # if the skipframes works correclty, data returned should be only 1s |
| 183 | + assert data.pos.all() == 1, "skipframes not working correctly for positions" |
| 184 | + assert data.neg_dy.all() == 1, "skipframes not working correctly for forces" |
0 commit comments