@@ -33,7 +33,10 @@ def __init__(
33
33
transform = None ,
34
34
pre_transform = None ,
35
35
pre_filter = None ,
36
+ source_file = "mdcath_source.h5" ,
37
+ file_basename = "mdcath_dataset" ,
36
38
numAtoms = 5000 ,
39
+ numNoHAtoms = None ,
37
40
numResidues = 1000 ,
38
41
temperatures = ["348" ],
39
42
skip_frames = 1 ,
@@ -52,16 +55,21 @@ def __init__(
52
55
Root directory where the dataset should be stored. Data will be downloaded to 'root/'.
53
56
numAtoms: int
54
57
Max number of atoms in the protein structure.
58
+ source_file: str
59
+ Name of the source file with the information about the protein structures. Default is "mdcath_source.h5".
60
+ file_basename: str
61
+ Base name of the hdf5 files. Default is "mdcath_dataset".
55
62
numNoHAtoms: int
56
- Max number of non-hydrogen atoms in the protein structure.
63
+ Max number of non-hydrogen atoms in the protein structure, not available for original mdcath dataset. Default is None.
64
+ Be sure to have the attribute 'numNoHAtoms' in the source file.
57
65
numResidues: int
58
66
Max number of residues in the protein structure.
59
67
temperatures: list
60
68
List of temperatures (in Kelvin) to download. Default is ["348"]. Available temperatures are ['320', '348', '379', '413', '450']
61
69
skip_frames: int
62
70
Number of frames to skip in the trajectory. Default is 1.
63
71
pdb_list: list or str
64
- List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'mdcath_source.h5 ' will be downloaded.
72
+ List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'source_file ' will be downloaded.
65
73
The filters will be applied to the PDB IDs in this list in any case. Default is None.
66
74
min_gyration_radius: float
67
75
Minimum gyration radius (in nm) of the protein structure. Default is None.
@@ -76,7 +84,9 @@ def __init__(
76
84
"""
77
85
78
86
self .url = "https://huggingface.co/datasets/compsciencelab/mdCATH/resolve/main/"
79
- self .source_file = "mdcath_source.h5"
87
+ self .source_file = source_file
88
+ self .file_basename = file_basename
89
+ self .numNoHAtoms = numNoHAtoms
80
90
self .root = root
81
91
os .makedirs (root , exist_ok = True )
82
92
self .numAtoms = numAtoms
@@ -103,33 +113,35 @@ def __init__(
103
113
104
114
@property
105
115
def raw_file_names (self ):
106
- return [f"mdcath_dataset_ { pdb_id } .h5" for pdb_id in self .processed .keys ()]
116
+ return [f"{ self . file_basename } _ { pdb_id } .h5" for pdb_id in self .processed .keys ()]
107
117
108
118
@property
109
119
def raw_dir (self ):
110
120
# Override the raw_dir property to return the root directory
111
- # The files will be downloaded to the root directory
121
+ # The files will be downloaded to the root directory, compatible only with original mdcath dataset
112
122
return self .root
113
123
114
124
def _ensure_source_file (self ):
115
125
"""Ensure the source file is downloaded before processing."""
116
126
source_path = os .path .join (self .root , self .source_file )
117
127
if not os .path .exists (source_path ):
128
+ assert self .source_file == "mdcath_source.h5" , "Only 'mdcath_source.h5' is supported as source file for download."
118
129
logger .info (f"Downloading source file { self .source_file } " )
119
130
urllib .request .urlretrieve (opj (self .url , self .source_file ), source_path )
120
131
121
132
def download (self ):
122
133
for pdb_id in self .processed .keys ():
123
- file_name = f"mdcath_dataset_ { pdb_id } .h5"
134
+ file_name = f"{ self . file_basename } _ { pdb_id } .h5"
124
135
file_path = opj (self .raw_dir , file_name )
125
136
if not os .path .exists (file_path ):
137
+ assert self .file_basename == "mdcath_dataset" , "Only 'mdcath_dataset' is supported as file_basename for download."
126
138
# Download the file if it does not exist
127
139
urllib .request .urlretrieve (opj (self .url , 'data' , file_name ), file_path )
128
140
129
141
def calculate_dataset_size (self ):
130
142
total_size_bytes = 0
131
143
for pdb_id in self .processed .keys ():
132
- file_name = f"mdcath_dataset_ { pdb_id } .h5"
144
+ file_name = f"{ self . file_basename } _ { pdb_id } .h5"
133
145
total_size_bytes += os .path .getsize (opj (self .root , file_name ))
134
146
total_size_mb = round (total_size_bytes / (1024 * 1024 ), 4 )
135
147
return total_size_mb
@@ -161,7 +173,8 @@ def _evaluate_replica(self, pdb_id, temp, replica, pdb_group):
161
173
self .numFrames is not None and pdb_group [temp ][replica ].attrs ["numFrames" ] < self .numFrames ,
162
174
self .min_gyration_radius is not None and pdb_group [temp ][replica ].attrs ["min_gyration_radius" ] < self .min_gyration_radius ,
163
175
self .max_gyration_radius is not None and pdb_group [temp ][replica ].attrs ["max_gyration_radius" ] > self .max_gyration_radius ,
164
- self ._evaluate_structure (pdb_group , temp , replica )
176
+ self ._evaluate_structure (pdb_group , temp , replica ),
177
+ self .numNoHAtoms is not None and pdb_group .attrs ["numNoHAtoms" ] > self .numNoHAtoms ,
165
178
]
166
179
if any (conditions ):
167
180
return
@@ -180,7 +193,7 @@ def len(self):
180
193
return self .num_conformers
181
194
182
195
def _setup_idx (self ):
183
- files = [opj (self .root , f"mdcath_dataset_ { pdb_id } .h5" ) for pdb_id in self .processed .keys ()]
196
+ files = [opj (self .root , f"{ self . file_basename } _ { pdb_id } .h5" ) for pdb_id in self .processed .keys ()]
184
197
self .idx = []
185
198
for i , (pdb , group_info ) in enumerate (self .processed .items ()):
186
199
for temp , replica , num_frames in group_info :
0 commit comments