Skip to content

Commit ac35be3

Browse files
[BUG] Use the Correct WG Communicator (#4548)
cuGraph-PyG's WholeFeatureStore currently uses the local communicator, when it should be using the global communicator, as was originally intended. This PR modifies the feature store so it correctly calls `get_global_node_communicator()`. This also fixes another bug where torch.int32 was used to store the number of edges in the graph, which resulted in an overflow error when the number of edges exceeded that datatype's maximum value. The datatype is now correctly set to int64. Authors: - Alex Barghi (https://github.com/alexbarghi-nv) Approvers: - Rick Ratzel (https://github.com/rlratzel) URL: #4548
1 parent 94e60f0 commit ac35be3

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

python/cugraph-pyg/cugraph_pyg/data/feature_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def __init__(self, memory_type="distributed", location="cpu"):
169169

170170
self.__features = {}
171171

172-
self.__wg_comm = wgth.get_local_node_communicator()
172+
self.__wg_comm = wgth.get_global_communicator()
173173
self.__wg_type = memory_type
174174
self.__wg_location = location
175175

python/cugraph-pyg/cugraph_pyg/data/graph_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def __get_edgelist(self):
271271
torch.tensor(
272272
[self.__edge_indices[et].shape[1] for et in sorted_keys],
273273
device="cuda",
274-
dtype=torch.int32,
274+
dtype=torch.int64,
275275
)
276276
)
277277

0 commit comments

Comments
 (0)