diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 62868fc99..4ea072672 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -33,7 +33,10 @@ def __init__( transform=None, pre_transform=None, pre_filter=None, + source_file="mdcath_source.h5", + file_basename="mdcath_dataset", numAtoms=5000, + numNoHAtoms=None, numResidues=1000, temperatures=["348"], skip_frames=1, @@ -52,8 +55,13 @@ def __init__( Root directory where the dataset should be stored. Data will be downloaded to 'root/'. numAtoms: int Max number of atoms in the protein structure. + source_file: str + Name of the source file with the information about the protein structures. Default is "mdcath_source.h5". + file_basename: str + Base name of the hdf5 files. Default is "mdcath_dataset". numNoHAtoms: int - Max number of non-hydrogen atoms in the protein structure. + Max number of non-hydrogen atoms in the protein structure, not available for original mdcath dataset. Default is None. + Be sure to have the attribute 'numNoHAtoms' in the source file. numResidues: int Max number of residues in the protein structure. temperatures: list @@ -61,7 +69,7 @@ def __init__( skip_frames: int Number of frames to skip in the trajectory. Default is 1. pdb_list: list or str - List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'mdcath_source.h5' will be downloaded. + List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'source_file' will be downloaded. The filters will be applied to the PDB IDs in this list in any case. Default is None. min_gyration_radius: float Minimum gyration radius (in nm) of the protein structure. Default is None. @@ -76,7 +84,9 @@ def __init__( """ self.url = "https://huggingface.co/datasets/compsciencelab/mdCATH/resolve/main/" - self.source_file = "mdcath_source.h5" + self.source_file = source_file + self.file_basename = file_basename + self.numNoHAtoms = numNoHAtoms self.root = root os.makedirs(root, exist_ok=True) self.numAtoms = numAtoms @@ -103,33 +113,35 @@ def __init__( @property def raw_file_names(self): - return [f"mdcath_dataset_{pdb_id}.h5" for pdb_id in self.processed.keys()] + return [f"{self.file_basename}_{pdb_id}.h5" for pdb_id in self.processed.keys()] @property def raw_dir(self): # Override the raw_dir property to return the root directory - # The files will be downloaded to the root directory + # The files will be downloaded to the root directory, compatible only with original mdcath dataset return self.root def _ensure_source_file(self): """Ensure the source file is downloaded before processing.""" source_path = os.path.join(self.root, self.source_file) if not os.path.exists(source_path): + assert self.source_file == "mdcath_source.h5", "Only 'mdcath_source.h5' is supported as source file for download." logger.info(f"Downloading source file {self.source_file}") urllib.request.urlretrieve(opj(self.url, self.source_file), source_path) def download(self): for pdb_id in self.processed.keys(): - file_name = f"mdcath_dataset_{pdb_id}.h5" + file_name = f"{self.file_basename}_{pdb_id}.h5" file_path = opj(self.raw_dir, file_name) if not os.path.exists(file_path): + assert self.file_basename == "mdcath_dataset", "Only 'mdcath_dataset' is supported as file_basename for download." # Download the file if it does not exist urllib.request.urlretrieve(opj(self.url, 'data', file_name), file_path) def calculate_dataset_size(self): total_size_bytes = 0 for pdb_id in self.processed.keys(): - file_name = f"mdcath_dataset_{pdb_id}.h5" + file_name = f"{self.file_basename}_{pdb_id}.h5" total_size_bytes += os.path.getsize(opj(self.root, file_name)) total_size_mb = round(total_size_bytes / (1024 * 1024), 4) return total_size_mb @@ -161,7 +173,8 @@ def _evaluate_replica(self, pdb_id, temp, replica, pdb_group): self.numFrames is not None and pdb_group[temp][replica].attrs["numFrames"] < self.numFrames, self.min_gyration_radius is not None and pdb_group[temp][replica].attrs["min_gyration_radius"] < self.min_gyration_radius, self.max_gyration_radius is not None and pdb_group[temp][replica].attrs["max_gyration_radius"] > self.max_gyration_radius, - self._evaluate_structure(pdb_group, temp, replica) + self._evaluate_structure(pdb_group, temp, replica), + self.numNoHAtoms is not None and pdb_group.attrs["numNoHAtoms"] > self.numNoHAtoms, ] if any(conditions): return @@ -180,7 +193,7 @@ def len(self): return self.num_conformers def _setup_idx(self): - files = [opj(self.root, f"mdcath_dataset_{pdb_id}.h5") for pdb_id in self.processed.keys()] + files = [opj(self.root, f"{self.file_basename}_{pdb_id}.h5") for pdb_id in self.processed.keys()] self.idx = [] for i, (pdb, group_info) in enumerate(self.processed.items()): for temp, replica, num_frames in group_info: