9
9
import numpy as np
10
10
from torchmdnet .models .utils import OptimizedDistance
11
11
12
+
12
13
def sort_neighbors (neighbors , deltas , distances ):
13
14
i_sorted = np .lexsort (neighbors )
14
15
return neighbors [:, i_sorted ], deltas [i_sorted ], distances [i_sorted ]
@@ -69,7 +70,10 @@ def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box_vecto
69
70
return ref_neighbors , ref_distance_vecs , ref_distances
70
71
71
72
72
- @pytest .mark .parametrize (("device" , "strategy" ), [("cpu" , "brute" ), ("cuda" , "brute" ), ("cuda" , "shared" ), ("cuda" , "cell" )])
73
+ @pytest .mark .parametrize (
74
+ ("device" , "strategy" ),
75
+ [("cpu" , "brute" ), ("cuda" , "brute" ), ("cuda" , "shared" ), ("cuda" , "cell" )],
76
+ )
73
77
@pytest .mark .parametrize ("n_batches" , [1 , 2 , 3 , 4 , 128 ])
74
78
@pytest .mark .parametrize ("cutoff" , [0.1 , 1.0 , 3.0 , 4.9 ])
75
79
@pytest .mark .parametrize ("loop" , [True , False ])
@@ -92,7 +96,7 @@ def test_neighbors(
92
96
).to (device )
93
97
cumsum = np .cumsum (np .concatenate ([[0 ], n_atoms_per_batch ]))
94
98
lbox = 10.0
95
- pos = torch .rand (cumsum [- 1 ], 3 , device = device , dtype = dtype ) * lbox - 10.0 * lbox
99
+ pos = torch .rand (cumsum [- 1 ], 3 , device = device , dtype = dtype ) * lbox - 10.0 * lbox
96
100
# Ensure there is at least one pair
97
101
pos [0 , :] = torch .zeros (3 )
98
102
pos [1 , :] = torch .zeros (3 )
@@ -141,7 +145,11 @@ def test_neighbors(
141
145
assert np .allclose (distances , ref_distances )
142
146
assert np .allclose (distance_vecs , ref_distance_vecs )
143
147
144
- @pytest .mark .parametrize (("device" , "strategy" ), [("cpu" , "brute" ), ("cuda" , "brute" ), ("cuda" , "shared" ), ("cuda" , "cell" )])
148
+
149
+ @pytest .mark .parametrize (
150
+ ("device" , "strategy" ),
151
+ [("cpu" , "brute" ), ("cuda" , "brute" ), ("cuda" , "shared" ), ("cuda" , "cell" )],
152
+ )
145
153
@pytest .mark .parametrize ("loop" , [True , False ])
146
154
@pytest .mark .parametrize ("include_transpose" , [True , False ])
147
155
@pytest .mark .parametrize ("dtype" , [torch .float32 , torch .float64 ])
@@ -249,10 +257,14 @@ def test_neighbor_grads(
249
257
else :
250
258
assert np .allclose (ref_pos_grad_sorted , pos_grad_sorted , atol = 1e-8 , rtol = 1e-5 )
251
259
252
- @pytest .mark .parametrize (("device" , "strategy" ), [("cpu" , "brute" ), ("cuda" , "brute" ), ("cuda" , "shared" ), ("cuda" , "cell" )])
260
+
261
+ @pytest .mark .parametrize (
262
+ ("device" , "strategy" ),
263
+ [("cpu" , "brute" ), ("cuda" , "brute" ), ("cuda" , "shared" ), ("cuda" , "cell" )],
264
+ )
253
265
@pytest .mark .parametrize ("loop" , [True , False ])
254
266
@pytest .mark .parametrize ("include_transpose" , [True , False ])
255
- @pytest .mark .parametrize ("num_atoms" , [1 ,2 , 10 ])
267
+ @pytest .mark .parametrize ("num_atoms" , [1 , 2 , 10 ])
256
268
@pytest .mark .parametrize ("box_type" , [None , "triclinic" , "rectangular" ])
257
269
def test_neighbor_autograds (
258
270
device , strategy , loop , include_transpose , num_atoms , box_type
@@ -293,8 +305,12 @@ def test_neighbor_autograds(
293
305
neighbors , distances , deltas = nl (positions , batch )
294
306
# Lambda that returns only the distances and deltas
295
307
lambda_dist = lambda x , y : nl (x , y )[1 :]
296
- torch .autograd .gradcheck (lambda_dist , (positions , batch ), eps = 1e-4 , atol = 1e-4 , rtol = 1e-4 , nondet_tol = 1e-4 )
297
- torch .autograd .gradgradcheck (lambda_dist , (positions , batch ), eps = 1e-4 , atol = 1e-4 , rtol = 1e-4 , nondet_tol = 1e-4 )
308
+ torch .autograd .gradcheck (
309
+ lambda_dist , (positions , batch ), eps = 1e-4 , atol = 1e-4 , rtol = 1e-4 , nondet_tol = 1e-4
310
+ )
311
+ torch .autograd .gradgradcheck (
312
+ lambda_dist , (positions , batch ), eps = 1e-5 , atol = 1e-4 , rtol = 1e-4 , nondet_tol = 1e-3
313
+ )
298
314
299
315
300
316
@pytest .mark .parametrize ("strategy" , ["brute" , "cell" , "shared" ])
@@ -353,7 +369,11 @@ def test_large_size(strategy, n_batches):
353
369
assert np .allclose (distances , ref_distances )
354
370
assert np .allclose (distance_vecs , ref_distance_vecs )
355
371
356
- @pytest .mark .parametrize (("device" , "strategy" ), [("cpu" , "brute" ), ("cuda" , "brute" ), ("cuda" , "shared" ), ("cuda" , "cell" )])
372
+
373
+ @pytest .mark .parametrize (
374
+ ("device" , "strategy" ),
375
+ [("cpu" , "brute" ), ("cuda" , "brute" ), ("cuda" , "shared" ), ("cuda" , "cell" )],
376
+ )
357
377
@pytest .mark .parametrize ("n_batches" , [1 , 128 ])
358
378
@pytest .mark .parametrize ("cutoff" , [1.0 ])
359
379
@pytest .mark .parametrize ("loop" , [True , False ])
@@ -504,6 +524,7 @@ def test_cuda_graph_compatible_forward(
504
524
assert np .allclose (distances , ref_distances )
505
525
assert np .allclose (distance_vecs , ref_distance_vecs )
506
526
527
+
507
528
@pytest .mark .parametrize ("device" , ["cuda" ])
508
529
@pytest .mark .parametrize ("strategy" , ["brute" , "shared" , "cell" ])
509
530
@pytest .mark .parametrize ("n_batches" , [1 , 128 ])
@@ -578,12 +599,12 @@ def test_cuda_graph_compatible_backward(
578
599
torch .cuda .synchronize ()
579
600
580
601
581
- @pytest .mark .parametrize (("device" , "strategy" ), [("cpu" , "brute" ), ("cuda" , "brute" ), ("cuda" , "shared" )])
602
+ @pytest .mark .parametrize (
603
+ ("device" , "strategy" ), [("cpu" , "brute" ), ("cuda" , "brute" ), ("cuda" , "shared" )]
604
+ )
582
605
@pytest .mark .parametrize ("n_batches" , [1 , 128 ])
583
606
@pytest .mark .parametrize ("use_forward" , [True , False ])
584
- def test_per_batch_box (
585
- device , strategy , n_batches , use_forward
586
- ):
607
+ def test_per_batch_box (device , strategy , n_batches , use_forward ):
587
608
dtype = torch .float32
588
609
cutoff = 1.0
589
610
include_transpose = True
@@ -599,7 +620,7 @@ def test_per_batch_box(
599
620
).to (device )
600
621
cumsum = np .cumsum (np .concatenate ([[0 ], n_atoms_per_batch ]))
601
622
lbox = 10.0
602
- pos = torch .rand (cumsum [- 1 ], 3 , device = device , dtype = dtype ) * lbox - 10.0 * lbox
623
+ pos = torch .rand (cumsum [- 1 ], 3 , device = device , dtype = dtype ) * lbox - 10.0 * lbox
603
624
# Ensure there is at least one pair
604
625
pos [0 , :] = torch .zeros (3 )
605
626
pos [1 , :] = torch .zeros (3 )
@@ -625,7 +646,9 @@ def test_per_batch_box(
625
646
include_transpose = include_transpose ,
626
647
)
627
648
batch .to (device )
628
- neighbors , distances , distance_vecs = nl (pos , batch , box = box if use_forward else None )
649
+ neighbors , distances , distance_vecs = nl (
650
+ pos , batch , box = box if use_forward else None
651
+ )
629
652
neighbors = neighbors .cpu ().detach ().numpy ()
630
653
distance_vecs = distance_vecs .cpu ().detach ().numpy ()
631
654
distances = distances .cpu ().detach ().numpy ()
@@ -639,3 +662,45 @@ def test_per_batch_box(
639
662
assert np .allclose (neighbors , ref_neighbors )
640
663
assert np .allclose (distances , ref_distances )
641
664
assert np .allclose (distance_vecs , ref_distance_vecs )
665
+
666
+
667
+ @pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
668
+ @pytest .mark .parametrize ("dtype" , [torch .float64 ])
669
+ @pytest .mark .parametrize ("loop" , [True , False ])
670
+ @pytest .mark .parametrize ("include_transpose" , [True , False ])
671
+ def test_torch_compile (device , dtype , loop , include_transpose ):
672
+ if torch .__version__ < "2.0.0" :
673
+ pytest .skip ("Not available in this version" )
674
+ if device == "cuda" and not torch .cuda .is_available ():
675
+ pytest .skip ("CUDA not available" )
676
+ np .random .seed (123456 )
677
+ example_pos = 10 * torch .rand (50 , 3 , requires_grad = True , dtype = dtype , device = device )
678
+ model = OptimizedDistance (
679
+ cutoff_lower = 0.1 , # I do this to avoid non-finite-differentiable points
680
+ cutoff_upper = 10 ,
681
+ return_vecs = True ,
682
+ loop = loop ,
683
+ max_num_pairs = - example_pos .shape [0 ],
684
+ include_transpose = include_transpose ,
685
+ resize_to_fit = False ,
686
+ check_errors = False ,
687
+ ).to (device )
688
+ for _ in range (50 ):
689
+ model (example_pos )
690
+ example_pos = example_pos .detach ().requires_grad_ (True )
691
+ edge_index , edge_vec , edge_distance = model (example_pos )
692
+ edge_vec .sum ().backward ()
693
+ example_pos .grad .zero_ ()
694
+ fullgraph = torch .__version__ >= "2.2.0"
695
+ model = torch .compile (
696
+ model ,
697
+ fullgraph = fullgraph ,
698
+ backend = "inductor" ,
699
+ mode = "reduce-overhead" ,
700
+ )
701
+ edge_index , edge_vec , edge_distance = model (example_pos )
702
+ edge_vec .sum ().backward ()
703
+ lambda_dist = lambda x : model (x )[1 :]
704
+ torch .autograd .gradcheck (
705
+ lambda_dist , example_pos , eps = 1e-5 , atol = 1e-4 , rtol = 1e-4 , nondet_tol = 1e-3
706
+ )
0 commit comments