Skip to content

Commit 407cdab

Browse files
authored
Fix TensorProductConv test and improve docs (#4480)
Closes #4459 Authors: - Tingyu Wang (https://github.com/tingyu66) - Ralph Liu (https://github.com/nv-rliu) Approvers: - https://github.com/DejunL - Rick Ratzel (https://github.com/rlratzel) URL: #4480
1 parent 2e969da commit 407cdab

File tree

2 files changed

+74
-47
lines changed

2 files changed

+74
-47
lines changed

python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
from typing import Optional, Sequence, Union
14+
from typing import Optional, Sequence, Union, NamedTuple
1515

1616
import torch
1717
from torch import nn
@@ -31,6 +31,11 @@
3131
) from exc
3232

3333

34+
class Graph(NamedTuple):
35+
edge_index: torch.Tensor
36+
size: tuple[int, int]
37+
38+
3439
class FullyConnectedTensorProductConv(nn.Module):
3540
r"""Message passing layer for tensor products in DiffDock-like architectures.
3641
The left operand of tensor product is the spherical harmonic representation
@@ -81,27 +86,35 @@ class FullyConnectedTensorProductConv(nn.Module):
8186
8287
Examples
8388
--------
84-
>>> # Case 1: MLP with the input layer having 6 channels and 2 hidden layers
85-
>>> # having 16 channels. edge_emb.size(1) must match the size of
86-
>>> # the input layer: 6
87-
>>>
89+
Case 1: MLP with the input layer having 6 channels and 2 hidden layers
90+
having 16 channels. edge_emb.size(1) must match the size of the input layer: 6
91+
8892
>>> conv1 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
8993
>>> mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU()).cuda()
9094
>>> out = conv1(src_features, edge_sh, edge_emb, graph)
91-
>>>
92-
>>> # Case 2: Same as case 1 but with the scalar features from edges, sources
93-
>>> # and destinations passed in separately.
94-
>>>
95+
96+
Case 2: If `edge_emb` is constructed by concatenating scalar features from
97+
edges, sources and destinations, as in DiffDock, the layer can accept each
98+
scalar component separately:
99+
95100
>>> conv2 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
96101
>>> mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU()).cuda()
97-
>>> out = conv3(src_features, edge_sh, edge_scalars, graph,
102+
>>> out = conv2(src_features, edge_sh, edge_scalars, graph,
98103
>>> src_scalars=src_scalars, dst_scalars=dst_scalars)
99-
>>>
100-
>>> # Case 3: No MLP, edge_emb will be directly used as the tensor product weights
101-
>>>
104+
105+
This allows a smaller GEMM in the first MLP layer by performing GEMM on each
106+
component before indexing. The first-layer weights are split into sections
107+
for edges, sources and destinations, in that order.This is equivalent to
108+
109+
>>> src, dst = graph.edge_index
110+
>>> edge_emb = torch.hstack((edge_scalars, src_scalars[src], dst_scalars[dst]))
111+
>>> out = conv2(src_features, edge_sh, edge_emb, graph)
112+
113+
Case 3: No MLP, `edge_emb` will be directly used as the tensor product weights:
114+
102115
>>> conv3 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
103116
>>> mlp_channels=None).cuda()
104-
>>> out = conv2(src_features, edge_sh, edge_emb, graph)
117+
>>> out = conv3(src_features, edge_sh, edge_emb, graph)
105118
106119
"""
107120

@@ -174,20 +187,20 @@ def forward(
174187
Edge embeddings that are fed into MLPs to generate tensor product weights.
175188
Shape: (num_edges, dim), where `dim` should be:
176189
- `tp.weight_numel` when the layer does not contain MLPs.
177-
- num_edge_scalars, with the sum of num_[edge/src/dst]_scalars being
178-
mlp_channels[0]
190+
- num_edge_scalars, when scalar features from edges, sources and
191+
destinations are passed in separately.
179192
180193
graph : tuple
181194
A tuple that stores the graph information, with the first element being
182195
the adjacency matrix in COO, and the second element being its shape:
183196
(num_src_nodes, num_dst_nodes).
184197
185198
src_scalars: torch.Tensor, optional
186-
Scalar features of source nodes.
199+
Scalar features of source nodes. See examples for usage.
187200
Shape: (num_src_nodes, num_src_scalars)
188201
189202
dst_scalars: torch.Tensor, optional
190-
Scalar features of destination nodes.
203+
Scalar features of destination nodes. See examples for usage.
191204
Shape: (num_dst_nodes, num_dst_scalars)
192205
193206
reduce : str, optional (default="mean")

python/cugraph-equivariant/cugraph_equivariant/tests/test_tensor_product_conv.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313

1414
import pytest
1515

16-
import torch
17-
from torch import nn
18-
from e3nn import o3
19-
2016
try:
2117
from cugraph_equivariant.nn import FullyConnectedTensorProductConv
2218
except RuntimeError:
@@ -25,9 +21,29 @@
2521
allow_module_level=True,
2622
)
2723

28-
device = torch.device("cuda:0")
24+
import torch
25+
from torch import nn
26+
from e3nn import o3
27+
from cugraph_equivariant.nn.tensor_product_conv import Graph
28+
29+
device = torch.device("cuda")
2930

3031

32+
def create_random_graph(
33+
num_src_nodes,
34+
num_dst_nodes,
35+
num_edges,
36+
dtype=None,
37+
device=None,
38+
):
39+
row = torch.randint(num_src_nodes, (num_edges,), dtype=dtype, device=device)
40+
col = torch.randint(num_dst_nodes, (num_edges,), dtype=dtype, device=device)
41+
edge_index = torch.stack([row, col], dim=0)
42+
43+
return Graph(edge_index, (num_src_nodes, num_dst_nodes))
44+
45+
46+
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
3147
@pytest.mark.parametrize("e3nn_compat_mode", [True, False])
3248
@pytest.mark.parametrize("batch_norm", [True, False])
3349
@pytest.mark.parametrize(
@@ -39,9 +55,10 @@
3955
],
4056
)
4157
def test_tensor_product_conv_equivariance(
42-
mlp_channels, mlp_activation, scalar_sizes, batch_norm, e3nn_compat_mode
58+
mlp_channels, mlp_activation, scalar_sizes, batch_norm, e3nn_compat_mode, dtype
4359
):
4460
torch.manual_seed(12345)
61+
to_kwargs = {"device": device, "dtype": dtype}
4562

4663
in_irreps = o3.Irreps("10x0e + 10x1e")
4764
out_irreps = o3.Irreps("20x0e + 10x1e")
@@ -55,68 +72,65 @@ def test_tensor_product_conv_equivariance(
5572
mlp_activation=mlp_activation,
5673
batch_norm=batch_norm,
5774
e3nn_compat_mode=e3nn_compat_mode,
58-
).to(device)
75+
).to(**to_kwargs)
5976

6077
num_src_nodes, num_dst_nodes = 9, 7
6178
num_edges = 40
62-
src = torch.randint(num_src_nodes, (num_edges,), device=device)
63-
dst = torch.randint(num_dst_nodes, (num_edges,), device=device)
64-
edge_index = torch.vstack((src, dst))
65-
66-
src_pos = torch.randn(num_src_nodes, 3, device=device)
67-
dst_pos = torch.randn(num_dst_nodes, 3, device=device)
68-
edge_vec = dst_pos[dst] - src_pos[src]
69-
edge_sh = o3.spherical_harmonics(
70-
tp_conv.sh_irreps, edge_vec, normalize=True, normalization="component"
71-
).to(device)
72-
src_features = torch.randn(num_src_nodes, in_irreps.dim, device=device)
79+
graph = create_random_graph(num_src_nodes, num_dst_nodes, num_edges, device=device)
80+
81+
edge_sh = torch.randn(num_edges, sh_irreps.dim, **to_kwargs)
82+
src_features = torch.randn(num_src_nodes, in_irreps.dim, **to_kwargs)
7383

7484
rot = o3.rand_matrix()
75-
D_in = tp_conv.in_irreps.D_from_matrix(rot).to(device)
76-
D_sh = tp_conv.sh_irreps.D_from_matrix(rot).to(device)
77-
D_out = tp_conv.out_irreps.D_from_matrix(rot).to(device)
85+
D_in = tp_conv.in_irreps.D_from_matrix(rot).to(**to_kwargs)
86+
D_sh = tp_conv.sh_irreps.D_from_matrix(rot).to(**to_kwargs)
87+
D_out = tp_conv.out_irreps.D_from_matrix(rot).to(**to_kwargs)
7888

7989
if mlp_channels is None:
80-
edge_emb = torch.randn(num_edges, tp_conv.tp.weight_numel, device=device)
90+
edge_emb = torch.randn(num_edges, tp_conv.tp.weight_numel, **to_kwargs)
8191
src_scalars = dst_scalars = None
8292
else:
8393
if scalar_sizes:
84-
edge_emb = torch.randn(num_edges, scalar_sizes[0], device=device)
94+
edge_emb = torch.randn(num_edges, scalar_sizes[0], **to_kwargs)
8595
src_scalars = (
8696
None
8797
if scalar_sizes[1] == 0
88-
else torch.randn(num_src_nodes, scalar_sizes[1], device=device)
98+
else torch.randn(num_src_nodes, scalar_sizes[1], **to_kwargs)
8999
)
90100
dst_scalars = (
91101
None
92102
if scalar_sizes[2] == 0
93-
else torch.randn(num_dst_nodes, scalar_sizes[2], device=device)
103+
else torch.randn(num_dst_nodes, scalar_sizes[2], **to_kwargs)
94104
)
95105
else:
96-
edge_emb = torch.randn(num_edges, tp_conv.mlp[0].in_features, device=device)
106+
edge_emb = torch.randn(num_edges, tp_conv.mlp[0].in_features, **to_kwargs)
97107
src_scalars = dst_scalars = None
98108

99109
# rotate before
110+
torch.manual_seed(12345)
100111
out_before = tp_conv(
101112
src_features=src_features @ D_in.T,
102113
edge_sh=edge_sh @ D_sh.T,
103114
edge_emb=edge_emb,
104-
graph=(edge_index, (num_src_nodes, num_dst_nodes)),
115+
graph=graph,
105116
src_scalars=src_scalars,
106117
dst_scalars=dst_scalars,
107118
)
108119

109120
# rotate after
121+
torch.manual_seed(12345)
110122
out_after = (
111123
tp_conv(
112124
src_features=src_features,
113125
edge_sh=edge_sh,
114126
edge_emb=edge_emb,
115-
graph=(edge_index, (num_src_nodes, num_dst_nodes)),
127+
graph=graph,
116128
src_scalars=src_scalars,
117129
dst_scalars=dst_scalars,
118130
)
119131
@ D_out.T
120132
)
121133

122-
torch.allclose(out_before, out_after, rtol=1e-4, atol=1e-4)
134+
atol = 1e-3 if dtype == torch.float32 else 1e-1
135+
if e3nn_compat_mode:
136+
assert torch.allclose(out_before, out_after, rtol=1e-4, atol=atol)

0 commit comments

Comments
 (0)