|
1 | 1 | # Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
|
2 | 2 | # Distributed under the MIT License.
|
3 | 3 | # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
|
4 |
| - |
| 4 | +from glob import glob |
| 5 | +import os |
5 | 6 | import re
|
| 7 | +import tempfile |
6 | 8 | from typing import Optional, List, Tuple, Dict
|
7 | 9 | import torch
|
8 | 10 | from torch.autograd import grad
|
|
13 | 15 | from torchmdnet import priors
|
14 | 16 | from lightning_utilities.core.rank_zero import rank_zero_warn
|
15 | 17 | import warnings
|
| 18 | +import zipfile |
16 | 19 |
|
17 | 20 |
|
18 | 21 | def create_model(args, prior_model=None, mean=None, std=None):
|
@@ -139,26 +142,72 @@ def create_model(args, prior_model=None, mean=None, std=None):
|
139 | 142 | return model
|
140 | 143 |
|
141 | 144 |
|
| 145 | +def load_ensemble(filepath, args=None, device="cpu", return_std=False, **kwargs): |
| 146 | + """Load an ensemble of models from a list of checkpoint files or a zip file. |
| 147 | +
|
| 148 | + Args: |
| 149 | + filepath (str or list): Can be any of the following: |
| 150 | +
|
| 151 | + - Path to a zip file containing multiple checkpoint files. |
| 152 | + - List of paths to checkpoint files. |
| 153 | +
|
| 154 | + args (dict, optional): Arguments for the model. Defaults to None. |
| 155 | + device (str, optional): Device on which the model should be loaded. Defaults to "cpu". |
| 156 | + return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. |
| 157 | + **kwargs: Extra keyword arguments for the model, will be passed to :py:mod:`load_model`. |
| 158 | +
|
| 159 | + Returns: |
| 160 | + nn.Module: An instance of :py:mod:`Ensemble`. |
| 161 | + """ |
| 162 | + if isinstance(filepath, (list, tuple)): |
| 163 | + assert all(isinstance(f, str) for f in filepath), "Invalid filepath list." |
| 164 | + model_list = [ |
| 165 | + load_model(f, args=args, device=device, **kwargs) for f in filepath |
| 166 | + ] |
| 167 | + elif filepath.endswith(".zip"): |
| 168 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 169 | + with zipfile.ZipFile(filepath, "r") as z: |
| 170 | + z.extractall(tmpdir) |
| 171 | + ckpt_list = glob(os.path.join(tmpdir, "*.ckpt")) |
| 172 | + assert len(ckpt_list) > 0, "No checkpoint files found in zip file." |
| 173 | + model_list = [ |
| 174 | + load_model(f, args=args, device=device, **kwargs) for f in ckpt_list |
| 175 | + ] |
| 176 | + else: |
| 177 | + raise ValueError( |
| 178 | + "Invalid filepath. Must be a list of paths or a path to a zip file." |
| 179 | + ) |
| 180 | + return Ensemble( |
| 181 | + model_list, |
| 182 | + return_std=return_std, |
| 183 | + ) |
| 184 | + |
| 185 | + |
142 | 186 | def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
|
143 | 187 | """Load a model from a checkpoint file.
|
144 | 188 |
|
145 |
| - If a list of paths is given, an :py:mod:`Ensemble` model is returned. |
| 189 | + If a list of paths or a path to a zip file is given, an :py:mod:`Ensemble` model is returned. |
146 | 190 | Args:
|
147 |
| - filepath (str or list): Path to the checkpoint file or a list of paths. |
| 191 | + filepath (str or list): Can be any of the following: |
| 192 | +
|
| 193 | + - Path to a checkpoint file. In this case, a :py:mod:`TorchMD_Net` model is returned. |
| 194 | + - Path to a zip file containing multiple checkpoint files. In this case, an :py:mod:`Ensemble` model is returned. |
| 195 | + - List of paths to checkpoint files. In this case, an :py:mod:`Ensemble` model is returned. |
| 196 | +
|
148 | 197 | args (dict, optional): Arguments for the model. Defaults to None.
|
149 | 198 | device (str, optional): Device on which the model should be loaded. Defaults to "cpu".
|
150 | 199 | return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False.
|
151 | 200 | **kwargs: Extra keyword arguments for the model.
|
152 | 201 |
|
153 | 202 | Returns:
|
154 |
| - nn.Module: An instance of the TorchMD_Net model. |
| 203 | + nn.Module: An instance of the TorchMD_Net model or an Ensemble model. |
155 | 204 | """
|
156 |
| - if isinstance(filepath, (list, tuple)): |
157 |
| - return Ensemble( |
158 |
| - [load_model(f, args=args, device=device, **kwargs) for f in filepath], |
159 |
| - return_std=return_std, |
| 205 | + isEnsemble = isinstance(filepath, (list, tuple)) or filepath.endswith(".zip") |
| 206 | + if isEnsemble: |
| 207 | + return load_ensemble( |
| 208 | + filepath, args=args, device=device, return_std=return_std, **kwargs |
160 | 209 | )
|
161 |
| - |
| 210 | + assert isinstance(filepath, str) |
162 | 211 | ckpt = torch.load(filepath, map_location="cpu")
|
163 | 212 | if args is None:
|
164 | 213 | args = ckpt["hyper_parameters"]
|
|
0 commit comments