Skip to content

Fix TensorProductConv test and improve docs #4480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 9, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Union
from typing import Optional, Sequence, Union, NamedTuple

import torch
from torch import nn
Expand All @@ -31,6 +31,11 @@
) from exc


class Graph(NamedTuple):
edge_index: torch.Tensor
size: tuple[int, int]


class FullyConnectedTensorProductConv(nn.Module):
r"""Message passing layer for tensor products in DiffDock-like architectures.
The left operand of tensor product is the spherical harmonic representation
Expand Down Expand Up @@ -81,27 +86,35 @@ class FullyConnectedTensorProductConv(nn.Module):

Examples
--------
>>> # Case 1: MLP with the input layer having 6 channels and 2 hidden layers
>>> # having 16 channels. edge_emb.size(1) must match the size of
>>> # the input layer: 6
>>>
Case 1: MLP with the input layer having 6 channels and 2 hidden layers
having 16 channels. edge_emb.size(1) must match the size of the input layer: 6

>>> conv1 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
>>> mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU()).cuda()
>>> out = conv1(src_features, edge_sh, edge_emb, graph)
>>>
>>> # Case 2: Same as case 1 but with the scalar features from edges, sources
>>> # and destinations passed in separately.
>>>

Case 2: If `edge_emb` is constructed by concatenating scalar features from
edges, sources and destinations, as in DiffDock, the layer can accept each
scalar component separately:

>>> conv2 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
>>> mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU()).cuda()
>>> out = conv3(src_features, edge_sh, edge_scalars, graph,
>>> out = conv2(src_features, edge_sh, edge_scalars, graph,
>>> src_scalars=src_scalars, dst_scalars=dst_scalars)
>>>
>>> # Case 3: No MLP, edge_emb will be directly used as the tensor product weights
>>>

This allows a smaller GEMM in the first MLP layer by performing GEMM on each
component before indexing. The first-layer weights are split into sections
for edges, sources and destinations, in that order.This is equivalent to

>>> src, dst = graph.edge_index
>>> edge_emb = torch.hstack((edge_scalars, src_scalars[src], dst_scalars[dst]))
>>> out = conv2(src_features, edge_sh, edge_emb, graph)

Case 3: No MLP, `edge_emb` will be directly used as the tensor product weights:

>>> conv3 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
>>> mlp_channels=None).cuda()
>>> out = conv2(src_features, edge_sh, edge_emb, graph)
>>> out = conv3(src_features, edge_sh, edge_emb, graph)

"""

Expand Down Expand Up @@ -174,20 +187,20 @@ def forward(
Edge embeddings that are fed into MLPs to generate tensor product weights.
Shape: (num_edges, dim), where `dim` should be:
- `tp.weight_numel` when the layer does not contain MLPs.
- num_edge_scalars, with the sum of num_[edge/src/dst]_scalars being
mlp_channels[0]
- num_edge_scalars, when scalar features from edges, sources and
destinations are passed in separately.

graph : tuple
A tuple that stores the graph information, with the first element being
the adjacency matrix in COO, and the second element being its shape:
(num_src_nodes, num_dst_nodes).

src_scalars: torch.Tensor, optional
Scalar features of source nodes.
Scalar features of source nodes. See examples for usage.
Shape: (num_src_nodes, num_src_scalars)

dst_scalars: torch.Tensor, optional
Scalar features of destination nodes.
Scalar features of destination nodes. See examples for usage.
Shape: (num_dst_nodes, num_dst_scalars)

reduce : str, optional (default="mean")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@

import pytest

import torch
from torch import nn
from e3nn import o3

try:
from cugraph_equivariant.nn import FullyConnectedTensorProductConv
except RuntimeError:
Expand All @@ -25,9 +21,29 @@
allow_module_level=True,
)

device = torch.device("cuda:0")
import torch
from torch import nn
from e3nn import o3
from cugraph_equivariant.nn.tensor_product_conv import Graph

device = torch.device("cuda")


def create_random_graph(
num_src_nodes,
num_dst_nodes,
num_edges,
dtype=None,
device=None,
):
row = torch.randint(num_src_nodes, (num_edges,), dtype=dtype, device=device)
col = torch.randint(num_dst_nodes, (num_edges,), dtype=dtype, device=device)
edge_index = torch.stack([row, col], dim=0)

return Graph(edge_index, (num_src_nodes, num_dst_nodes))


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("e3nn_compat_mode", [True, False])
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize(
Expand All @@ -39,9 +55,10 @@
],
)
def test_tensor_product_conv_equivariance(
mlp_channels, mlp_activation, scalar_sizes, batch_norm, e3nn_compat_mode
mlp_channels, mlp_activation, scalar_sizes, batch_norm, e3nn_compat_mode, dtype
):
torch.manual_seed(12345)
to_kwargs = {"device": device, "dtype": dtype}

in_irreps = o3.Irreps("10x0e + 10x1e")
out_irreps = o3.Irreps("20x0e + 10x1e")
Expand All @@ -55,68 +72,65 @@ def test_tensor_product_conv_equivariance(
mlp_activation=mlp_activation,
batch_norm=batch_norm,
e3nn_compat_mode=e3nn_compat_mode,
).to(device)
).to(**to_kwargs)

num_src_nodes, num_dst_nodes = 9, 7
num_edges = 40
src = torch.randint(num_src_nodes, (num_edges,), device=device)
dst = torch.randint(num_dst_nodes, (num_edges,), device=device)
edge_index = torch.vstack((src, dst))

src_pos = torch.randn(num_src_nodes, 3, device=device)
dst_pos = torch.randn(num_dst_nodes, 3, device=device)
edge_vec = dst_pos[dst] - src_pos[src]
edge_sh = o3.spherical_harmonics(
tp_conv.sh_irreps, edge_vec, normalize=True, normalization="component"
).to(device)
src_features = torch.randn(num_src_nodes, in_irreps.dim, device=device)
graph = create_random_graph(num_src_nodes, num_dst_nodes, num_edges, device=device)

edge_sh = torch.randn(num_edges, sh_irreps.dim, **to_kwargs)
src_features = torch.randn(num_src_nodes, in_irreps.dim, **to_kwargs)

rot = o3.rand_matrix()
D_in = tp_conv.in_irreps.D_from_matrix(rot).to(device)
D_sh = tp_conv.sh_irreps.D_from_matrix(rot).to(device)
D_out = tp_conv.out_irreps.D_from_matrix(rot).to(device)
D_in = tp_conv.in_irreps.D_from_matrix(rot).to(**to_kwargs)
D_sh = tp_conv.sh_irreps.D_from_matrix(rot).to(**to_kwargs)
D_out = tp_conv.out_irreps.D_from_matrix(rot).to(**to_kwargs)

if mlp_channels is None:
edge_emb = torch.randn(num_edges, tp_conv.tp.weight_numel, device=device)
edge_emb = torch.randn(num_edges, tp_conv.tp.weight_numel, **to_kwargs)
src_scalars = dst_scalars = None
else:
if scalar_sizes:
edge_emb = torch.randn(num_edges, scalar_sizes[0], device=device)
edge_emb = torch.randn(num_edges, scalar_sizes[0], **to_kwargs)
src_scalars = (
None
if scalar_sizes[1] == 0
else torch.randn(num_src_nodes, scalar_sizes[1], device=device)
else torch.randn(num_src_nodes, scalar_sizes[1], **to_kwargs)
)
dst_scalars = (
None
if scalar_sizes[2] == 0
else torch.randn(num_dst_nodes, scalar_sizes[2], device=device)
else torch.randn(num_dst_nodes, scalar_sizes[2], **to_kwargs)
)
else:
edge_emb = torch.randn(num_edges, tp_conv.mlp[0].in_features, device=device)
edge_emb = torch.randn(num_edges, tp_conv.mlp[0].in_features, **to_kwargs)
src_scalars = dst_scalars = None

# rotate before
torch.manual_seed(12345)
out_before = tp_conv(
src_features=src_features @ D_in.T,
edge_sh=edge_sh @ D_sh.T,
edge_emb=edge_emb,
graph=(edge_index, (num_src_nodes, num_dst_nodes)),
graph=graph,
src_scalars=src_scalars,
dst_scalars=dst_scalars,
)

# rotate after
torch.manual_seed(12345)
out_after = (
tp_conv(
src_features=src_features,
edge_sh=edge_sh,
edge_emb=edge_emb,
graph=(edge_index, (num_src_nodes, num_dst_nodes)),
graph=graph,
src_scalars=src_scalars,
dst_scalars=dst_scalars,
)
@ D_out.T
)

torch.allclose(out_before, out_after, rtol=1e-4, atol=1e-4)
atol = 1e-3 if dtype == torch.float32 else 1e-1
if e3nn_compat_mode:
assert torch.allclose(out_before, out_after, rtol=1e-4, atol=atol)
Loading