Skip to content

Commit f013b80

Browse files
authored
Merge pull request #341 from AntonioMirarchi/mdcath_noh
CUSTOM MDCATH
2 parents 30fb141 + e1b0818 commit f013b80

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

torchmdnet/datasets/mdcath.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ def __init__(
3333
transform=None,
3434
pre_transform=None,
3535
pre_filter=None,
36+
source_file="mdcath_source.h5",
37+
file_basename="mdcath_dataset",
3638
numAtoms=5000,
39+
numNoHAtoms=None,
3740
numResidues=1000,
3841
temperatures=["348"],
3942
skip_frames=1,
@@ -52,16 +55,21 @@ def __init__(
5255
Root directory where the dataset should be stored. Data will be downloaded to 'root/'.
5356
numAtoms: int
5457
Max number of atoms in the protein structure.
58+
source_file: str
59+
Name of the source file with the information about the protein structures. Default is "mdcath_source.h5".
60+
file_basename: str
61+
Base name of the hdf5 files. Default is "mdcath_dataset".
5562
numNoHAtoms: int
56-
Max number of non-hydrogen atoms in the protein structure.
63+
Max number of non-hydrogen atoms in the protein structure, not available for original mdcath dataset. Default is None.
64+
Be sure to have the attribute 'numNoHAtoms' in the source file.
5765
numResidues: int
5866
Max number of residues in the protein structure.
5967
temperatures: list
6068
List of temperatures (in Kelvin) to download. Default is ["348"]. Available temperatures are ['320', '348', '379', '413', '450']
6169
skip_frames: int
6270
Number of frames to skip in the trajectory. Default is 1.
6371
pdb_list: list or str
64-
List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'mdcath_source.h5' will be downloaded.
72+
List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'source_file' will be downloaded.
6573
The filters will be applied to the PDB IDs in this list in any case. Default is None.
6674
min_gyration_radius: float
6775
Minimum gyration radius (in nm) of the protein structure. Default is None.
@@ -76,7 +84,9 @@ def __init__(
7684
"""
7785

7886
self.url = "https://huggingface.co/datasets/compsciencelab/mdCATH/resolve/main/"
79-
self.source_file = "mdcath_source.h5"
87+
self.source_file = source_file
88+
self.file_basename = file_basename
89+
self.numNoHAtoms = numNoHAtoms
8090
self.root = root
8191
os.makedirs(root, exist_ok=True)
8292
self.numAtoms = numAtoms
@@ -103,33 +113,35 @@ def __init__(
103113

104114
@property
105115
def raw_file_names(self):
106-
return [f"mdcath_dataset_{pdb_id}.h5" for pdb_id in self.processed.keys()]
116+
return [f"{self.file_basename}_{pdb_id}.h5" for pdb_id in self.processed.keys()]
107117

108118
@property
109119
def raw_dir(self):
110120
# Override the raw_dir property to return the root directory
111-
# The files will be downloaded to the root directory
121+
# The files will be downloaded to the root directory, compatible only with original mdcath dataset
112122
return self.root
113123

114124
def _ensure_source_file(self):
115125
"""Ensure the source file is downloaded before processing."""
116126
source_path = os.path.join(self.root, self.source_file)
117127
if not os.path.exists(source_path):
128+
assert self.source_file == "mdcath_source.h5", "Only 'mdcath_source.h5' is supported as source file for download."
118129
logger.info(f"Downloading source file {self.source_file}")
119130
urllib.request.urlretrieve(opj(self.url, self.source_file), source_path)
120131

121132
def download(self):
122133
for pdb_id in self.processed.keys():
123-
file_name = f"mdcath_dataset_{pdb_id}.h5"
134+
file_name = f"{self.file_basename}_{pdb_id}.h5"
124135
file_path = opj(self.raw_dir, file_name)
125136
if not os.path.exists(file_path):
137+
assert self.file_basename == "mdcath_dataset", "Only 'mdcath_dataset' is supported as file_basename for download."
126138
# Download the file if it does not exist
127139
urllib.request.urlretrieve(opj(self.url, 'data', file_name), file_path)
128140

129141
def calculate_dataset_size(self):
130142
total_size_bytes = 0
131143
for pdb_id in self.processed.keys():
132-
file_name = f"mdcath_dataset_{pdb_id}.h5"
144+
file_name = f"{self.file_basename}_{pdb_id}.h5"
133145
total_size_bytes += os.path.getsize(opj(self.root, file_name))
134146
total_size_mb = round(total_size_bytes / (1024 * 1024), 4)
135147
return total_size_mb
@@ -161,7 +173,8 @@ def _evaluate_replica(self, pdb_id, temp, replica, pdb_group):
161173
self.numFrames is not None and pdb_group[temp][replica].attrs["numFrames"] < self.numFrames,
162174
self.min_gyration_radius is not None and pdb_group[temp][replica].attrs["min_gyration_radius"] < self.min_gyration_radius,
163175
self.max_gyration_radius is not None and pdb_group[temp][replica].attrs["max_gyration_radius"] > self.max_gyration_radius,
164-
self._evaluate_structure(pdb_group, temp, replica)
176+
self._evaluate_structure(pdb_group, temp, replica),
177+
self.numNoHAtoms is not None and pdb_group.attrs["numNoHAtoms"] > self.numNoHAtoms,
165178
]
166179
if any(conditions):
167180
return
@@ -180,7 +193,7 @@ def len(self):
180193
return self.num_conformers
181194

182195
def _setup_idx(self):
183-
files = [opj(self.root, f"mdcath_dataset_{pdb_id}.h5") for pdb_id in self.processed.keys()]
196+
files = [opj(self.root, f"{self.file_basename}_{pdb_id}.h5") for pdb_id in self.processed.keys()]
184197
self.idx = []
185198
for i, (pdb, group_info) in enumerate(self.processed.items()):
186199
for temp, replica, num_frames in group_info:

0 commit comments

Comments
 (0)