Skip to content

Commit 36c190a

Browse files
[FEA] Biased Sampling in cuGraph-DGL (#4595)
Adds support for biased sampling to cuGraph-DGL. Resolves rapidsai/cugraph-gnn#25 Merge after #4583, #4586, #4607 Authors: - Alex Barghi (https://github.com/alexbarghi-nv) - Ralph Liu (https://github.com/nv-rliu) - Seunghwa Kang (https://github.com/seunghwak) Approvers: - Ray Douglass (https://github.com/raydouglass) - Tingyu Wang (https://github.com/tingyu66) - Rick Ratzel (https://github.com/rlratzel) URL: #4595
1 parent 1e5b328 commit 36c190a

File tree

7 files changed

+224
-85
lines changed

7 files changed

+224
-85
lines changed

python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from typing import Sequence, Optional, Union, List, Tuple, Iterator
2020

21-
from cugraph.gnn import UniformNeighborSampler, DistSampleWriter
21+
from cugraph.gnn import UniformNeighborSampler, BiasedNeighborSampler, DistSampleWriter
2222
from cugraph.utilities.utils import import_optional
2323

2424
import cugraph_dgl
@@ -93,7 +93,6 @@ def __init__(
9393
If provided, the probability of each neighbor being
9494
sampled is proportional to the edge feature
9595
with the given name. Mutually exclusive with mask.
96-
Currently unsupported.
9796
mask: str
9897
Optional.
9998
If proivided, only neighbors where the edge mask
@@ -133,10 +132,6 @@ def __init__(
133132
raise NotImplementedError(
134133
"Edge masking is currently unsupported by cuGraph-DGL"
135134
)
136-
if prob:
137-
raise NotImplementedError(
138-
"Edge masking is currently unsupported by cuGraph-DGL"
139-
)
140135
if prefetch_edge_feats:
141136
warnings.warn("'prefetch_edge_feats' is ignored by cuGraph-DGL")
142137
if prefetch_node_feats:
@@ -146,6 +141,8 @@ def __init__(
146141
if fused:
147142
warnings.warn("'fused' is ignored by cuGraph-DGL")
148143

144+
self.__prob_attr = prob
145+
149146
self.fanouts = fanouts_per_layer
150147
reverse_fanouts = fanouts_per_layer.copy()
151148
reverse_fanouts.reverse()
@@ -180,8 +177,14 @@ def sample(
180177
format=kwargs.pop("format", "parquet"),
181178
)
182179

183-
ds = UniformNeighborSampler(
184-
g._graph(self.edge_dir),
180+
sampling_clx = (
181+
UniformNeighborSampler
182+
if self.__prob_attr is None
183+
else BiasedNeighborSampler
184+
)
185+
186+
ds = sampling_clx(
187+
g._graph(self.edge_dir, prob_attr=self.__prob_attr),
185188
writer,
186189
compression="CSR",
187190
fanout=self._reversed_fanout_vals,

python/cugraph-dgl/cugraph_dgl/graph.py

Lines changed: 82 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def add_edges(
312312
self.__graph = None
313313
self.__vertex_offsets = None
314314

315-
def num_nodes(self, ntype: str = None) -> int:
315+
def num_nodes(self, ntype: Optional[str] = None) -> int:
316316
"""
317317
Returns the number of nodes of ntype, or if ntype is not provided,
318318
the total number of nodes in the graph.
@@ -322,7 +322,7 @@ def num_nodes(self, ntype: str = None) -> int:
322322

323323
return self.__num_nodes_dict[ntype]
324324

325-
def number_of_nodes(self, ntype: str = None) -> int:
325+
def number_of_nodes(self, ntype: Optional[str] = None) -> int:
326326
"""
327327
Alias for num_nodes.
328328
"""
@@ -381,7 +381,7 @@ def _vertex_offsets(self) -> Dict[str, int]:
381381

382382
return dict(self.__vertex_offsets)
383383

384-
def __get_edgelist(self) -> Dict[str, "torch.Tensor"]:
384+
def __get_edgelist(self, prob_attr=None) -> Dict[str, "torch.Tensor"]:
385385
"""
386386
This function always returns src/dst labels with respect
387387
to the out direction.
@@ -431,63 +431,71 @@ def __get_edgelist(self) -> Dict[str, "torch.Tensor"]:
431431
)
432432
)
433433

434+
num_edges_t = torch.tensor(
435+
[self.__edge_indices[et].shape[1] for et in sorted_keys], device="cuda"
436+
)
437+
434438
if self.is_multi_gpu:
435439
rank = torch.distributed.get_rank()
436440
world_size = torch.distributed.get_world_size()
437441

438-
num_edges_t = torch.tensor(
439-
[self.__edge_indices[et].shape[1] for et in sorted_keys], device="cuda"
440-
)
441442
num_edges_all_t = torch.empty(
442443
world_size, num_edges_t.numel(), dtype=torch.int64, device="cuda"
443444
)
444445
torch.distributed.all_gather_into_tensor(num_edges_all_t, num_edges_t)
445446

446-
if rank > 0:
447-
start_offsets = num_edges_all_t[:rank].T.sum(axis=1)
448-
edge_id_array = torch.concat(
447+
start_offsets = num_edges_all_t[:rank].T.sum(axis=1)
448+
449+
else:
450+
rank = 0
451+
start_offsets = torch.zeros(
452+
(len(sorted_keys),), dtype=torch.int64, device="cuda"
453+
)
454+
num_edges_all_t = num_edges_t.reshape((1, num_edges_t.numel()))
455+
456+
# Use pinned memory here for fast access to CPU/WG storage
457+
edge_id_array_per_type = [
458+
torch.arange(
459+
start_offsets[i],
460+
start_offsets[i] + num_edges_all_t[rank][i],
461+
dtype=torch.int64,
462+
device="cpu",
463+
).pin_memory()
464+
for i in range(len(sorted_keys))
465+
]
466+
467+
# Retrieve the weights from the appropriate feature(s)
468+
# DGL implicitly requires all edge types use the same
469+
# feature name.
470+
if prob_attr is None:
471+
weights = None
472+
else:
473+
if len(sorted_keys) > 1:
474+
weights = torch.concat(
449475
[
450-
torch.arange(
451-
start_offsets[i],
452-
start_offsets[i] + num_edges_all_t[rank][i],
453-
dtype=torch.int64,
454-
device="cuda",
455-
)
456-
for i in range(len(sorted_keys))
476+
self.edata[prob_attr][sorted_keys[i]][ix]
477+
for i, ix in enumerate(edge_id_array_per_type)
457478
]
458479
)
459480
else:
460-
edge_id_array = torch.concat(
461-
[
462-
torch.arange(
463-
self.__edge_indices[et].shape[1],
464-
dtype=torch.int64,
465-
device="cuda",
466-
)
467-
for et in sorted_keys
468-
]
469-
)
481+
weights = self.edata[prob_attr][edge_id_array_per_type[0]]
470482

471-
else:
472-
# single GPU
473-
edge_id_array = torch.concat(
474-
[
475-
torch.arange(
476-
self.__edge_indices[et].shape[1],
477-
dtype=torch.int64,
478-
device="cuda",
479-
)
480-
for et in sorted_keys
481-
]
482-
)
483+
# Safe to move this to cuda because the consumer will always
484+
# move it to cuda if it isn't already there.
485+
edge_id_array = torch.concat(edge_id_array_per_type).cuda()
483486

484-
return {
487+
edgelist_dict = {
485488
"src": edge_index[0],
486489
"dst": edge_index[1],
487490
"etp": edge_type_array,
488491
"eid": edge_id_array,
489492
}
490493

494+
if weights is not None:
495+
edgelist_dict["wgt"] = weights
496+
497+
return edgelist_dict
498+
491499
@property
492500
def is_homogeneous(self):
493501
return len(self.__num_edges_dict) <= 1 and len(self.__num_nodes_dict) <= 1
@@ -508,7 +516,9 @@ def _resource_handle(self):
508516
return self.__handle
509517

510518
def _graph(
511-
self, direction: str
519+
self,
520+
direction: str,
521+
prob_attr: Optional[str] = None,
512522
) -> Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph]:
513523
"""
514524
Gets the pylibcugraph Graph object with edges pointing in the given direction
@@ -522,12 +532,16 @@ def _graph(
522532
is_multigraph=True, is_symmetric=False
523533
)
524534

525-
if self.__graph is not None and self.__graph[1] != direction:
526-
self.__graph = None
535+
if self.__graph is not None:
536+
if (
537+
self.__graph["direction"] != direction
538+
or self.__graph["prob_attr"] != prob_attr
539+
):
540+
self.__graph = None
527541

528542
if self.__graph is None:
529543
src_col, dst_col = ("src", "dst") if direction == "out" else ("dst", "src")
530-
edgelist_dict = self.__get_edgelist()
544+
edgelist_dict = self.__get_edgelist(prob_attr=prob_attr)
531545

532546
if self.is_multi_gpu:
533547
rank = torch.distributed.get_rank()
@@ -536,33 +550,35 @@ def _graph(
536550
vertices_array = cupy.arange(self.num_nodes(), dtype="int64")
537551
vertices_array = cupy.array_split(vertices_array, world_size)[rank]
538552

539-
self.__graph = (
540-
pylibcugraph.MGGraph(
541-
self._resource_handle,
542-
graph_properties,
543-
[cupy.asarray(edgelist_dict[src_col]).astype("int64")],
544-
[cupy.asarray(edgelist_dict[dst_col]).astype("int64")],
545-
vertices_array=[vertices_array],
546-
edge_id_array=[cupy.asarray(edgelist_dict["eid"])],
547-
edge_type_array=[cupy.asarray(edgelist_dict["etp"])],
548-
),
549-
direction,
553+
graph = pylibcugraph.MGGraph(
554+
self._resource_handle,
555+
graph_properties,
556+
[cupy.asarray(edgelist_dict[src_col]).astype("int64")],
557+
[cupy.asarray(edgelist_dict[dst_col]).astype("int64")],
558+
vertices_array=[vertices_array],
559+
edge_id_array=[cupy.asarray(edgelist_dict["eid"])],
560+
edge_type_array=[cupy.asarray(edgelist_dict["etp"])],
561+
weight_array=[cupy.asarray(edgelist_dict["wgt"])]
562+
if "wgt" in edgelist_dict
563+
else None,
550564
)
551565
else:
552-
self.__graph = (
553-
pylibcugraph.SGGraph(
554-
self._resource_handle,
555-
graph_properties,
556-
cupy.asarray(edgelist_dict[src_col]).astype("int64"),
557-
cupy.asarray(edgelist_dict[dst_col]).astype("int64"),
558-
vertices_array=cupy.arange(self.num_nodes(), dtype="int64"),
559-
edge_id_array=cupy.asarray(edgelist_dict["eid"]),
560-
edge_type_array=cupy.asarray(edgelist_dict["etp"]),
561-
),
562-
direction,
566+
graph = pylibcugraph.SGGraph(
567+
self._resource_handle,
568+
graph_properties,
569+
cupy.asarray(edgelist_dict[src_col]).astype("int64"),
570+
cupy.asarray(edgelist_dict[dst_col]).astype("int64"),
571+
vertices_array=cupy.arange(self.num_nodes(), dtype="int64"),
572+
edge_id_array=cupy.asarray(edgelist_dict["eid"]),
573+
edge_type_array=cupy.asarray(edgelist_dict["etp"]),
574+
weight_array=cupy.asarray(edgelist_dict["wgt"])
575+
if "wgt" in edgelist_dict
576+
else None,
563577
)
564578

565-
return self.__graph[0]
579+
self.__graph = {"graph": graph, "direction": direction, "prob_attr": prob_attr}
580+
581+
return self.__graph["graph"]
566582

567583
def _has_n_emb(self, ntype: str, emb_name: str) -> bool:
568584
return (ntype, emb_name) in self.__ndata_storage

python/cugraph-dgl/cugraph_dgl/tests/dataloading/test_dataloader.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
1415
import cugraph_dgl.dataloading
1516
import pytest
1617

@@ -48,9 +49,12 @@ def test_dataloader_basic_homogeneous():
4849
assert len(out_t) <= 2
4950

5051

51-
def sample_dgl_graphs(g, train_nid, fanouts, batch_size=1):
52+
def sample_dgl_graphs(g, train_nid, fanouts, batch_size=1, prob_attr=None):
5253
# Single fanout to match cugraph
53-
sampler = dgl.dataloading.NeighborSampler(fanouts)
54+
sampler = dgl.dataloading.NeighborSampler(
55+
fanouts,
56+
prob=prob_attr,
57+
)
5458
dataloader = dgl.dataloading.DataLoader(
5559
g,
5660
train_nid,
@@ -71,8 +75,13 @@ def sample_dgl_graphs(g, train_nid, fanouts, batch_size=1):
7175
return dgl_output
7276

7377

74-
def sample_cugraph_dgl_graphs(cugraph_g, train_nid, fanouts, batch_size=1):
75-
sampler = cugraph_dgl.dataloading.NeighborSampler(fanouts)
78+
def sample_cugraph_dgl_graphs(
79+
cugraph_g, train_nid, fanouts, batch_size=1, prob_attr=None
80+
):
81+
sampler = cugraph_dgl.dataloading.NeighborSampler(
82+
fanouts,
83+
prob=prob_attr,
84+
)
7685

7786
dataloader = cugraph_dgl.dataloading.FutureDataLoader(
7887
cugraph_g,
@@ -126,3 +135,41 @@ def test_same_homogeneousgraph_results(ix, batch_size):
126135
dgl_output[0]["blocks"][0].num_edges()
127136
== cugraph_output[0]["blocks"][0].num_edges()
128137
)
138+
139+
140+
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
141+
@pytest.mark.skipif(isinstance(dgl, MissingModule), reason="dgl not available")
142+
def test_dataloader_biased_homogeneous():
143+
src = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
144+
dst = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])
145+
wgt = torch.tensor([1, 1, 2, 0, 0, 0, 2, 1], dtype=torch.float32)
146+
147+
train_nid = torch.tensor([0, 1])
148+
# Create a heterograph with 3 node types and 3 edges types.
149+
dgl_g = dgl.graph((src, dst))
150+
dgl_g.edata["wgt"] = wgt
151+
152+
cugraph_g = cugraph_dgl.Graph(is_multi_gpu=False)
153+
cugraph_g.add_nodes(9)
154+
cugraph_g.add_edges(u=src, v=dst, data={"wgt": wgt})
155+
156+
dgl_output = sample_dgl_graphs(dgl_g, train_nid, [4], batch_size=2, prob_attr="wgt")
157+
cugraph_output = sample_cugraph_dgl_graphs(
158+
cugraph_g, train_nid, [4], batch_size=2, prob_attr="wgt"
159+
)
160+
161+
cugraph_output_nodes = cugraph_output[0]["output_nodes"].cpu().numpy()
162+
dgl_output_nodes = dgl_output[0]["output_nodes"].cpu().numpy()
163+
164+
np.testing.assert_array_equal(
165+
np.sort(cugraph_output_nodes), np.sort(dgl_output_nodes)
166+
)
167+
assert (
168+
dgl_output[0]["blocks"][0].num_dst_nodes()
169+
== cugraph_output[0]["blocks"][0].num_dst_nodes()
170+
)
171+
assert (
172+
dgl_output[0]["blocks"][0].num_edges()
173+
== cugraph_output[0]["blocks"][0].num_edges()
174+
)
175+
assert 5 == cugraph_output[0]["blocks"][0].num_edges()

0 commit comments

Comments
 (0)