Skip to content

CUSTOM MDCATH #341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions torchmdnet/datasets/mdcath.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def __init__(
transform=None,
pre_transform=None,
pre_filter=None,
source_file="mdcath_source.h5",
file_basename="mdcath_dataset",
numAtoms=5000,
numNoHAtoms=None,
numResidues=1000,
temperatures=["348"],
skip_frames=1,
Expand All @@ -52,16 +55,21 @@ def __init__(
Root directory where the dataset should be stored. Data will be downloaded to 'root/'.
numAtoms: int
Max number of atoms in the protein structure.
source_file: str
Name of the source file with the information about the protein structures. Default is "mdcath_source.h5".
file_basename: str
Base name of the hdf5 files. Default is "mdcath_dataset".
numNoHAtoms: int
Max number of non-hydrogen atoms in the protein structure.
Max number of non-hydrogen atoms in the protein structure, not available for original mdcath dataset. Default is None.
Be sure to have the attribute 'numNoHAtoms' in the source file.
numResidues: int
Max number of residues in the protein structure.
temperatures: list
List of temperatures (in Kelvin) to download. Default is ["348"]. Available temperatures are ['320', '348', '379', '413', '450']
skip_frames: int
Number of frames to skip in the trajectory. Default is 1.
pdb_list: list or str
List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'mdcath_source.h5' will be downloaded.
List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'source_file' will be downloaded.
The filters will be applied to the PDB IDs in this list in any case. Default is None.
min_gyration_radius: float
Minimum gyration radius (in nm) of the protein structure. Default is None.
Expand All @@ -76,7 +84,9 @@ def __init__(
"""

self.url = "https://huggingface.co/datasets/compsciencelab/mdCATH/resolve/main/"
self.source_file = "mdcath_source.h5"
self.source_file = source_file
self.file_basename = file_basename
self.numNoHAtoms = numNoHAtoms
self.root = root
os.makedirs(root, exist_ok=True)
self.numAtoms = numAtoms
Expand All @@ -103,33 +113,35 @@ def __init__(

@property
def raw_file_names(self):
return [f"mdcath_dataset_{pdb_id}.h5" for pdb_id in self.processed.keys()]
return [f"{self.file_basename}_{pdb_id}.h5" for pdb_id in self.processed.keys()]

@property
def raw_dir(self):
# Override the raw_dir property to return the root directory
# The files will be downloaded to the root directory
# The files will be downloaded to the root directory, compatible only with original mdcath dataset
return self.root

def _ensure_source_file(self):
"""Ensure the source file is downloaded before processing."""
source_path = os.path.join(self.root, self.source_file)
if not os.path.exists(source_path):
assert self.source_file == "mdcath_source.h5", "Only 'mdcath_source.h5' is supported as source file for download."
logger.info(f"Downloading source file {self.source_file}")
urllib.request.urlretrieve(opj(self.url, self.source_file), source_path)

def download(self):
for pdb_id in self.processed.keys():
file_name = f"mdcath_dataset_{pdb_id}.h5"
file_name = f"{self.file_basename}_{pdb_id}.h5"
file_path = opj(self.raw_dir, file_name)
if not os.path.exists(file_path):
assert self.file_basename == "mdcath_dataset", "Only 'mdcath_dataset' is supported as file_basename for download."
# Download the file if it does not exist
urllib.request.urlretrieve(opj(self.url, 'data', file_name), file_path)

def calculate_dataset_size(self):
total_size_bytes = 0
for pdb_id in self.processed.keys():
file_name = f"mdcath_dataset_{pdb_id}.h5"
file_name = f"{self.file_basename}_{pdb_id}.h5"
total_size_bytes += os.path.getsize(opj(self.root, file_name))
total_size_mb = round(total_size_bytes / (1024 * 1024), 4)
return total_size_mb
Expand Down Expand Up @@ -161,7 +173,8 @@ def _evaluate_replica(self, pdb_id, temp, replica, pdb_group):
self.numFrames is not None and pdb_group[temp][replica].attrs["numFrames"] < self.numFrames,
self.min_gyration_radius is not None and pdb_group[temp][replica].attrs["min_gyration_radius"] < self.min_gyration_radius,
self.max_gyration_radius is not None and pdb_group[temp][replica].attrs["max_gyration_radius"] > self.max_gyration_radius,
self._evaluate_structure(pdb_group, temp, replica)
self._evaluate_structure(pdb_group, temp, replica),
self.numNoHAtoms is not None and pdb_group.attrs["numNoHAtoms"] > self.numNoHAtoms,
]
if any(conditions):
return
Expand All @@ -180,7 +193,7 @@ def len(self):
return self.num_conformers

def _setup_idx(self):
files = [opj(self.root, f"mdcath_dataset_{pdb_id}.h5") for pdb_id in self.processed.keys()]
files = [opj(self.root, f"{self.file_basename}_{pdb_id}.h5") for pdb_id in self.processed.keys()]
self.idx = []
for i, (pdb, group_info) in enumerate(self.processed.items()):
for temp, replica, num_frames in group_info:
Expand Down
Loading