Skip to content

Commit 797a036

Browse files
Distributed Sampling in cuGraph-PyG (#4384)
Distributed sampling in cuGraph-PyG. Also renames the existing API to clarify that it is dask based. Adds a dependency on `tensordict` for `cuGraph-PyG` which supports the new `TensorDictFeatureStore`. Also no longer installs `torch-cluster` and `torch-spline-conv` in CI for testing since that results in an `ImportError` and neither of those packages are needed. Requires PyG 2.5. Should be merged after #4335 Merge after #4355 Closes #4248 Closes #4249 Closes #3383 Closes #3942 Closes #3836 Closes #4202 Closes #4051 Closes #4326 Closes #4252 Partially addresses #3805 Authors: - Alex Barghi (https://github.com/alexbarghi-nv) - Seunghwa Kang (https://github.com/seunghwak) - Tingyu Wang (https://github.com/tingyu66) - Ralph Liu (https://github.com/nv-rliu) Approvers: - Tingyu Wang (https://github.com/tingyu66) - Brad Rees (https://github.com/BradReesWork) - Jake Awe (https://github.com/AyodeAwe) URL: #4384
1 parent 563c06e commit 797a036

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2465
-229
lines changed

ci/run_cugraph_pyg_pytests.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ set -euo pipefail
66
# Support invoking run_cugraph_pyg_pytests.sh outside the script directory
77
cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../python/cugraph-pyg/cugraph_pyg
88

9-
pytest --cache-clear --ignore=tests/mg "$@" .
9+
pytest --cache-clear --benchmark-disable "$@" .
10+
11+
# Used to skip certain examples in CI due to memory limitations
12+
export CI_RUN=1
1013

1114
# Test examples
1215
for e in "$(pwd)"/examples/*.py; do

ci/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ if hasArg "--run-python-tests"; then
103103
conda list
104104
cd ${CUGRAPH_ROOT}/python/cugraph-pyg/cugraph_pyg
105105
# rmat is not tested because of MG testing
106-
pytest --cache-clear --junitxml=${CUGRAPH_ROOT}/junit-cugraph-pytests.xml -v --cov-config=.coveragerc --cov=cugraph_pyg --cov-report=xml:${WORKSPACE}/python/cugraph_pyg/cugraph-coverage.xml --cov-report term --ignore=raft --ignore=tests/mg --ignore=tests/int --ignore=tests/generators --benchmark-disable
106+
pytest -sv -m sg --cache-clear --junitxml=${CUGRAPH_ROOT}/junit-cugraph-pytests.xml -v --cov-config=.coveragerc --cov=cugraph_pyg --cov-report=xml:${WORKSPACE}/python/cugraph_pyg/cugraph-coverage.xml --cov-report term --ignore=raft --benchmark-disable
107107
echo "Ran Python pytest for cugraph_pyg : return code was: $?, test script exit code is now: $EXITCODE"
108108

109109
echo "Python pytest for cugraph-service (single-GPU only)..."

ci/test_python.sh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,14 @@ if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then
215215

216216
# Install pyg dependencies (which requires pip)
217217

218-
pip install ogb
218+
pip install \
219+
ogb \
220+
tensordict
221+
219222
pip install \
220223
pyg_lib \
221224
torch_scatter \
222225
torch_sparse \
223-
torch_cluster \
224-
torch_spline_conv \
225226
-f ${PYG_URL}
226227

227228
rapids-print-env

ci/test_wheel_cugraph-pyg.sh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ python -m pip install $(ls ./dist/${python_package_name}*.whl)[test]
2424
# RAPIDS_DATASET_ROOT_DIR is used by test scripts
2525
export RAPIDS_DATASET_ROOT_DIR="$(realpath datasets)"
2626

27+
# Used to skip certain examples in CI due to memory limitations
28+
export CI_RUN=1
29+
2730
if [[ "${CUDA_VERSION}" == "11.8.0" ]]; then
2831
PYTORCH_URL="https://download.pytorch.org/whl/cu118"
2932
PYG_URL="https://data.pyg.org/whl/torch-2.1.0+cu118.html"
@@ -39,15 +42,14 @@ rapids-retry python -m pip install \
3942
pyg_lib \
4043
torch_scatter \
4144
torch_sparse \
42-
torch_cluster \
43-
torch_spline_conv \
45+
tensordict \
4446
-f ${PYG_URL}
4547

4648
rapids-logger "pytest cugraph-pyg (single GPU)"
4749
pushd python/cugraph-pyg/cugraph_pyg
4850
python -m pytest \
4951
--cache-clear \
50-
--ignore=tests/mg \
52+
--benchmark-disable \
5153
tests
5254
# Test examples
5355
for e in "$(pwd)"/examples/*.py; do

conda/recipes/cugraph-pyg/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ requirements:
3434
- cupy >=12.0.0
3535
- cugraph ={{ version }}
3636
- pylibcugraphops ={{ minor_version }}
37+
- tensordict >=0.1.2
3738
- pyg >=2.5,<2.6
3839

3940
tests:

dependencies.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ dependencies:
565565
- cugraph==24.6.*
566566
- pytorch>=2.0
567567
- pytorch-cuda==11.8
568+
- tensordict>=0.1.2
568569
- pyg>=2.5,<2.6
569570

570571
depends_on_rmm:

docs/cugraph/source/api_docs/cugraph-pyg/cugraph_pyg.rst

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,37 @@ cugraph-pyg
66

77
.. currentmodule:: cugraph_pyg
88

9+
Graph Storage
10+
-------------
911
.. autosummary::
1012
:toctree: ../api/cugraph-pyg/
1113

12-
.. cugraph_pyg.data.cugraph_store.EXPERIMENTAL__CuGraphStore
13-
.. cugraph_pyg.sampler.cugraph_sampler.EXPERIMENTAL__CuGraphSampler
14+
cugraph_pyg.data.dask_graph_store.DaskGraphStore
15+
cugraph_pyg.data.graph_store.GraphStore
16+
17+
Feature Storage
18+
---------------
19+
.. autosummary::
20+
:toctree: ../api/cugraph-pyg/
21+
22+
cugraph_pyg.data.feature_store.TensorDictFeatureStore
23+
24+
Data Loaders
25+
------------
26+
.. autosummary::
27+
:toctree: ../api/cugraph-pyg/
28+
29+
cugraph_pyg.loader.dask_node_loader.DaskNeighborLoader
30+
cugraph_pyg.loader.dask_node_loader.BulkSampleLoader
31+
cugraph_pyg.loader.node_loader.NodeLoader
32+
cugraph_pyg.loader.neighbor_loader.NeighborLoader
33+
34+
Samplers
35+
--------
36+
.. autosummary::
37+
:toctree: ../api/cugraph-pyg/
38+
39+
cugraph_pyg.sampler.sampler.BaseSampler
40+
cugraph_pyg.sampler.sampler.SampleReader
41+
cugraph_pyg.sampler.sampler.HomogeneousSampleReader
42+
cugraph_pyg.sampler.sampler.SampleIterator

python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ dependencies:
2121
- pytorch-cuda==11.8
2222
- pytorch>=2.0
2323
- scipy
24+
- tensordict>=0.1.2
2425
name: cugraph_pyg_dev_cuda-118

python/cugraph-pyg/cugraph_pyg/data/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
1+
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -11,4 +11,13 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
from cugraph_pyg.data.cugraph_store import CuGraphStore
14+
import warnings
15+
16+
from cugraph_pyg.data.dask_graph_store import DaskGraphStore
17+
from cugraph_pyg.data.graph_store import GraphStore
18+
from cugraph_pyg.data.feature_store import TensorDictFeatureStore
19+
20+
21+
def CuGraphStore(*args, **kwargs):
22+
warnings.warn("CuGraphStore has been renamed to DaskGraphStore", FutureWarning)
23+
return DaskGraphStore(*args, **kwargs)

python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py renamed to python/cugraph-pyg/cugraph_pyg/data/dask_graph_store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def cast(cls, *args, **kwargs):
199199
return cls(*args, **kwargs)
200200

201201

202-
class CuGraphStore:
202+
class DaskGraphStore:
203203
"""
204204
Duck-typed version of PyG's GraphStore and FeatureStore.
205205
"""
@@ -221,7 +221,7 @@ def __init__(
221221
order: str = "CSR",
222222
):
223223
"""
224-
Constructs a new CuGraphStore from the provided
224+
Constructs a new DaskGraphStore from the provided
225225
arguments.
226226
227227
Parameters
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import warnings
15+
16+
from typing import Optional, Tuple, List
17+
18+
from cugraph.utilities.utils import import_optional, MissingModule
19+
20+
torch = import_optional("torch")
21+
torch_geometric = import_optional("torch_geometric")
22+
tensordict = import_optional("tensordict")
23+
24+
25+
class TensorDictFeatureStore(
26+
object
27+
if isinstance(torch_geometric, MissingModule)
28+
else torch_geometric.data.FeatureStore
29+
):
30+
"""
31+
A basic implementation of the PyG FeatureStore interface that stores
32+
feature data in a single TensorDict. This type of feature store is
33+
not distributed, so each node will have to load the entire graph's
34+
features into memory.
35+
"""
36+
37+
def __init__(self):
38+
super().__init__()
39+
40+
self.__features = {}
41+
42+
def _put_tensor(
43+
self,
44+
tensor: "torch_geometric.typing.FeatureTensorType",
45+
attr: "torch_geometric.data.feature_store.TensorAttr",
46+
) -> bool:
47+
if attr.group_name in self.__features:
48+
td = self.__features[attr.group_name]
49+
batch_size = td.batch_size[0]
50+
51+
if attr.is_set("index"):
52+
if attr.attr_name in td.keys():
53+
if attr.index.shape[0] != batch_size:
54+
raise ValueError(
55+
"Leading size of index tensor "
56+
"does not match existing tensors for group name "
57+
f"{attr.group_name}; Expected {batch_size}, "
58+
f"got {attr.index.shape[0]}"
59+
)
60+
td[attr.attr_name][attr.index] = tensor
61+
return True
62+
else:
63+
warnings.warn(
64+
"Ignoring index parameter "
65+
f"(attribute does not exist for group {attr.group_name})"
66+
)
67+
68+
if tensor.shape[0] != batch_size:
69+
raise ValueError(
70+
"Leading size of input tensor does not match "
71+
f"existing tensors for group name {attr.group_name};"
72+
f" Expected {batch_size}, got {tensor.shape[0]}"
73+
)
74+
else:
75+
batch_size = tensor.shape[0]
76+
self.__features[attr.group_name] = tensordict.TensorDict(
77+
{}, batch_size=batch_size
78+
)
79+
80+
self.__features[attr.group_name][attr.attr_name] = tensor
81+
return True
82+
83+
def _get_tensor(
84+
self, attr: "torch_geometric.data.feature_store.TensorAttr"
85+
) -> Optional["torch_geometric.typing.FeatureTensorType"]:
86+
if attr.group_name not in self.__features:
87+
return None
88+
89+
if attr.attr_name not in self.__features[attr.group_name].keys():
90+
return None
91+
92+
tensor = self.__features[attr.group_name][attr.attr_name]
93+
return (
94+
tensor
95+
if (attr.index is None or (not attr.is_set("index")))
96+
else tensor[attr.index]
97+
)
98+
99+
def _remove_tensor(
100+
self, attr: "torch_geometric.data.feature_store.TensorAttr"
101+
) -> bool:
102+
if attr.group_name not in self.__features:
103+
return False
104+
105+
if attr.attr_name not in self.__features[attr.group_name].keys():
106+
return False
107+
108+
del self.__features[attr.group_name][attr.attr_name]
109+
return True
110+
111+
def _get_tensor_size(
112+
self, attr: "torch_geometric.data.feature_store.TensorAttr"
113+
) -> Tuple:
114+
return self._get_tensor(attr).size()
115+
116+
def get_all_tensor_attrs(
117+
self,
118+
) -> List["torch_geometric.data.feature_store.TensorAttr"]:
119+
attrs = []
120+
for group_name, td in self.__features.items():
121+
for attr_name in td.keys():
122+
attrs.append(
123+
torch_geometric.data.feature_store.TensorAttr(
124+
group_name,
125+
attr_name,
126+
)
127+
)
128+
129+
return attrs

0 commit comments

Comments
 (0)