Skip to content

Commit 72d6e8e

Browse files
authored
torch.compile neighbors without graph breaks (#305)
* Add test for full graph neighbor torch.compile * blacken * Add missing sqrt default implementation * Update compile test * Import torch.Tensor * Expose only the extensions * Make CUDA backwards also the operation used in CPU Make fwd and bkwd independent operators Add meta registrations for forwards and backwards Define meta registrations only in pytorch>=2.2.0 * Fix CPU extentension potentially allocating with a negative size * Fix incorrect type
1 parent 6694816 commit 72d6e8e

File tree

7 files changed

+374
-190
lines changed

7 files changed

+374
-190
lines changed

tests/test_neighbors.py

+79-14
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
from torchmdnet.models.utils import OptimizedDistance
1111

12+
1213
def sort_neighbors(neighbors, deltas, distances):
1314
i_sorted = np.lexsort(neighbors)
1415
return neighbors[:, i_sorted], deltas[i_sorted], distances[i_sorted]
@@ -69,7 +70,10 @@ def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box_vecto
6970
return ref_neighbors, ref_distance_vecs, ref_distances
7071

7172

72-
@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")])
73+
@pytest.mark.parametrize(
74+
("device", "strategy"),
75+
[("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")],
76+
)
7377
@pytest.mark.parametrize("n_batches", [1, 2, 3, 4, 128])
7478
@pytest.mark.parametrize("cutoff", [0.1, 1.0, 3.0, 4.9])
7579
@pytest.mark.parametrize("loop", [True, False])
@@ -92,7 +96,7 @@ def test_neighbors(
9296
).to(device)
9397
cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch]))
9498
lbox = 10.0
95-
pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox - 10.0*lbox
99+
pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox - 10.0 * lbox
96100
# Ensure there is at least one pair
97101
pos[0, :] = torch.zeros(3)
98102
pos[1, :] = torch.zeros(3)
@@ -141,7 +145,11 @@ def test_neighbors(
141145
assert np.allclose(distances, ref_distances)
142146
assert np.allclose(distance_vecs, ref_distance_vecs)
143147

144-
@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")])
148+
149+
@pytest.mark.parametrize(
150+
("device", "strategy"),
151+
[("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")],
152+
)
145153
@pytest.mark.parametrize("loop", [True, False])
146154
@pytest.mark.parametrize("include_transpose", [True, False])
147155
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
@@ -249,10 +257,14 @@ def test_neighbor_grads(
249257
else:
250258
assert np.allclose(ref_pos_grad_sorted, pos_grad_sorted, atol=1e-8, rtol=1e-5)
251259

252-
@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")])
260+
261+
@pytest.mark.parametrize(
262+
("device", "strategy"),
263+
[("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")],
264+
)
253265
@pytest.mark.parametrize("loop", [True, False])
254266
@pytest.mark.parametrize("include_transpose", [True, False])
255-
@pytest.mark.parametrize("num_atoms", [1,2,10])
267+
@pytest.mark.parametrize("num_atoms", [1, 2, 10])
256268
@pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"])
257269
def test_neighbor_autograds(
258270
device, strategy, loop, include_transpose, num_atoms, box_type
@@ -293,8 +305,12 @@ def test_neighbor_autograds(
293305
neighbors, distances, deltas = nl(positions, batch)
294306
# Lambda that returns only the distances and deltas
295307
lambda_dist = lambda x, y: nl(x, y)[1:]
296-
torch.autograd.gradcheck(lambda_dist, (positions, batch), eps=1e-4, atol=1e-4, rtol=1e-4, nondet_tol=1e-4)
297-
torch.autograd.gradgradcheck(lambda_dist, (positions, batch), eps=1e-4, atol=1e-4, rtol=1e-4, nondet_tol=1e-4)
308+
torch.autograd.gradcheck(
309+
lambda_dist, (positions, batch), eps=1e-4, atol=1e-4, rtol=1e-4, nondet_tol=1e-4
310+
)
311+
torch.autograd.gradgradcheck(
312+
lambda_dist, (positions, batch), eps=1e-5, atol=1e-4, rtol=1e-4, nondet_tol=1e-3
313+
)
298314

299315

300316
@pytest.mark.parametrize("strategy", ["brute", "cell", "shared"])
@@ -353,7 +369,11 @@ def test_large_size(strategy, n_batches):
353369
assert np.allclose(distances, ref_distances)
354370
assert np.allclose(distance_vecs, ref_distance_vecs)
355371

356-
@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")])
372+
373+
@pytest.mark.parametrize(
374+
("device", "strategy"),
375+
[("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")],
376+
)
357377
@pytest.mark.parametrize("n_batches", [1, 128])
358378
@pytest.mark.parametrize("cutoff", [1.0])
359379
@pytest.mark.parametrize("loop", [True, False])
@@ -504,6 +524,7 @@ def test_cuda_graph_compatible_forward(
504524
assert np.allclose(distances, ref_distances)
505525
assert np.allclose(distance_vecs, ref_distance_vecs)
506526

527+
507528
@pytest.mark.parametrize("device", ["cuda"])
508529
@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"])
509530
@pytest.mark.parametrize("n_batches", [1, 128])
@@ -578,12 +599,12 @@ def test_cuda_graph_compatible_backward(
578599
torch.cuda.synchronize()
579600

580601

581-
@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared")])
602+
@pytest.mark.parametrize(
603+
("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared")]
604+
)
582605
@pytest.mark.parametrize("n_batches", [1, 128])
583606
@pytest.mark.parametrize("use_forward", [True, False])
584-
def test_per_batch_box(
585-
device, strategy, n_batches, use_forward
586-
):
607+
def test_per_batch_box(device, strategy, n_batches, use_forward):
587608
dtype = torch.float32
588609
cutoff = 1.0
589610
include_transpose = True
@@ -599,7 +620,7 @@ def test_per_batch_box(
599620
).to(device)
600621
cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch]))
601622
lbox = 10.0
602-
pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox - 10.0*lbox
623+
pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox - 10.0 * lbox
603624
# Ensure there is at least one pair
604625
pos[0, :] = torch.zeros(3)
605626
pos[1, :] = torch.zeros(3)
@@ -625,7 +646,9 @@ def test_per_batch_box(
625646
include_transpose=include_transpose,
626647
)
627648
batch.to(device)
628-
neighbors, distances, distance_vecs = nl(pos, batch, box=box if use_forward else None)
649+
neighbors, distances, distance_vecs = nl(
650+
pos, batch, box=box if use_forward else None
651+
)
629652
neighbors = neighbors.cpu().detach().numpy()
630653
distance_vecs = distance_vecs.cpu().detach().numpy()
631654
distances = distances.cpu().detach().numpy()
@@ -639,3 +662,45 @@ def test_per_batch_box(
639662
assert np.allclose(neighbors, ref_neighbors)
640663
assert np.allclose(distances, ref_distances)
641664
assert np.allclose(distance_vecs, ref_distance_vecs)
665+
666+
667+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
668+
@pytest.mark.parametrize("dtype", [torch.float64])
669+
@pytest.mark.parametrize("loop", [True, False])
670+
@pytest.mark.parametrize("include_transpose", [True, False])
671+
def test_torch_compile(device, dtype, loop, include_transpose):
672+
if torch.__version__ < "2.0.0":
673+
pytest.skip("Not available in this version")
674+
if device == "cuda" and not torch.cuda.is_available():
675+
pytest.skip("CUDA not available")
676+
np.random.seed(123456)
677+
example_pos = 10 * torch.rand(50, 3, requires_grad=True, dtype=dtype, device=device)
678+
model = OptimizedDistance(
679+
cutoff_lower=0.1, # I do this to avoid non-finite-differentiable points
680+
cutoff_upper=10,
681+
return_vecs=True,
682+
loop=loop,
683+
max_num_pairs=-example_pos.shape[0],
684+
include_transpose=include_transpose,
685+
resize_to_fit=False,
686+
check_errors=False,
687+
).to(device)
688+
for _ in range(50):
689+
model(example_pos)
690+
example_pos = example_pos.detach().requires_grad_(True)
691+
edge_index, edge_vec, edge_distance = model(example_pos)
692+
edge_vec.sum().backward()
693+
example_pos.grad.zero_()
694+
fullgraph = torch.__version__ >= "2.2.0"
695+
model = torch.compile(
696+
model,
697+
fullgraph=fullgraph,
698+
backend="inductor",
699+
mode="reduce-overhead",
700+
)
701+
edge_index, edge_vec, edge_distance = model(example_pos)
702+
edge_vec.sum().backward()
703+
lambda_dist = lambda x: model(x)[1:]
704+
torch.autograd.gradcheck(
705+
lambda_dist, example_pos, eps=1e-5, atol=1e-4, rtol=1e-4, nondet_tol=1e-3
706+
)

torchmdnet/extensions/__init__.py

+60-18
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
# Place here any short extensions to torch that you want to use in your code.
66
# The extensions present in extensions.cpp will be automatically compiled in setup.py and loaded here.
77
# The extensions will be available under torch.ops.torchmdnet_extensions, but you can add wrappers here to make them more convenient to use.
8+
# Place here too any meta registrations for your extensions if required.
9+
810
import os.path as osp
911
import torch
1012
import importlib.machinery
13+
from torch import Tensor
1114
from typing import Tuple
1215

1316

@@ -29,6 +32,8 @@ def _load_library(library):
2932

3033
_load_library("torchmdnet_extensions")
3134

35+
__all__ = ["is_current_stream_capturing", "get_neighbor_pairs_kernel"]
36+
3237

3338
def is_current_stream_capturing():
3439
"""Returns True if the current CUDA stream is capturing.
@@ -45,30 +50,29 @@ def is_current_stream_capturing():
4550

4651
def get_neighbor_pairs_kernel(
4752
strategy: str,
48-
positions: torch.Tensor,
49-
batch: torch.Tensor,
50-
box_vectors: torch.Tensor,
53+
positions: Tensor,
54+
batch: Tensor,
55+
box_vectors: Tensor,
5156
use_periodic: bool,
5257
cutoff_lower: float,
5358
cutoff_upper: float,
5459
max_num_pairs: int,
5560
loop: bool,
5661
include_transpose: bool,
57-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
62+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
5863
"""Computes the neighbor pairs for a given set of atomic positions.
59-
6064
The list is generated as a list of pairs (i,j) without any enforced ordering.
6165
The list is padded with -1 to the maximum number of pairs.
6266
6367
Parameters
6468
----------
6569
strategy : str
6670
Strategy to use for computing the neighbor list. Can be one of :code:`["shared", "brute", "cell"]`.
67-
positions : torch.Tensor
71+
positions : Tensor
6872
A tensor with shape (N, 3) representing the atomic positions.
69-
batch : torch.Tensor
73+
batch : Tensor
7074
A tensor with shape (N,). Specifies the batch for each atom.
71-
box_vectors : torch.Tensor
75+
box_vectors : Tensor
7276
The vectors defining the periodic box with shape `(3, 3)` or `(max(batch)+1, 3, 3)` if a different box is used for each sample.
7377
use_periodic : bool
7478
Whether to apply periodic boundary conditions.
@@ -85,18 +89,14 @@ def get_neighbor_pairs_kernel(
8589
8690
Returns
8791
-------
88-
neighbors : torch.Tensor
92+
neighbors : Tensor
8993
List of neighbors for each atom. Shape (2, max_num_pairs).
90-
distances : torch.Tensor
94+
distances : Tensor
9195
List of distances for each atom. Shape (max_num_pairs,).
92-
distance_vecs : torch.Tensor
96+
distance_vecs : Tensor
9397
List of distance vectors for each atom. Shape (max_num_pairs, 3).
94-
num_pairs : torch.Tensor
98+
num_pairs : Tensor
9599
The number of pairs found.
96-
97-
Notes
98-
-----
99-
This function is a torch extension loaded from `torch.ops.torchmdnet_extensions.get_neighbor_pairs`.
100100
"""
101101
return torch.ops.torchmdnet_extensions.get_neighbor_pairs(
102102
strategy,
@@ -112,7 +112,49 @@ def get_neighbor_pairs_kernel(
112112
)
113113

114114

115-
# For some unknown reason torch.compile is not able to compile this function
116-
if int(torch.__version__.split(".")[0]) >= 2:
115+
def get_neighbor_pairs_bkwd_meta(
116+
grad_edge_vec: Tensor,
117+
grad_edge_weight: Tensor,
118+
edge_index: Tensor,
119+
edge_vec: Tensor,
120+
edge_weight: Tensor,
121+
num_atoms: int,
122+
):
123+
return torch.zeros((num_atoms, 3), dtype=edge_vec.dtype, device=edge_vec.device)
124+
125+
126+
def get_neighbor_pairs_fwd_meta(
127+
strategy: str,
128+
positions: Tensor,
129+
batch: Tensor,
130+
box_vectors: Tensor,
131+
use_periodic: bool,
132+
cutoff_lower: float,
133+
cutoff_upper: float,
134+
max_num_pairs: int,
135+
loop: bool,
136+
include_transpose: bool,
137+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
138+
"""Returns empty vectors with the correct shape for the output of get_neighbor_pairs_kernel."""
139+
size = max_num_pairs
140+
edge_index = torch.empty((2, size), dtype=torch.int, device=positions.device)
141+
edge_distance = torch.empty((size,), dtype=positions.dtype, device=positions.device)
142+
edge_vec = torch.empty((size, 3), dtype=positions.dtype, device=positions.device)
143+
num_pairs = torch.empty((1,), dtype=torch.int, device=positions.device)
144+
return edge_index, edge_vec, edge_distance, num_pairs
145+
146+
147+
if torch.__version__ >= "2.2.0":
148+
from torch.library import impl_abstract
149+
150+
impl_abstract(
151+
"torchmdnet_extensions::get_neighbor_pairs_bkwd", get_neighbor_pairs_bkwd_meta
152+
)
153+
impl_abstract(
154+
"torchmdnet_extensions::get_neighbor_pairs_fwd", get_neighbor_pairs_fwd_meta
155+
)
156+
elif torch.__version__ < "2.2.0" and torch.__version__ >= "2.0.0":
157+
# torch.compile is not able to compile this function in old versions
117158
import torch._dynamo as dynamo
159+
118160
dynamo.disallow_in_graph(torch.ops.torchmdnet_extensions.get_neighbor_pairs)

torchmdnet/extensions/extensions.cpp

+20-1
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,27 @@ bool is_current_stream_capturing() {
3333
#endif
3434
}
3535

36+
#define TORCH_VERSION_CODE(MAJOR, MINOR, PATCH) ((MAJOR)*10000 + (MINOR)*100 + (PATCH))
37+
#define TORCH_VERSION_COMPARE_LE(MAJOR, MINOR, PATCH) \
38+
(TORCH_VERSION_CODE(TORCH_VERSION_MAJOR, TORCH_VERSION_MINOR, TORCH_VERSION_PATCH) >= \
39+
TORCH_VERSION_CODE(MAJOR, MINOR, PATCH))
3640

3741
TORCH_LIBRARY(torchmdnet_extensions, m) {
3842
m.def("is_current_stream_capturing", is_current_stream_capturing);
39-
m.def("get_neighbor_pairs(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)");
43+
#if TORCH_VERSION_COMPARE_LE(2, 2, 0)
44+
//This line is required to signal to torch that the meta registration is implemented in python.
45+
// Specifically, it will look for them in the torchmdnet.extensions module.
46+
m.impl_abstract_pystub("torchmdnet.extensions");
47+
#endif
48+
m.def("get_neighbor_pairs(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, "
49+
"bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool "
50+
"loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor "
51+
"distance_vecs, Tensor num_pairs)");
52+
//The individual fwd and bkwd functions must be exposed in order to register their meta implementations python side.
53+
m.def("get_neighbor_pairs_fwd(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, "
54+
"bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool "
55+
"loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor "
56+
"distance_vecs, Tensor num_pairs)");
57+
m.def("get_neighbor_pairs_bkwd(Tensor grad_edge_vec, Tensor grad_edge_weight, Tensor edge_index, "
58+
"Tensor edge_vec, Tensor edge_weight, int num_atoms) -> Tensor");
4059
}

torchmdnet/extensions/neighbors/common.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ inline Accessor<scalar_t, num_dims> get_accessor(const Tensor& tensor) {
3434
return tensor.packed_accessor32<scalar_t, num_dims, torch::RestrictPtrTraits>();
3535
};
3636

37-
template <typename scalar_t> __device__ __forceinline__ scalar_t sqrt_(scalar_t x){};
37+
template <typename scalar_t> __device__ __forceinline__ scalar_t sqrt_(scalar_t x){return ::sqrt(x);};
3838
template <> __device__ __forceinline__ float sqrt_(float x) {
3939
return ::sqrtf(x);
4040
};

0 commit comments

Comments
 (0)