Skip to content

Commit 8b47246

Browse files
authored
Merge pull request #311 from torchmd/ensemble_zips
Support for ensemble model zip files
2 parents 8a1be71 + 1a7b274 commit 8b47246

File tree

2 files changed

+98
-29
lines changed

2 files changed

+98
-29
lines changed

tests/test_model.py

+40-20
Original file line numberDiff line numberDiff line change
@@ -116,26 +116,28 @@ def test_cuda_graph_compatible(model_name):
116116
if not torch.cuda.is_available():
117117
pytest.skip("CUDA not available")
118118
z, pos, batch = create_example_batch()
119-
args = {"model": model_name,
120-
"embedding_dimension": 128,
121-
"num_layers": 2,
122-
"num_rbf": 32,
123-
"rbf_type": "expnorm",
124-
"trainable_rbf": False,
125-
"activation": "silu",
126-
"cutoff_lower": 0.0,
127-
"cutoff_upper": 5.0,
128-
"max_z": 100,
129-
"max_num_neighbors": 128,
130-
"equivariance_invariance_group": "O(3)",
131-
"prior_model": None,
132-
"atom_filter": -1,
133-
"derivative": True,
134-
"check_errors": False,
135-
"static_shapes": True,
136-
"output_model": "Scalar",
137-
"reduce_op": "sum",
138-
"precision": 32 }
119+
args = {
120+
"model": model_name,
121+
"embedding_dimension": 128,
122+
"num_layers": 2,
123+
"num_rbf": 32,
124+
"rbf_type": "expnorm",
125+
"trainable_rbf": False,
126+
"activation": "silu",
127+
"cutoff_lower": 0.0,
128+
"cutoff_upper": 5.0,
129+
"max_z": 100,
130+
"max_num_neighbors": 128,
131+
"equivariance_invariance_group": "O(3)",
132+
"prior_model": None,
133+
"atom_filter": -1,
134+
"derivative": True,
135+
"check_errors": False,
136+
"static_shapes": True,
137+
"output_model": "Scalar",
138+
"reduce_op": "sum",
139+
"precision": 32,
140+
}
139141
model = create_model(args).to(device="cuda")
140142
model.eval()
141143
z = z.to("cuda")
@@ -260,3 +262,21 @@ def test_ensemble():
260262
assert neg_dy_std.shape == deriv.shape
261263
assert (y_std == 0).all()
262264
assert (neg_dy_std == 0).all()
265+
266+
import zipfile
267+
import tempfile
268+
269+
with tempfile.TemporaryDirectory() as tmpdir:
270+
ensemble_zip = join(tmpdir, "ensemble.zip")
271+
with zipfile.ZipFile(ensemble_zip, "w") as zipf:
272+
for i, ckpt in enumerate(ckpts):
273+
zipf.write(ckpt, f"model_{i}.ckpt")
274+
ensemble_model = load_model(ensemble_zip, return_std=True)
275+
pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)
276+
277+
torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5)
278+
torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5)
279+
assert y_std.shape == pred.shape
280+
assert neg_dy_std.shape == deriv.shape
281+
assert (y_std == 0).all()
282+
assert (neg_dy_std == 0).all()

torchmdnet/models/model.py

+58-9
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
22
# Distributed under the MIT License.
33
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
4-
4+
from glob import glob
5+
import os
56
import re
7+
import tempfile
68
from typing import Optional, List, Tuple, Dict
79
import torch
810
from torch.autograd import grad
@@ -13,6 +15,7 @@
1315
from torchmdnet import priors
1416
from lightning_utilities.core.rank_zero import rank_zero_warn
1517
import warnings
18+
import zipfile
1619

1720

1821
def create_model(args, prior_model=None, mean=None, std=None):
@@ -139,26 +142,72 @@ def create_model(args, prior_model=None, mean=None, std=None):
139142
return model
140143

141144

145+
def load_ensemble(filepath, args=None, device="cpu", return_std=False, **kwargs):
146+
"""Load an ensemble of models from a list of checkpoint files or a zip file.
147+
148+
Args:
149+
filepath (str or list): Can be any of the following:
150+
151+
- Path to a zip file containing multiple checkpoint files.
152+
- List of paths to checkpoint files.
153+
154+
args (dict, optional): Arguments for the model. Defaults to None.
155+
device (str, optional): Device on which the model should be loaded. Defaults to "cpu".
156+
return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False.
157+
**kwargs: Extra keyword arguments for the model, will be passed to :py:mod:`load_model`.
158+
159+
Returns:
160+
nn.Module: An instance of :py:mod:`Ensemble`.
161+
"""
162+
if isinstance(filepath, (list, tuple)):
163+
assert all(isinstance(f, str) for f in filepath), "Invalid filepath list."
164+
model_list = [
165+
load_model(f, args=args, device=device, **kwargs) for f in filepath
166+
]
167+
elif filepath.endswith(".zip"):
168+
with tempfile.TemporaryDirectory() as tmpdir:
169+
with zipfile.ZipFile(filepath, "r") as z:
170+
z.extractall(tmpdir)
171+
ckpt_list = glob(os.path.join(tmpdir, "*.ckpt"))
172+
assert len(ckpt_list) > 0, "No checkpoint files found in zip file."
173+
model_list = [
174+
load_model(f, args=args, device=device, **kwargs) for f in ckpt_list
175+
]
176+
else:
177+
raise ValueError(
178+
"Invalid filepath. Must be a list of paths or a path to a zip file."
179+
)
180+
return Ensemble(
181+
model_list,
182+
return_std=return_std,
183+
)
184+
185+
142186
def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
143187
"""Load a model from a checkpoint file.
144188
145-
If a list of paths is given, an :py:mod:`Ensemble` model is returned.
189+
If a list of paths or a path to a zip file is given, an :py:mod:`Ensemble` model is returned.
146190
Args:
147-
filepath (str or list): Path to the checkpoint file or a list of paths.
191+
filepath (str or list): Can be any of the following:
192+
193+
- Path to a checkpoint file. In this case, a :py:mod:`TorchMD_Net` model is returned.
194+
- Path to a zip file containing multiple checkpoint files. In this case, an :py:mod:`Ensemble` model is returned.
195+
- List of paths to checkpoint files. In this case, an :py:mod:`Ensemble` model is returned.
196+
148197
args (dict, optional): Arguments for the model. Defaults to None.
149198
device (str, optional): Device on which the model should be loaded. Defaults to "cpu".
150199
return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False.
151200
**kwargs: Extra keyword arguments for the model.
152201
153202
Returns:
154-
nn.Module: An instance of the TorchMD_Net model.
203+
nn.Module: An instance of the TorchMD_Net model or an Ensemble model.
155204
"""
156-
if isinstance(filepath, (list, tuple)):
157-
return Ensemble(
158-
[load_model(f, args=args, device=device, **kwargs) for f in filepath],
159-
return_std=return_std,
205+
isEnsemble = isinstance(filepath, (list, tuple)) or filepath.endswith(".zip")
206+
if isEnsemble:
207+
return load_ensemble(
208+
filepath, args=args, device=device, return_std=return_std, **kwargs
160209
)
161-
210+
assert isinstance(filepath, str)
162211
ckpt = torch.load(filepath, map_location="cpu")
163212
if args is None:
164213
args = ckpt["hyper_parameters"]

0 commit comments

Comments
 (0)