Skip to content

Commit ac98ce7

Browse files
authored
Update cugraph-pyg models for PyG 2.5 (#4335)
- Support [`torch_geometric.EdgeIndex`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.edge_index.EdgeIndex.html) in models. - Drop support for PyG < 2.5 Breaking changes: The `csc` argument in `model.forward()` has been renamed to `edge_index` to align with upstream models. For now, users can still pass in CSC tuples (generated from `to_csc` call) as `edge_index`. Authors: - Tingyu Wang (https://github.com/tingyu66) - Alex Barghi (https://github.com/alexbarghi-nv) Approvers: - Alex Barghi (https://github.com/alexbarghi-nv) - Jake Awe (https://github.com/AyodeAwe) URL: #4335
1 parent 37b67c9 commit ac98ce7

23 files changed

+368
-201
lines changed

ci/test_wheel_cugraph-pyg.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ else
3333
fi
3434
rapids-logger "Installing PyTorch and PyG dependencies"
3535
rapids-retry python -m pip install torch==2.1.0 --index-url ${PYTORCH_URL}
36-
rapids-retry python -m pip install torch-geometric==2.4.0
36+
rapids-retry python -m pip install "torch-geometric>=2.5,<2.6"
3737
rapids-retry python -m pip install \
3838
ogb \
3939
pyg_lib \

conda/recipes/cugraph-pyg/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ requirements:
3434
- cupy >=12.0.0
3535
- cugraph ={{ version }}
3636
- pylibcugraphops ={{ minor_version }}
37-
- pyg >=2.3,<2.5
37+
- pyg >=2.5,<2.6
3838

3939
tests:
4040
imports:

dependencies.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ dependencies:
565565
- cugraph==24.6.*
566566
- pytorch>=2.0
567567
- pytorch-cuda==11.8
568-
- pyg>=2.4.0
568+
- pyg>=2.5,<2.6
569569

570570
depends_on_rmm:
571571
common:

python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- cugraph==24.6.*
1313
- pandas
1414
- pre-commit
15-
- pyg>=2.4.0
15+
- pyg>=2.5,<2.6
1616
- pylibcugraphops==24.6.*
1717
- pytest
1818
- pytest-benchmark

python/cugraph-pyg/cugraph_pyg/nn/conv/base.py

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023, NVIDIA CORPORATION.
1+
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -15,11 +15,15 @@
1515
from typing import Optional, Tuple, Union
1616

1717
from cugraph.utilities.utils import import_optional
18-
from pylibcugraphops.pytorch import CSC, HeteroCSC
18+
import pylibcugraphops.pytorch
19+
1920

2021
torch = import_optional("torch")
2122
torch_geometric = import_optional("torch_geometric")
2223

24+
# A tuple of (row, colptr, num_src_nodes)
25+
CSC = Tuple[torch.Tensor, torch.Tensor, int]
26+
2327

2428
class BaseConv(torch.nn.Module): # pragma: no cover
2529
r"""An abstract base class for implementing cugraph-ops message passing layers."""
@@ -33,10 +37,7 @@ def to_csc(
3337
edge_index: torch.Tensor,
3438
size: Optional[Tuple[int, int]] = None,
3539
edge_attr: Optional[torch.Tensor] = None,
36-
) -> Union[
37-
Tuple[torch.Tensor, torch.Tensor, int],
38-
Tuple[Tuple[torch.Tensor, torch.Tensor, int], torch.Tensor],
39-
]:
40+
) -> Union[CSC, Tuple[CSC, torch.Tensor],]:
4041
r"""Returns a CSC representation of an :obj:`edge_index` tensor to be
4142
used as input to cugraph-ops conv layers.
4243
@@ -71,27 +72,31 @@ def to_csc(
7172

7273
def get_cugraph(
7374
self,
74-
csc: Tuple[torch.Tensor, torch.Tensor, int],
75+
edge_index: Union[torch_geometric.EdgeIndex, CSC],
7576
bipartite: bool = False,
7677
max_num_neighbors: Optional[int] = None,
77-
) -> CSC:
78+
) -> Tuple[pylibcugraphops.pytorch.CSC, Optional[torch.Tensor]]:
7879
r"""Constructs a :obj:`cugraph-ops` graph object from CSC representation.
7980
Supports both bipartite and non-bipartite graphs.
8081
8182
Args:
82-
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
83-
representation of a graph, given as a tuple of
84-
:obj:`(row, colptr, num_src_nodes)`. Use the
85-
:meth:`to_csc` method to convert an :obj:`edge_index`
86-
representation to the desired format.
83+
edge_index (EdgeIndex, (torch.Tensor, torch.Tensor, int)): The edge
84+
indices, or a tuple of :obj:`(row, colptr, num_src_nodes)` for
85+
CSC representation.
8786
bipartite (bool): If set to :obj:`True`, will create the bipartite
8887
structure in cugraph-ops. (default: :obj:`False`)
8988
max_num_neighbors (int, optional): The maximum number of neighbors
9089
of a destination node. When enabled, it allows models to use
9190
the message-flow-graph primitives in cugraph-ops.
9291
(default: :obj:`None`)
9392
"""
94-
row, colptr, num_src_nodes = csc
93+
perm = None
94+
if isinstance(edge_index, torch_geometric.EdgeIndex):
95+
edge_index, perm = edge_index.sort_by("col")
96+
num_src_nodes = edge_index.get_sparse_size(0)
97+
(colptr, row), _ = edge_index.get_csc()
98+
else:
99+
row, colptr, num_src_nodes = edge_index
95100

96101
if not row.is_cuda:
97102
raise RuntimeError(
@@ -102,32 +107,33 @@ def get_cugraph(
102107
if max_num_neighbors is None:
103108
max_num_neighbors = -1
104109

105-
return CSC(
106-
offsets=colptr,
107-
indices=row,
108-
num_src_nodes=num_src_nodes,
109-
dst_max_in_degree=max_num_neighbors,
110-
is_bipartite=bipartite,
110+
return (
111+
pylibcugraphops.pytorch.CSC(
112+
offsets=colptr,
113+
indices=row,
114+
num_src_nodes=num_src_nodes,
115+
dst_max_in_degree=max_num_neighbors,
116+
is_bipartite=bipartite,
117+
),
118+
perm,
111119
)
112120

113121
def get_typed_cugraph(
114122
self,
115-
csc: Tuple[torch.Tensor, torch.Tensor, int],
123+
edge_index: Union[torch_geometric.EdgeIndex, CSC],
116124
edge_type: torch.Tensor,
117125
num_edge_types: Optional[int] = None,
118126
bipartite: bool = False,
119127
max_num_neighbors: Optional[int] = None,
120-
) -> HeteroCSC:
128+
) -> Tuple[pylibcugraphops.pytorch.HeteroCSC, Optional[torch.Tensor]]:
121129
r"""Constructs a typed :obj:`cugraph` graph object from a CSC
122130
representation where each edge corresponds to a given edge type.
123131
Supports both bipartite and non-bipartite graphs.
124132
125133
Args:
126-
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
127-
representation of a graph, given as a tuple of
128-
:obj:`(row, colptr, num_src_nodes)`. Use the
129-
:meth:`to_csc` method to convert an :obj:`edge_index`
130-
representation to the desired format.
134+
edge_index (EdgeIndex, (torch.Tensor, torch.Tensor, int)): The edge
135+
indices, or a tuple of :obj:`(row, colptr, num_src_nodes)` for
136+
CSC representation.
131137
edge_type (torch.Tensor): The edge type.
132138
num_edge_types (int, optional): The maximum number of edge types.
133139
When not given, will be computed on-the-fly, leading to
@@ -145,32 +151,40 @@ def get_typed_cugraph(
145151
if max_num_neighbors is None:
146152
max_num_neighbors = -1
147153

148-
row, colptr, num_src_nodes = csc
154+
perm = None
155+
if isinstance(edge_index, torch_geometric.EdgeIndex):
156+
edge_index, perm = edge_index.sort_by("col")
157+
edge_type = edge_type[perm]
158+
num_src_nodes = edge_index.get_sparse_size(0)
159+
(colptr, row), _ = edge_index.get_csc()
160+
else:
161+
row, colptr, num_src_nodes = edge_index
149162
edge_type = edge_type.int()
150163

151-
return HeteroCSC(
152-
offsets=colptr,
153-
indices=row,
154-
edge_types=edge_type,
155-
num_src_nodes=num_src_nodes,
156-
num_edge_types=num_edge_types,
157-
dst_max_in_degree=max_num_neighbors,
158-
is_bipartite=bipartite,
164+
return (
165+
pylibcugraphops.pytorch.HeteroCSC(
166+
offsets=colptr,
167+
indices=row,
168+
edge_types=edge_type,
169+
num_src_nodes=num_src_nodes,
170+
num_edge_types=num_edge_types,
171+
dst_max_in_degree=max_num_neighbors,
172+
is_bipartite=bipartite,
173+
),
174+
perm,
159175
)
160176

161177
def forward(
162178
self,
163179
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
164-
csc: Tuple[torch.Tensor, torch.Tensor, int],
180+
edge_index: Union[torch_geometric.EdgeIndex, CSC],
165181
) -> torch.Tensor:
166182
r"""Runs the forward pass of the module.
167183
168184
Args:
169185
x (torch.Tensor): The node features.
170-
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
171-
representation of a graph, given as a tuple of
172-
:obj:`(row, colptr, num_src_nodes)`. Use the
173-
:meth:`to_csc` method to convert an :obj:`edge_index`
174-
representation to the desired format.
186+
edge_index (EdgeIndex, (torch.Tensor, torch.Tensor, int)): The edge
187+
indices, or a tuple of :obj:`(row, colptr, num_src_nodes)` for
188+
CSC representation.
175189
"""
176190
raise NotImplementedError

python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from cugraph.utilities.utils import import_optional
1717
from pylibcugraphops.pytorch.operators import mha_gat_n2n
1818

19-
from .base import BaseConv
19+
from .base import BaseConv, CSC
2020

2121
torch = import_optional("torch")
2222
nn = import_optional("torch.nn")
@@ -159,7 +159,7 @@ def reset_parameters(self):
159159
def forward(
160160
self,
161161
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
162-
csc: Tuple[torch.Tensor, torch.Tensor, int],
162+
edge_index: Union[torch_geometric.EdgeIndex, CSC],
163163
edge_attr: Optional[torch.Tensor] = None,
164164
max_num_neighbors: Optional[int] = None,
165165
deterministic_dgrad: bool = False,
@@ -172,11 +172,7 @@ def forward(
172172
Args:
173173
x (torch.Tensor or tuple): The node features. Can be a tuple of
174174
tensors denoting source and destination node features.
175-
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
176-
representation of a graph, given as a tuple of
177-
:obj:`(row, colptr, num_src_nodes)`. Use the
178-
:meth:`to_csc` method to convert an :obj:`edge_index`
179-
representation to the desired format.
175+
edge_index (EdgeIndex or CSC): The edge indices.
180176
edge_attr: (torch.Tensor, optional) The edge features.
181177
max_num_neighbors (int, optional): The maximum number of neighbors
182178
of a destination node. When enabled, it allows models to use
@@ -198,9 +194,12 @@ def forward(
198194
the corresponding input type at the very end.
199195
"""
200196
bipartite = not isinstance(x, torch.Tensor)
201-
graph = self.get_cugraph(
202-
csc, bipartite=bipartite, max_num_neighbors=max_num_neighbors
197+
graph, perm = self.get_cugraph(
198+
edge_index=edge_index,
199+
bipartite=bipartite,
200+
max_num_neighbors=max_num_neighbors,
203201
)
202+
204203
if deterministic_dgrad:
205204
graph.add_reverse_graph()
206205

@@ -212,6 +211,8 @@ def forward(
212211
)
213212
if edge_attr.dim() == 1:
214213
edge_attr = edge_attr.view(-1, 1)
214+
if perm is not None:
215+
edge_attr = edge_attr[perm]
215216
edge_attr = self.lin_edge(edge_attr)
216217

217218
if bipartite:

python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from cugraph.utilities.utils import import_optional
1717
from pylibcugraphops.pytorch.operators import mha_gat_v2_n2n
1818

19-
from .base import BaseConv
19+
from .base import BaseConv, CSC
2020

2121
torch = import_optional("torch")
2222
nn = import_optional("torch.nn")
@@ -172,7 +172,7 @@ def reset_parameters(self):
172172
def forward(
173173
self,
174174
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
175-
csc: Tuple[torch.Tensor, torch.Tensor, int],
175+
edge_index: Union[torch_geometric.EdgeIndex, CSC],
176176
edge_attr: Optional[torch.Tensor] = None,
177177
deterministic_dgrad: bool = False,
178178
deterministic_wgrad: bool = False,
@@ -182,11 +182,7 @@ def forward(
182182
Args:
183183
x (torch.Tensor or tuple): The node features. Can be a tuple of
184184
tensors denoting source and destination node features.
185-
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
186-
representation of a graph, given as a tuple of
187-
:obj:`(row, colptr, num_src_nodes)`. Use the
188-
:meth:`to_csc` method to convert an :obj:`edge_index`
189-
representation to the desired format.
185+
edge_index (EdgeIndex or CSC): The edge indices.
190186
edge_attr: (torch.Tensor, optional) The edge features.
191187
deterministic_dgrad : bool, default=False
192188
Optional flag indicating whether the feature gradients
@@ -196,7 +192,7 @@ def forward(
196192
are computed deterministically using a dedicated workspace buffer.
197193
"""
198194
bipartite = not isinstance(x, torch.Tensor) or not self.share_weights
199-
graph = self.get_cugraph(csc, bipartite=bipartite)
195+
graph, perm = self.get_cugraph(edge_index, bipartite=bipartite)
200196
if deterministic_dgrad:
201197
graph.add_reverse_graph()
202198

@@ -208,6 +204,8 @@ def forward(
208204
)
209205
if edge_attr.dim() == 1:
210206
edge_attr = edge_attr.view(-1, 1)
207+
if perm is not None:
208+
edge_attr = edge_attr[perm]
211209
edge_attr = self.lin_edge(edge_attr)
212210

213211
if bipartite:

python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023, NVIDIA CORPORATION.
1+
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -18,6 +18,7 @@
1818
from pylibcugraphops.pytorch.operators import mha_gat_n2n
1919

2020
from .base import BaseConv
21+
from cugraph_pyg.utils.imports import package_available
2122

2223
torch = import_optional("torch")
2324
torch_geometric = import_optional("torch_geometric")
@@ -74,10 +75,10 @@ def __init__(
7475
bias: bool = True,
7576
aggr: str = "sum",
7677
):
77-
major, minor, patch = torch_geometric.__version__.split(".")[:3]
78-
pyg_version = tuple(map(int, [major, minor, patch]))
79-
if pyg_version < (2, 4, 0):
80-
raise RuntimeError(f"{self.__class__.__name__} requires pyg >= 2.4.0.")
78+
if not package_available("torch_geometric>=2.4.0"):
79+
raise RuntimeError(
80+
f"{self.__class__.__name__} requires torch_geometric>=2.4.0."
81+
)
8182

8283
super().__init__()
8384

@@ -225,7 +226,7 @@ def forward(
225226
)
226227

227228
if src_type == dst_type:
228-
graph = self.get_cugraph(
229+
graph, _ = self.get_cugraph(
229230
csc,
230231
bipartite=False,
231232
)
@@ -240,7 +241,7 @@ def forward(
240241
)
241242

242243
else:
243-
graph = self.get_cugraph(
244+
graph, _ = self.get_cugraph(
244245
csc,
245246
bipartite=True,
246247
)

0 commit comments

Comments
 (0)