Skip to content

Commit 0b94d88

Browse files
authored
Merge pull request #348 from torchmd/fix_torch_warnings
Fixed pytorch deprecations warnings
2 parents f6c0c16 + 535c5ae commit 0b94d88

File tree

4 files changed

+19
-10
lines changed

4 files changed

+19
-10
lines changed

tests/test_optimize.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,23 @@
44

55
import pytest
66
from pytest import mark
7-
import torch as pt
8-
from torchmdnet.models.model import create_model
9-
from torchmdnet.optimize import optimize
10-
from torchmdnet.models.utils import dtype_mapping
117

8+
try:
9+
import NNPOps
10+
11+
nnpops_available = True
12+
except ImportError:
13+
nnpops_available = False
14+
15+
16+
@pytest.mark.skipif(not nnpops_available, reason="NNPOps not available")
1217
@mark.parametrize("device", ["cpu", "cuda"])
1318
@mark.parametrize("num_atoms", [10, 100])
1419
def test_gn(device, num_atoms):
20+
import torch as pt
21+
from torchmdnet.models.model import create_model
22+
from torchmdnet.optimize import optimize
23+
from torchmdnet.models.utils import dtype_mapping
1524

1625
if not pt.cuda.is_available() and device == "cuda":
1726
pytest.skip("No GPU")

torchmdnet/extensions/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,12 @@ def get_neighbor_pairs_fwd_meta(
145145

146146

147147
if torch.__version__ >= "2.2.0":
148-
from torch.library import impl_abstract
148+
from torch.library import register_fake
149149

150-
impl_abstract(
150+
register_fake(
151151
"torchmdnet_extensions::get_neighbor_pairs_bkwd", get_neighbor_pairs_bkwd_meta
152152
)
153-
impl_abstract(
153+
register_fake(
154154
"torchmdnet_extensions::get_neighbor_pairs_fwd", get_neighbor_pairs_fwd_meta
155155
)
156156
elif torch.__version__ < "2.2.0" and torch.__version__ >= "2.0.0":

torchmdnet/extensions/neighbors/neighbors_cpu.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
using std::tuple;
99
using torch::arange;
1010
using torch::div;
11-
using torch::frobenius_norm;
11+
using torch::linalg_vector_norm;
1212
using torch::full;
1313
using torch::hstack;
1414
using torch::index_select;
@@ -99,7 +99,7 @@ forward_impl(const std::string& strategy, const Tensor& positions, const Tensor&
9999
deltas.index_put_({Slice(), 0}, deltas.index({Slice(), 0}) -
100100
scale1 * box_vectors.index({pair_batch, 0, 0}));
101101
}
102-
distances = frobenius_norm(deltas, 1);
102+
distances = linalg_vector_norm(deltas, 2, 1);
103103
mask = (distances < cutoff_upper) * (distances >= cutoff_lower);
104104
neighbors = neighbors.index({Slice(), mask});
105105
deltas = deltas.index({mask, Slice()});

torchmdnet/models/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
209209
filepath, args=args, device=device, return_std=return_std, **kwargs
210210
)
211211
assert isinstance(filepath, str)
212-
ckpt = torch.load(filepath, map_location="cpu")
212+
ckpt = torch.load(filepath, map_location="cpu", weights_only=False)
213213
if args is None:
214214
args = ckpt["hyper_parameters"]
215215

0 commit comments

Comments
 (0)