Skip to content

Commit 5c7cb2b

Browse files
[FEA] cuGraph GNN NCCL-only Setup and Distributed Sampling (#4278)
* Adds the ability to run `pylibcugraph` without UCX/dask within PyTorch DDP. * Adds the new distributed sampler which uses the new nccl+ddp path to perform bulk sampling. Closes #4200 Closes #4201 Closes #4246 Closes #3851 Authors: - Alex Barghi (https://github.com/alexbarghi-nv) Approvers: - Seunghwa Kang (https://github.com/seunghwak) - Rick Ratzel (https://github.com/rlratzel) - Chuck Hastings (https://github.com/ChuckHastings) - Jake Awe (https://github.com/AyodeAwe) - Joseph Nke (https://github.com/jnke2016) URL: #4278
1 parent 80d0ecb commit 5c7cb2b

File tree

14 files changed

+1398
-5
lines changed

14 files changed

+1398
-5
lines changed

ci/run_cugraph_pyg_pytests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ pytest --cache-clear --ignore=tests/mg "$@" .
1111
# Test examples
1212
for e in "$(pwd)"/examples/*.py; do
1313
rapids-logger "running example $e"
14-
python $e
14+
(yes || true) | python $e
1515
done

ci/test_wheel_cugraph-pyg.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,6 @@ python -m pytest \
5252
# Test examples
5353
for e in "$(pwd)"/examples/*.py; do
5454
rapids-logger "running example $e"
55-
python $e
55+
(yes || true) | python $e
5656
done
5757
popd
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
# This example shows how to use cuGraph nccl-only comms, pylibcuGraph,
15+
# and PyTorch DDP to run a multi-GPU sampling workflow. Most users of the
16+
# GNN packages will not interact with cuGraph directly. This example
17+
# is intented for users who want to extend cuGraph within a DDP workflow.
18+
19+
import os
20+
import re
21+
import tempfile
22+
23+
import numpy as np
24+
import torch
25+
import torch.multiprocessing as tmp
26+
import torch.distributed as dist
27+
28+
import cudf
29+
30+
from cugraph.gnn import (
31+
cugraph_comms_init,
32+
cugraph_comms_shutdown,
33+
cugraph_comms_create_unique_id,
34+
cugraph_comms_get_raft_handle,
35+
DistSampleWriter,
36+
UniformNeighborSampler,
37+
)
38+
39+
from pylibcugraph import MGGraph, ResourceHandle, GraphProperties
40+
41+
from ogb.nodeproppred import NodePropPredDataset
42+
43+
44+
def init_pytorch(rank, world_size):
45+
os.environ["MASTER_ADDR"] = "localhost"
46+
os.environ["MASTER_PORT"] = "12355"
47+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
48+
49+
50+
def sample(rank: int, world_size: int, uid, edgelist, directory):
51+
init_pytorch(rank, world_size)
52+
53+
device = rank
54+
cugraph_comms_init(rank, world_size, uid, device)
55+
56+
print(f"rank {rank} initialized cugraph")
57+
58+
src = cudf.Series(np.array_split(edgelist[0], world_size)[rank])
59+
dst = cudf.Series(np.array_split(edgelist[1], world_size)[rank])
60+
61+
seeds_per_rank = 50
62+
seeds = cudf.Series(np.arange(rank * seeds_per_rank, (rank + 1) * seeds_per_rank))
63+
handle = ResourceHandle(cugraph_comms_get_raft_handle().getHandle())
64+
65+
print("constructing graph")
66+
G = MGGraph(
67+
handle,
68+
GraphProperties(is_multigraph=True, is_symmetric=False),
69+
[src],
70+
[dst],
71+
)
72+
print("graph constructed")
73+
74+
sample_writer = DistSampleWriter(directory=directory, batches_per_partition=2)
75+
sampler = UniformNeighborSampler(
76+
G,
77+
sample_writer,
78+
fanout=[5, 5],
79+
)
80+
81+
sampler.sample_from_nodes(seeds, batch_size=16, random_state=62)
82+
83+
dist.barrier()
84+
cugraph_comms_shutdown()
85+
print(f"rank {rank} shut down cugraph")
86+
87+
88+
def main():
89+
world_size = torch.cuda.device_count()
90+
uid = cugraph_comms_create_unique_id()
91+
92+
dataset = NodePropPredDataset("ogbn-products")
93+
el = dataset[0][0]["edge_index"].astype("int64")
94+
95+
with tempfile.TemporaryDirectory() as directory:
96+
tmp.spawn(
97+
sample,
98+
args=(world_size, uid, el, "."),
99+
nprocs=world_size,
100+
)
101+
102+
print("Printing samples...")
103+
for file in os.listdir(directory):
104+
m = re.match(r"batch=([0-9]+)\.([0-9]+)\-([0-9]+)\.([0-9]+)\.parquet", file)
105+
rank, start, _, end = int(m[1]), int(m[2]), int(m[3]), int(m[4])
106+
print(f"File: {file} (batches {start} to {end} for rank {rank})")
107+
print(cudf.read_parquet(os.path.join(directory, file)))
108+
print("\n")
109+
110+
111+
if __name__ == "__main__":
112+
main()
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
# This example shows how to use cuGraph nccl-only comms, pylibcuGraph,
15+
# and PyTorch to run a single-GPU sampling workflow. Most users of the
16+
# GNN packages will not interact with cuGraph directly. This example
17+
# is intented for users who want to extend cuGraph within a PyTorch workflow.
18+
19+
import os
20+
import re
21+
import tempfile
22+
23+
import numpy as np
24+
25+
import cudf
26+
27+
from cugraph.gnn import (
28+
DistSampleWriter,
29+
UniformNeighborSampler,
30+
)
31+
32+
from pylibcugraph import SGGraph, ResourceHandle, GraphProperties
33+
34+
from ogb.nodeproppred import NodePropPredDataset
35+
36+
37+
def sample(edgelist, directory):
38+
src = cudf.Series(edgelist[0])
39+
dst = cudf.Series(edgelist[1])
40+
41+
seeds_per_rank = 50
42+
seeds = cudf.Series(np.arange(0, seeds_per_rank))
43+
44+
print("constructing graph")
45+
G = SGGraph(
46+
ResourceHandle(),
47+
GraphProperties(is_multigraph=True, is_symmetric=False),
48+
src,
49+
dst,
50+
)
51+
print("graph constructed")
52+
53+
sample_writer = DistSampleWriter(directory=directory, batches_per_partition=2)
54+
sampler = UniformNeighborSampler(
55+
G,
56+
sample_writer,
57+
fanout=[5, 5],
58+
)
59+
60+
sampler.sample_from_nodes(seeds, batch_size=16, random_state=62)
61+
62+
63+
def main():
64+
dataset = NodePropPredDataset("ogbn-products")
65+
el = dataset[0][0]["edge_index"].astype("int64")
66+
67+
with tempfile.TemporaryDirectory() as directory:
68+
sample(el, directory)
69+
70+
print("Printing samples...")
71+
for file in os.listdir(directory):
72+
m = re.match(r"batch=([0-9]+)\.([0-9]+)\-([0-9]+)\.([0-9]+)\.parquet", file)
73+
rank, start, _, end = int(m[1]), int(m[2]), int(m[3]), int(m[4])
74+
print(f"File: {file} (batches {start} to {end} for rank {rank})")
75+
print(cudf.read_parquet(os.path.join(directory, file)))
76+
print("\n")
77+
78+
79+
if __name__ == "__main__":
80+
main()
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
# This example shows how to use cuGraph nccl-only comms, pylibcuGraph,
15+
# and PyTorch DDP to run a multi-GPU workflow. Most users of the
16+
# GNN packages will not interact with cuGraph directly. This example
17+
# is intented for users who want to extend cuGraph within a DDP workflow.
18+
19+
import os
20+
21+
import pandas
22+
import numpy as np
23+
import torch
24+
import torch.multiprocessing as tmp
25+
import torch.distributed as dist
26+
27+
import cudf
28+
29+
from cugraph.gnn import (
30+
cugraph_comms_init,
31+
cugraph_comms_shutdown,
32+
cugraph_comms_create_unique_id,
33+
cugraph_comms_get_raft_handle,
34+
)
35+
36+
from pylibcugraph import MGGraph, ResourceHandle, GraphProperties, degrees
37+
38+
from ogb.nodeproppred import NodePropPredDataset
39+
40+
41+
def init_pytorch(rank, world_size):
42+
os.environ["MASTER_ADDR"] = "localhost"
43+
os.environ["MASTER_PORT"] = "12355"
44+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
45+
46+
47+
def calc_degree(rank: int, world_size: int, uid, edgelist):
48+
init_pytorch(rank, world_size)
49+
50+
device = rank
51+
cugraph_comms_init(rank, world_size, uid, device)
52+
53+
print(f"rank {rank} initialized cugraph")
54+
55+
src = cudf.Series(np.array_split(edgelist[0], world_size)[rank])
56+
dst = cudf.Series(np.array_split(edgelist[1], world_size)[rank])
57+
58+
seeds = cudf.Series(np.arange(rank * 50, (rank + 1) * 50))
59+
handle = ResourceHandle(cugraph_comms_get_raft_handle().getHandle())
60+
61+
print("constructing graph")
62+
G = MGGraph(
63+
handle,
64+
GraphProperties(is_multigraph=True, is_symmetric=False),
65+
[src],
66+
[dst],
67+
)
68+
print("graph constructed")
69+
70+
print("calculating degrees")
71+
vertices, in_deg, out_deg = degrees(handle, G, seeds, do_expensive_check=False)
72+
print("degrees calculated")
73+
74+
print("constructing dataframe")
75+
df = pandas.DataFrame(
76+
{"v": vertices.get(), "in": in_deg.get(), "out": out_deg.get()}
77+
)
78+
print(df)
79+
80+
dist.barrier()
81+
cugraph_comms_shutdown()
82+
print(f"rank {rank} shut down cugraph")
83+
84+
85+
def main():
86+
world_size = torch.cuda.device_count()
87+
uid = cugraph_comms_create_unique_id()
88+
89+
dataset = NodePropPredDataset("ogbn-products")
90+
el = dataset[0][0]["edge_index"].astype("int64")
91+
92+
tmp.spawn(
93+
calc_degree,
94+
args=(world_size, uid, el),
95+
nprocs=world_size,
96+
)
97+
98+
99+
if __name__ == "__main__":
100+
main()
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
# This example shows how to use cuGraph and pylibcuGraph to run a
15+
# single-GPU workflow. Most users of the GNN packages will not interact
16+
# with cuGraph directly. This example is intented for users who want
17+
# to extend cuGraph within a PyTorch workflow.
18+
19+
import pandas
20+
import numpy as np
21+
22+
import cudf
23+
24+
from pylibcugraph import SGGraph, ResourceHandle, GraphProperties, degrees
25+
26+
from ogb.nodeproppred import NodePropPredDataset
27+
28+
29+
def calc_degree(edgelist):
30+
src = cudf.Series(edgelist[0])
31+
dst = cudf.Series(edgelist[1])
32+
33+
seeds = cudf.Series(np.arange(256))
34+
35+
print("constructing graph")
36+
G = SGGraph(
37+
ResourceHandle(),
38+
GraphProperties(is_multigraph=True, is_symmetric=False),
39+
src,
40+
dst,
41+
)
42+
print("graph constructed")
43+
44+
print("calculating degrees")
45+
vertices, in_deg, out_deg = degrees(
46+
ResourceHandle(), G, seeds, do_expensive_check=False
47+
)
48+
print("degrees calculated")
49+
50+
print("constructing dataframe")
51+
df = pandas.DataFrame(
52+
{"v": vertices.get(), "in": in_deg.get(), "out": out_deg.get()}
53+
)
54+
print(df)
55+
56+
print("done")
57+
58+
59+
def main():
60+
dataset = NodePropPredDataset("ogbn-products")
61+
el = dataset[0][0]["edge_index"].astype("int64")
62+
calc_degree(el)
63+
64+
65+
if __name__ == "__main__":
66+
main()

python/cugraph/cugraph/gnn/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
1+
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -13,3 +13,14 @@
1313

1414
from .feature_storage.feat_storage import FeatureStore
1515
from .data_loading.bulk_sampler import BulkSampler
16+
from .data_loading.dist_sampler import (
17+
DistSampler,
18+
DistSampleWriter,
19+
UniformNeighborSampler,
20+
)
21+
from .comms.cugraph_nccl_comms import (
22+
cugraph_comms_init,
23+
cugraph_comms_shutdown,
24+
cugraph_comms_create_unique_id,
25+
cugraph_comms_get_raft_handle,
26+
)

0 commit comments

Comments
 (0)