13
13
14
14
import pytest
15
15
16
- import torch
17
- from torch import nn
18
- from e3nn import o3
19
-
20
16
try :
21
17
from cugraph_equivariant .nn import FullyConnectedTensorProductConv
22
18
except RuntimeError :
25
21
allow_module_level = True ,
26
22
)
27
23
28
- device = torch .device ("cuda:0" )
24
+ import torch
25
+ from torch import nn
26
+ from e3nn import o3
27
+ from cugraph_equivariant .nn .tensor_product_conv import Graph
28
+
29
+ device = torch .device ("cuda" )
29
30
30
31
32
+ def create_random_graph (
33
+ num_src_nodes ,
34
+ num_dst_nodes ,
35
+ num_edges ,
36
+ dtype = None ,
37
+ device = None ,
38
+ ):
39
+ row = torch .randint (num_src_nodes , (num_edges ,), dtype = dtype , device = device )
40
+ col = torch .randint (num_dst_nodes , (num_edges ,), dtype = dtype , device = device )
41
+ edge_index = torch .stack ([row , col ], dim = 0 )
42
+
43
+ return Graph (edge_index , (num_src_nodes , num_dst_nodes ))
44
+
45
+
46
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .float16 , torch .bfloat16 ])
31
47
@pytest .mark .parametrize ("e3nn_compat_mode" , [True , False ])
32
48
@pytest .mark .parametrize ("batch_norm" , [True , False ])
33
49
@pytest .mark .parametrize (
39
55
],
40
56
)
41
57
def test_tensor_product_conv_equivariance (
42
- mlp_channels , mlp_activation , scalar_sizes , batch_norm , e3nn_compat_mode
58
+ mlp_channels , mlp_activation , scalar_sizes , batch_norm , e3nn_compat_mode , dtype
43
59
):
44
60
torch .manual_seed (12345 )
61
+ to_kwargs = {"device" : device , "dtype" : dtype }
45
62
46
63
in_irreps = o3 .Irreps ("10x0e + 10x1e" )
47
64
out_irreps = o3 .Irreps ("20x0e + 10x1e" )
@@ -55,68 +72,65 @@ def test_tensor_product_conv_equivariance(
55
72
mlp_activation = mlp_activation ,
56
73
batch_norm = batch_norm ,
57
74
e3nn_compat_mode = e3nn_compat_mode ,
58
- ).to (device )
75
+ ).to (** to_kwargs )
59
76
60
77
num_src_nodes , num_dst_nodes = 9 , 7
61
78
num_edges = 40
62
- src = torch .randint (num_src_nodes , (num_edges ,), device = device )
63
- dst = torch .randint (num_dst_nodes , (num_edges ,), device = device )
64
- edge_index = torch .vstack ((src , dst ))
65
-
66
- src_pos = torch .randn (num_src_nodes , 3 , device = device )
67
- dst_pos = torch .randn (num_dst_nodes , 3 , device = device )
68
- edge_vec = dst_pos [dst ] - src_pos [src ]
69
- edge_sh = o3 .spherical_harmonics (
70
- tp_conv .sh_irreps , edge_vec , normalize = True , normalization = "component"
71
- ).to (device )
72
- src_features = torch .randn (num_src_nodes , in_irreps .dim , device = device )
79
+ graph = create_random_graph (num_src_nodes , num_dst_nodes , num_edges , device = device )
80
+
81
+ edge_sh = torch .randn (num_edges , sh_irreps .dim , ** to_kwargs )
82
+ src_features = torch .randn (num_src_nodes , in_irreps .dim , ** to_kwargs )
73
83
74
84
rot = o3 .rand_matrix ()
75
- D_in = tp_conv .in_irreps .D_from_matrix (rot ).to (device )
76
- D_sh = tp_conv .sh_irreps .D_from_matrix (rot ).to (device )
77
- D_out = tp_conv .out_irreps .D_from_matrix (rot ).to (device )
85
+ D_in = tp_conv .in_irreps .D_from_matrix (rot ).to (** to_kwargs )
86
+ D_sh = tp_conv .sh_irreps .D_from_matrix (rot ).to (** to_kwargs )
87
+ D_out = tp_conv .out_irreps .D_from_matrix (rot ).to (** to_kwargs )
78
88
79
89
if mlp_channels is None :
80
- edge_emb = torch .randn (num_edges , tp_conv .tp .weight_numel , device = device )
90
+ edge_emb = torch .randn (num_edges , tp_conv .tp .weight_numel , ** to_kwargs )
81
91
src_scalars = dst_scalars = None
82
92
else :
83
93
if scalar_sizes :
84
- edge_emb = torch .randn (num_edges , scalar_sizes [0 ], device = device )
94
+ edge_emb = torch .randn (num_edges , scalar_sizes [0 ], ** to_kwargs )
85
95
src_scalars = (
86
96
None
87
97
if scalar_sizes [1 ] == 0
88
- else torch .randn (num_src_nodes , scalar_sizes [1 ], device = device )
98
+ else torch .randn (num_src_nodes , scalar_sizes [1 ], ** to_kwargs )
89
99
)
90
100
dst_scalars = (
91
101
None
92
102
if scalar_sizes [2 ] == 0
93
- else torch .randn (num_dst_nodes , scalar_sizes [2 ], device = device )
103
+ else torch .randn (num_dst_nodes , scalar_sizes [2 ], ** to_kwargs )
94
104
)
95
105
else :
96
- edge_emb = torch .randn (num_edges , tp_conv .mlp [0 ].in_features , device = device )
106
+ edge_emb = torch .randn (num_edges , tp_conv .mlp [0 ].in_features , ** to_kwargs )
97
107
src_scalars = dst_scalars = None
98
108
99
109
# rotate before
110
+ torch .manual_seed (12345 )
100
111
out_before = tp_conv (
101
112
src_features = src_features @ D_in .T ,
102
113
edge_sh = edge_sh @ D_sh .T ,
103
114
edge_emb = edge_emb ,
104
- graph = ( edge_index , ( num_src_nodes , num_dst_nodes )) ,
115
+ graph = graph ,
105
116
src_scalars = src_scalars ,
106
117
dst_scalars = dst_scalars ,
107
118
)
108
119
109
120
# rotate after
121
+ torch .manual_seed (12345 )
110
122
out_after = (
111
123
tp_conv (
112
124
src_features = src_features ,
113
125
edge_sh = edge_sh ,
114
126
edge_emb = edge_emb ,
115
- graph = ( edge_index , ( num_src_nodes , num_dst_nodes )) ,
127
+ graph = graph ,
116
128
src_scalars = src_scalars ,
117
129
dst_scalars = dst_scalars ,
118
130
)
119
131
@ D_out .T
120
132
)
121
133
122
- torch .allclose (out_before , out_after , rtol = 1e-4 , atol = 1e-4 )
134
+ atol = 1e-3 if dtype == torch .float32 else 1e-1
135
+ if e3nn_compat_mode :
136
+ assert torch .allclose (out_before , out_after , rtol = 1e-4 , atol = atol )
0 commit comments