1
- # Copyright (c) 2023, NVIDIA CORPORATION.
1
+ # Copyright (c) 2023-2024 , NVIDIA CORPORATION.
2
2
# Licensed under the Apache License, Version 2.0 (the "License");
3
3
# you may not use this file except in compliance with the License.
4
4
# You may obtain a copy of the License at
15
15
from typing import Optional , Tuple , Union
16
16
17
17
from cugraph .utilities .utils import import_optional
18
- from pylibcugraphops .pytorch import CSC , HeteroCSC
18
+ import pylibcugraphops .pytorch
19
+
19
20
20
21
torch = import_optional ("torch" )
21
22
torch_geometric = import_optional ("torch_geometric" )
22
23
24
+ # A tuple of (row, colptr, num_src_nodes)
25
+ CSC = Tuple [torch .Tensor , torch .Tensor , int ]
26
+
23
27
24
28
class BaseConv (torch .nn .Module ): # pragma: no cover
25
29
r"""An abstract base class for implementing cugraph-ops message passing layers."""
@@ -33,10 +37,7 @@ def to_csc(
33
37
edge_index : torch .Tensor ,
34
38
size : Optional [Tuple [int , int ]] = None ,
35
39
edge_attr : Optional [torch .Tensor ] = None ,
36
- ) -> Union [
37
- Tuple [torch .Tensor , torch .Tensor , int ],
38
- Tuple [Tuple [torch .Tensor , torch .Tensor , int ], torch .Tensor ],
39
- ]:
40
+ ) -> Union [CSC , Tuple [CSC , torch .Tensor ],]:
40
41
r"""Returns a CSC representation of an :obj:`edge_index` tensor to be
41
42
used as input to cugraph-ops conv layers.
42
43
@@ -71,27 +72,31 @@ def to_csc(
71
72
72
73
def get_cugraph (
73
74
self ,
74
- csc : Tuple [ torch . Tensor , torch . Tensor , int ],
75
+ edge_index : Union [ torch_geometric . EdgeIndex , CSC ],
75
76
bipartite : bool = False ,
76
77
max_num_neighbors : Optional [int ] = None ,
77
- ) -> CSC :
78
+ ) -> Tuple [ pylibcugraphops . pytorch . CSC , Optional [ torch . Tensor ]] :
78
79
r"""Constructs a :obj:`cugraph-ops` graph object from CSC representation.
79
80
Supports both bipartite and non-bipartite graphs.
80
81
81
82
Args:
82
- csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
83
- representation of a graph, given as a tuple of
84
- :obj:`(row, colptr, num_src_nodes)`. Use the
85
- :meth:`to_csc` method to convert an :obj:`edge_index`
86
- representation to the desired format.
83
+ edge_index (EdgeIndex, (torch.Tensor, torch.Tensor, int)): The edge
84
+ indices, or a tuple of :obj:`(row, colptr, num_src_nodes)` for
85
+ CSC representation.
87
86
bipartite (bool): If set to :obj:`True`, will create the bipartite
88
87
structure in cugraph-ops. (default: :obj:`False`)
89
88
max_num_neighbors (int, optional): The maximum number of neighbors
90
89
of a destination node. When enabled, it allows models to use
91
90
the message-flow-graph primitives in cugraph-ops.
92
91
(default: :obj:`None`)
93
92
"""
94
- row , colptr , num_src_nodes = csc
93
+ perm = None
94
+ if isinstance (edge_index , torch_geometric .EdgeIndex ):
95
+ edge_index , perm = edge_index .sort_by ("col" )
96
+ num_src_nodes = edge_index .get_sparse_size (0 )
97
+ (colptr , row ), _ = edge_index .get_csc ()
98
+ else :
99
+ row , colptr , num_src_nodes = edge_index
95
100
96
101
if not row .is_cuda :
97
102
raise RuntimeError (
@@ -102,32 +107,33 @@ def get_cugraph(
102
107
if max_num_neighbors is None :
103
108
max_num_neighbors = - 1
104
109
105
- return CSC (
106
- offsets = colptr ,
107
- indices = row ,
108
- num_src_nodes = num_src_nodes ,
109
- dst_max_in_degree = max_num_neighbors ,
110
- is_bipartite = bipartite ,
110
+ return (
111
+ pylibcugraphops .pytorch .CSC (
112
+ offsets = colptr ,
113
+ indices = row ,
114
+ num_src_nodes = num_src_nodes ,
115
+ dst_max_in_degree = max_num_neighbors ,
116
+ is_bipartite = bipartite ,
117
+ ),
118
+ perm ,
111
119
)
112
120
113
121
def get_typed_cugraph (
114
122
self ,
115
- csc : Tuple [ torch . Tensor , torch . Tensor , int ],
123
+ edge_index : Union [ torch_geometric . EdgeIndex , CSC ],
116
124
edge_type : torch .Tensor ,
117
125
num_edge_types : Optional [int ] = None ,
118
126
bipartite : bool = False ,
119
127
max_num_neighbors : Optional [int ] = None ,
120
- ) -> HeteroCSC :
128
+ ) -> Tuple [ pylibcugraphops . pytorch . HeteroCSC , Optional [ torch . Tensor ]] :
121
129
r"""Constructs a typed :obj:`cugraph` graph object from a CSC
122
130
representation where each edge corresponds to a given edge type.
123
131
Supports both bipartite and non-bipartite graphs.
124
132
125
133
Args:
126
- csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
127
- representation of a graph, given as a tuple of
128
- :obj:`(row, colptr, num_src_nodes)`. Use the
129
- :meth:`to_csc` method to convert an :obj:`edge_index`
130
- representation to the desired format.
134
+ edge_index (EdgeIndex, (torch.Tensor, torch.Tensor, int)): The edge
135
+ indices, or a tuple of :obj:`(row, colptr, num_src_nodes)` for
136
+ CSC representation.
131
137
edge_type (torch.Tensor): The edge type.
132
138
num_edge_types (int, optional): The maximum number of edge types.
133
139
When not given, will be computed on-the-fly, leading to
@@ -145,32 +151,40 @@ def get_typed_cugraph(
145
151
if max_num_neighbors is None :
146
152
max_num_neighbors = - 1
147
153
148
- row , colptr , num_src_nodes = csc
154
+ perm = None
155
+ if isinstance (edge_index , torch_geometric .EdgeIndex ):
156
+ edge_index , perm = edge_index .sort_by ("col" )
157
+ edge_type = edge_type [perm ]
158
+ num_src_nodes = edge_index .get_sparse_size (0 )
159
+ (colptr , row ), _ = edge_index .get_csc ()
160
+ else :
161
+ row , colptr , num_src_nodes = edge_index
149
162
edge_type = edge_type .int ()
150
163
151
- return HeteroCSC (
152
- offsets = colptr ,
153
- indices = row ,
154
- edge_types = edge_type ,
155
- num_src_nodes = num_src_nodes ,
156
- num_edge_types = num_edge_types ,
157
- dst_max_in_degree = max_num_neighbors ,
158
- is_bipartite = bipartite ,
164
+ return (
165
+ pylibcugraphops .pytorch .HeteroCSC (
166
+ offsets = colptr ,
167
+ indices = row ,
168
+ edge_types = edge_type ,
169
+ num_src_nodes = num_src_nodes ,
170
+ num_edge_types = num_edge_types ,
171
+ dst_max_in_degree = max_num_neighbors ,
172
+ is_bipartite = bipartite ,
173
+ ),
174
+ perm ,
159
175
)
160
176
161
177
def forward (
162
178
self ,
163
179
x : Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]],
164
- csc : Tuple [ torch . Tensor , torch . Tensor , int ],
180
+ edge_index : Union [ torch_geometric . EdgeIndex , CSC ],
165
181
) -> torch .Tensor :
166
182
r"""Runs the forward pass of the module.
167
183
168
184
Args:
169
185
x (torch.Tensor): The node features.
170
- csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC
171
- representation of a graph, given as a tuple of
172
- :obj:`(row, colptr, num_src_nodes)`. Use the
173
- :meth:`to_csc` method to convert an :obj:`edge_index`
174
- representation to the desired format.
186
+ edge_index (EdgeIndex, (torch.Tensor, torch.Tensor, int)): The edge
187
+ indices, or a tuple of :obj:`(row, colptr, num_src_nodes)` for
188
+ CSC representation.
175
189
"""
176
190
raise NotImplementedError
0 commit comments