@@ -312,7 +312,7 @@ def add_edges(
312
312
self .__graph = None
313
313
self .__vertex_offsets = None
314
314
315
- def num_nodes (self , ntype : str = None ) -> int :
315
+ def num_nodes (self , ntype : Optional [ str ] = None ) -> int :
316
316
"""
317
317
Returns the number of nodes of ntype, or if ntype is not provided,
318
318
the total number of nodes in the graph.
@@ -322,7 +322,7 @@ def num_nodes(self, ntype: str = None) -> int:
322
322
323
323
return self .__num_nodes_dict [ntype ]
324
324
325
- def number_of_nodes (self , ntype : str = None ) -> int :
325
+ def number_of_nodes (self , ntype : Optional [ str ] = None ) -> int :
326
326
"""
327
327
Alias for num_nodes.
328
328
"""
@@ -381,7 +381,7 @@ def _vertex_offsets(self) -> Dict[str, int]:
381
381
382
382
return dict (self .__vertex_offsets )
383
383
384
- def __get_edgelist (self ) -> Dict [str , "torch.Tensor" ]:
384
+ def __get_edgelist (self , prob_attr = None ) -> Dict [str , "torch.Tensor" ]:
385
385
"""
386
386
This function always returns src/dst labels with respect
387
387
to the out direction.
@@ -431,63 +431,71 @@ def __get_edgelist(self) -> Dict[str, "torch.Tensor"]:
431
431
)
432
432
)
433
433
434
+ num_edges_t = torch .tensor (
435
+ [self .__edge_indices [et ].shape [1 ] for et in sorted_keys ], device = "cuda"
436
+ )
437
+
434
438
if self .is_multi_gpu :
435
439
rank = torch .distributed .get_rank ()
436
440
world_size = torch .distributed .get_world_size ()
437
441
438
- num_edges_t = torch .tensor (
439
- [self .__edge_indices [et ].shape [1 ] for et in sorted_keys ], device = "cuda"
440
- )
441
442
num_edges_all_t = torch .empty (
442
443
world_size , num_edges_t .numel (), dtype = torch .int64 , device = "cuda"
443
444
)
444
445
torch .distributed .all_gather_into_tensor (num_edges_all_t , num_edges_t )
445
446
446
- if rank > 0 :
447
- start_offsets = num_edges_all_t [:rank ].T .sum (axis = 1 )
448
- edge_id_array = torch .concat (
447
+ start_offsets = num_edges_all_t [:rank ].T .sum (axis = 1 )
448
+
449
+ else :
450
+ rank = 0
451
+ start_offsets = torch .zeros (
452
+ (len (sorted_keys ),), dtype = torch .int64 , device = "cuda"
453
+ )
454
+ num_edges_all_t = num_edges_t .reshape ((1 , num_edges_t .numel ()))
455
+
456
+ # Use pinned memory here for fast access to CPU/WG storage
457
+ edge_id_array_per_type = [
458
+ torch .arange (
459
+ start_offsets [i ],
460
+ start_offsets [i ] + num_edges_all_t [rank ][i ],
461
+ dtype = torch .int64 ,
462
+ device = "cpu" ,
463
+ ).pin_memory ()
464
+ for i in range (len (sorted_keys ))
465
+ ]
466
+
467
+ # Retrieve the weights from the appropriate feature(s)
468
+ # DGL implicitly requires all edge types use the same
469
+ # feature name.
470
+ if prob_attr is None :
471
+ weights = None
472
+ else :
473
+ if len (sorted_keys ) > 1 :
474
+ weights = torch .concat (
449
475
[
450
- torch .arange (
451
- start_offsets [i ],
452
- start_offsets [i ] + num_edges_all_t [rank ][i ],
453
- dtype = torch .int64 ,
454
- device = "cuda" ,
455
- )
456
- for i in range (len (sorted_keys ))
476
+ self .edata [prob_attr ][sorted_keys [i ]][ix ]
477
+ for i , ix in enumerate (edge_id_array_per_type )
457
478
]
458
479
)
459
480
else :
460
- edge_id_array = torch .concat (
461
- [
462
- torch .arange (
463
- self .__edge_indices [et ].shape [1 ],
464
- dtype = torch .int64 ,
465
- device = "cuda" ,
466
- )
467
- for et in sorted_keys
468
- ]
469
- )
481
+ weights = self .edata [prob_attr ][edge_id_array_per_type [0 ]]
470
482
471
- else :
472
- # single GPU
473
- edge_id_array = torch .concat (
474
- [
475
- torch .arange (
476
- self .__edge_indices [et ].shape [1 ],
477
- dtype = torch .int64 ,
478
- device = "cuda" ,
479
- )
480
- for et in sorted_keys
481
- ]
482
- )
483
+ # Safe to move this to cuda because the consumer will always
484
+ # move it to cuda if it isn't already there.
485
+ edge_id_array = torch .concat (edge_id_array_per_type ).cuda ()
483
486
484
- return {
487
+ edgelist_dict = {
485
488
"src" : edge_index [0 ],
486
489
"dst" : edge_index [1 ],
487
490
"etp" : edge_type_array ,
488
491
"eid" : edge_id_array ,
489
492
}
490
493
494
+ if weights is not None :
495
+ edgelist_dict ["wgt" ] = weights
496
+
497
+ return edgelist_dict
498
+
491
499
@property
492
500
def is_homogeneous (self ):
493
501
return len (self .__num_edges_dict ) <= 1 and len (self .__num_nodes_dict ) <= 1
@@ -508,7 +516,9 @@ def _resource_handle(self):
508
516
return self .__handle
509
517
510
518
def _graph (
511
- self , direction : str
519
+ self ,
520
+ direction : str ,
521
+ prob_attr : Optional [str ] = None ,
512
522
) -> Union [pylibcugraph .SGGraph , pylibcugraph .MGGraph ]:
513
523
"""
514
524
Gets the pylibcugraph Graph object with edges pointing in the given direction
@@ -522,12 +532,16 @@ def _graph(
522
532
is_multigraph = True , is_symmetric = False
523
533
)
524
534
525
- if self .__graph is not None and self .__graph [1 ] != direction :
526
- self .__graph = None
535
+ if self .__graph is not None :
536
+ if (
537
+ self .__graph ["direction" ] != direction
538
+ or self .__graph ["prob_attr" ] != prob_attr
539
+ ):
540
+ self .__graph = None
527
541
528
542
if self .__graph is None :
529
543
src_col , dst_col = ("src" , "dst" ) if direction == "out" else ("dst" , "src" )
530
- edgelist_dict = self .__get_edgelist ()
544
+ edgelist_dict = self .__get_edgelist (prob_attr = prob_attr )
531
545
532
546
if self .is_multi_gpu :
533
547
rank = torch .distributed .get_rank ()
@@ -536,33 +550,35 @@ def _graph(
536
550
vertices_array = cupy .arange (self .num_nodes (), dtype = "int64" )
537
551
vertices_array = cupy .array_split (vertices_array , world_size )[rank ]
538
552
539
- self . __graph = (
540
- pylibcugraph . MGGraph (
541
- self . _resource_handle ,
542
- graph_properties ,
543
- [cupy .asarray (edgelist_dict [src_col ]).astype ("int64" )],
544
- [ cupy . asarray ( edgelist_dict [ dst_col ]). astype ( "int64" ) ],
545
- vertices_array = [ vertices_array ],
546
- edge_id_array = [cupy .asarray (edgelist_dict ["eid " ])],
547
- edge_type_array = [cupy .asarray (edgelist_dict ["etp " ])],
548
- ),
549
- direction ,
553
+ graph = pylibcugraph . MGGraph (
554
+ self . _resource_handle ,
555
+ graph_properties ,
556
+ [ cupy . asarray ( edgelist_dict [ src_col ]). astype ( "int64" )] ,
557
+ [cupy .asarray (edgelist_dict [dst_col ]).astype ("int64" )],
558
+ vertices_array = [ vertices_array ],
559
+ edge_id_array = [ cupy . asarray ( edgelist_dict [ "eid" ]) ],
560
+ edge_type_array = [cupy .asarray (edgelist_dict ["etp " ])],
561
+ weight_array = [cupy .asarray (edgelist_dict ["wgt " ])]
562
+ if "wgt" in edgelist_dict
563
+ else None ,
550
564
)
551
565
else :
552
- self . __graph = (
553
- pylibcugraph . SGGraph (
554
- self . _resource_handle ,
555
- graph_properties ,
556
- cupy .asarray (edgelist_dict [src_col ]).astype ("int64" ),
557
- cupy .asarray ( edgelist_dict [ dst_col ]). astype ( "int64" ),
558
- vertices_array = cupy .arange ( self . num_nodes (), dtype = "int64" ),
559
- edge_id_array = cupy .asarray (edgelist_dict ["eid " ]),
560
- edge_type_array = cupy .asarray (edgelist_dict ["etp " ]),
561
- ),
562
- direction ,
566
+ graph = pylibcugraph . SGGraph (
567
+ self . _resource_handle ,
568
+ graph_properties ,
569
+ cupy . asarray ( edgelist_dict [ src_col ]). astype ( "int64" ) ,
570
+ cupy .asarray (edgelist_dict [dst_col ]).astype ("int64" ),
571
+ vertices_array = cupy .arange ( self . num_nodes (), dtype = "int64" ),
572
+ edge_id_array = cupy .asarray ( edgelist_dict [ "eid" ] ),
573
+ edge_type_array = cupy .asarray (edgelist_dict ["etp " ]),
574
+ weight_array = cupy .asarray (edgelist_dict ["wgt " ])
575
+ if "wgt" in edgelist_dict
576
+ else None ,
563
577
)
564
578
565
- return self .__graph [0 ]
579
+ self .__graph = {"graph" : graph , "direction" : direction , "prob_attr" : prob_attr }
580
+
581
+ return self .__graph ["graph" ]
566
582
567
583
def _has_n_emb (self , ntype : str , emb_name : str ) -> bool :
568
584
return (ntype , emb_name ) in self .__ndata_storage
0 commit comments