Skip to content

Update cugraph-pyg models for PyG 2.5 #4335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/test_wheel_cugraph-pyg.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ else
fi
rapids-logger "Installing PyTorch and PyG dependencies"
rapids-retry python -m pip install torch==2.1.0 --index-url ${PYTORCH_URL}
rapids-retry python -m pip install torch-geometric==2.4.0
rapids-retry python -m pip install "torch-geometric>=2.5,<2.6"
rapids-retry python -m pip install \
ogb \
pyg_lib \
Expand Down
2 changes: 1 addition & 1 deletion conda/recipes/cugraph-pyg/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ requirements:
- cupy >=12.0.0
- cugraph ={{ version }}
- pylibcugraphops ={{ minor_version }}
- pyg >=2.3,<2.5
- pyg >=2.5,<2.6

tests:
imports:
Expand Down
2 changes: 1 addition & 1 deletion dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ dependencies:
- cugraph==24.6.*
- pytorch>=2.0
- pytorch-cuda==11.8
- pyg>=2.4.0
- pyg>=2.5,<2.6

depends_on_rmm:
common:
Expand Down
2 changes: 1 addition & 1 deletion python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- cugraph==24.6.*
- pandas
- pre-commit
- pyg>=2.4.0
- pyg>=2.5,<2.6
- pylibcugraphops==24.6.*
- pytest
- pytest-benchmark
Expand Down
98 changes: 56 additions & 42 deletions python/cugraph-pyg/cugraph_pyg/nn/conv/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -15,11 +15,15 @@
from typing import Optional, Tuple, Union

from cugraph.utilities.utils import import_optional
from pylibcugraphops.pytorch import CSC, HeteroCSC
import pylibcugraphops.pytorch


torch = import_optional("torch")
torch_geometric = import_optional("torch_geometric")

# A tuple of (row, colptr, num_src_nodes)
CSC = Tuple[torch.Tensor, torch.Tensor, int]


class BaseConv(torch.nn.Module): # pragma: no cover
r"""An abstract base class for implementing cugraph-ops message passing layers."""
Expand All @@ -33,10 +37,7 @@ def to_csc(
edge_index: torch.Tensor,
size: Optional[Tuple[int, int]] = None,
edge_attr: Optional[torch.Tensor] = None,
) -> Union[
Tuple[torch.Tensor, torch.Tensor, int],
Tuple[Tuple[torch.Tensor, torch.Tensor, int], torch.Tensor],
]:
) -> Union[CSC, Tuple[CSC, torch.Tensor],]:
r"""Returns a CSC representation of an :obj:`edge_index` tensor to be
used as input to cugraph-ops conv layers.

Expand Down Expand Up @@ -71,27 +72,31 @@ def to_csc(

def get_cugraph(
self,
csc: Tuple[torch.Tensor, torch.Tensor, int],
edge_index: Union[torch_geometric.EdgeIndex, CSC],
bipartite: bool = False,
max_num_neighbors: Optional[int] = None,
) -> CSC:
) -> Tuple[pylibcugraphops.pytorch.CSC, Optional[torch.Tensor]]:
r"""Constructs a :obj:`cugraph-ops` graph object from CSC representation.
Supports both bipartite and non-bipartite graphs.

Args:
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
representation of a graph, given as a tuple of
:obj:`(row, colptr, num_src_nodes)`. Use the
:meth:`to_csc` method to convert an :obj:`edge_index`
representation to the desired format.
edge_index (EdgeIndex, (torch.Tensor, torch.Tensor, int)): The edge
indices, or a tuple of :obj:`(row, colptr, num_src_nodes)` for
CSC representation.
bipartite (bool): If set to :obj:`True`, will create the bipartite
structure in cugraph-ops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors
of a destination node. When enabled, it allows models to use
the message-flow-graph primitives in cugraph-ops.
(default: :obj:`None`)
"""
row, colptr, num_src_nodes = csc
perm = None
if isinstance(edge_index, torch_geometric.EdgeIndex):
edge_index, perm = edge_index.sort_by("col")
num_src_nodes = edge_index.get_sparse_size(0)
(colptr, row), _ = edge_index.get_csc()
else:
row, colptr, num_src_nodes = edge_index

if not row.is_cuda:
raise RuntimeError(
Expand All @@ -102,32 +107,33 @@ def get_cugraph(
if max_num_neighbors is None:
max_num_neighbors = -1

return CSC(
offsets=colptr,
indices=row,
num_src_nodes=num_src_nodes,
dst_max_in_degree=max_num_neighbors,
is_bipartite=bipartite,
return (
pylibcugraphops.pytorch.CSC(
offsets=colptr,
indices=row,
num_src_nodes=num_src_nodes,
dst_max_in_degree=max_num_neighbors,
is_bipartite=bipartite,
),
perm,
)

def get_typed_cugraph(
self,
csc: Tuple[torch.Tensor, torch.Tensor, int],
edge_index: Union[torch_geometric.EdgeIndex, CSC],
edge_type: torch.Tensor,
num_edge_types: Optional[int] = None,
bipartite: bool = False,
max_num_neighbors: Optional[int] = None,
) -> HeteroCSC:
) -> Tuple[pylibcugraphops.pytorch.HeteroCSC, Optional[torch.Tensor]]:
r"""Constructs a typed :obj:`cugraph` graph object from a CSC
representation where each edge corresponds to a given edge type.
Supports both bipartite and non-bipartite graphs.

Args:
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
representation of a graph, given as a tuple of
:obj:`(row, colptr, num_src_nodes)`. Use the
:meth:`to_csc` method to convert an :obj:`edge_index`
representation to the desired format.
edge_index (EdgeIndex, (torch.Tensor, torch.Tensor, int)): The edge
indices, or a tuple of :obj:`(row, colptr, num_src_nodes)` for
CSC representation.
edge_type (torch.Tensor): The edge type.
num_edge_types (int, optional): The maximum number of edge types.
When not given, will be computed on-the-fly, leading to
Expand All @@ -145,32 +151,40 @@ def get_typed_cugraph(
if max_num_neighbors is None:
max_num_neighbors = -1

row, colptr, num_src_nodes = csc
perm = None
if isinstance(edge_index, torch_geometric.EdgeIndex):
edge_index, perm = edge_index.sort_by("col")
edge_type = edge_type[perm]
num_src_nodes = edge_index.get_sparse_size(0)
(colptr, row), _ = edge_index.get_csc()
else:
row, colptr, num_src_nodes = edge_index
edge_type = edge_type.int()

return HeteroCSC(
offsets=colptr,
indices=row,
edge_types=edge_type,
num_src_nodes=num_src_nodes,
num_edge_types=num_edge_types,
dst_max_in_degree=max_num_neighbors,
is_bipartite=bipartite,
return (
pylibcugraphops.pytorch.HeteroCSC(
offsets=colptr,
indices=row,
edge_types=edge_type,
num_src_nodes=num_src_nodes,
num_edge_types=num_edge_types,
dst_max_in_degree=max_num_neighbors,
is_bipartite=bipartite,
),
perm,
)

def forward(
self,
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
csc: Tuple[torch.Tensor, torch.Tensor, int],
edge_index: Union[torch_geometric.EdgeIndex, CSC],
) -> torch.Tensor:
r"""Runs the forward pass of the module.

Args:
x (torch.Tensor): The node features.
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
representation of a graph, given as a tuple of
:obj:`(row, colptr, num_src_nodes)`. Use the
:meth:`to_csc` method to convert an :obj:`edge_index`
representation to the desired format.
edge_index (EdgeIndex, (torch.Tensor, torch.Tensor, int)): The edge
indices, or a tuple of :obj:`(row, colptr, num_src_nodes)` for
CSC representation.
"""
raise NotImplementedError
19 changes: 10 additions & 9 deletions python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from cugraph.utilities.utils import import_optional
from pylibcugraphops.pytorch.operators import mha_gat_n2n

from .base import BaseConv
from .base import BaseConv, CSC

torch = import_optional("torch")
nn = import_optional("torch.nn")
Expand Down Expand Up @@ -159,7 +159,7 @@ def reset_parameters(self):
def forward(
self,
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
csc: Tuple[torch.Tensor, torch.Tensor, int],
edge_index: Union[torch_geometric.EdgeIndex, CSC],
edge_attr: Optional[torch.Tensor] = None,
max_num_neighbors: Optional[int] = None,
deterministic_dgrad: bool = False,
Expand All @@ -172,11 +172,7 @@ def forward(
Args:
x (torch.Tensor or tuple): The node features. Can be a tuple of
tensors denoting source and destination node features.
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
representation of a graph, given as a tuple of
:obj:`(row, colptr, num_src_nodes)`. Use the
:meth:`to_csc` method to convert an :obj:`edge_index`
representation to the desired format.
edge_index (EdgeIndex or CSC): The edge indices.
edge_attr: (torch.Tensor, optional) The edge features.
max_num_neighbors (int, optional): The maximum number of neighbors
of a destination node. When enabled, it allows models to use
Expand All @@ -198,9 +194,12 @@ def forward(
the corresponding input type at the very end.
"""
bipartite = not isinstance(x, torch.Tensor)
graph = self.get_cugraph(
csc, bipartite=bipartite, max_num_neighbors=max_num_neighbors
graph, perm = self.get_cugraph(
edge_index=edge_index,
bipartite=bipartite,
max_num_neighbors=max_num_neighbors,
)

if deterministic_dgrad:
graph.add_reverse_graph()

Expand All @@ -212,6 +211,8 @@ def forward(
)
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
if perm is not None:
edge_attr = edge_attr[perm]
edge_attr = self.lin_edge(edge_attr)

if bipartite:
Expand Down
14 changes: 6 additions & 8 deletions python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from cugraph.utilities.utils import import_optional
from pylibcugraphops.pytorch.operators import mha_gat_v2_n2n

from .base import BaseConv
from .base import BaseConv, CSC

torch = import_optional("torch")
nn = import_optional("torch.nn")
Expand Down Expand Up @@ -172,7 +172,7 @@ def reset_parameters(self):
def forward(
self,
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
csc: Tuple[torch.Tensor, torch.Tensor, int],
edge_index: Union[torch_geometric.EdgeIndex, CSC],
edge_attr: Optional[torch.Tensor] = None,
deterministic_dgrad: bool = False,
deterministic_wgrad: bool = False,
Expand All @@ -182,11 +182,7 @@ def forward(
Args:
x (torch.Tensor or tuple): The node features. Can be a tuple of
tensors denoting source and destination node features.
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
representation of a graph, given as a tuple of
:obj:`(row, colptr, num_src_nodes)`. Use the
:meth:`to_csc` method to convert an :obj:`edge_index`
representation to the desired format.
edge_index (EdgeIndex or CSC): The edge indices.
edge_attr: (torch.Tensor, optional) The edge features.
deterministic_dgrad : bool, default=False
Optional flag indicating whether the feature gradients
Expand All @@ -196,7 +192,7 @@ def forward(
are computed deterministically using a dedicated workspace buffer.
"""
bipartite = not isinstance(x, torch.Tensor) or not self.share_weights
graph = self.get_cugraph(csc, bipartite=bipartite)
graph, perm = self.get_cugraph(edge_index, bipartite=bipartite)
if deterministic_dgrad:
graph.add_reverse_graph()

Expand All @@ -208,6 +204,8 @@ def forward(
)
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
if perm is not None:
edge_attr = edge_attr[perm]
edge_attr = self.lin_edge(edge_attr)

if bipartite:
Expand Down
15 changes: 8 additions & 7 deletions python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -18,6 +18,7 @@
from pylibcugraphops.pytorch.operators import mha_gat_n2n

from .base import BaseConv
from cugraph_pyg.utils.imports import package_available

torch = import_optional("torch")
torch_geometric = import_optional("torch_geometric")
Expand Down Expand Up @@ -74,10 +75,10 @@ def __init__(
bias: bool = True,
aggr: str = "sum",
):
major, minor, patch = torch_geometric.__version__.split(".")[:3]
pyg_version = tuple(map(int, [major, minor, patch]))
if pyg_version < (2, 4, 0):
raise RuntimeError(f"{self.__class__.__name__} requires pyg >= 2.4.0.")
if not package_available("torch_geometric>=2.4.0"):
raise RuntimeError(
f"{self.__class__.__name__} requires torch_geometric>=2.4.0."
)

super().__init__()

Expand Down Expand Up @@ -225,7 +226,7 @@ def forward(
)

if src_type == dst_type:
graph = self.get_cugraph(
graph, _ = self.get_cugraph(
csc,
bipartite=False,
)
Expand All @@ -240,7 +241,7 @@ def forward(
)

else:
graph = self.get_cugraph(
graph, _ = self.get_cugraph(
csc,
bipartite=True,
)
Expand Down
Loading
Loading