Skip to content

Commit 581bd9c

Browse files
authored
Add Additional Checks to get_edgelist and get_dask_edgelist (#4256)
Closes #4241 This PR adds an additional check to the `get_edgelist()` and `get_dask_edgelist()` functions in the Datasets API. This ensures that, when retrieving an edge-list, the internal (`self._edgelist`) type is verified to ensure that the object is SG or MG. In addition, minor improvements have also been made `utils/test_dataset.py` to be more thorough with type checks. Authors: - Ralph Liu (https://github.com/nv-rliu) Approvers: - Rick Ratzel (https://github.com/rlratzel) URL: #4256
1 parent 1b88ea1 commit 581bd9c

File tree

2 files changed

+31
-31
lines changed

2 files changed

+31
-31
lines changed

python/cugraph/cugraph/datasets/dataset.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def get_edgelist(self, download=False, reader="cudf"):
174174
reader : 'cudf' or 'pandas' (default='cudf')
175175
The library used to read a CSV and return an edgelist DataFrame.
176176
"""
177-
if self._edgelist is None:
177+
if self._edgelist is None or not isinstance(self._edgelist, cudf.DataFrame):
178178
full_path = self.get_path()
179179
if not full_path.is_file():
180180
if download:
@@ -223,7 +223,9 @@ def get_dask_edgelist(self, download=False):
223223
Automatically download the dataset from the 'url' location within
224224
the YAML file.
225225
"""
226-
if self._edgelist is None:
226+
if self._edgelist is None or not isinstance(
227+
self._edgelist, dask_cudf.DataFrame
228+
):
227229
full_path = self.get_path()
228230
if not full_path.is_file():
229231
if download:
@@ -286,7 +288,7 @@ def get_graph(
286288
for certain algorithms, such as pagerank.
287289
"""
288290
if self._edgelist is None:
289-
self.get_edgelist(download)
291+
self.get_edgelist(download=download)
290292

291293
if create_using is None:
292294
G = Graph()
@@ -351,7 +353,7 @@ def get_dask_graph(
351353
for certain algorithms.
352354
"""
353355
if self._edgelist is None:
354-
self.get_dask_edgelist(download)
356+
self.get_dask_edgelist(download=download)
355357

356358
if create_using is None:
357359
G = Graph()
@@ -367,7 +369,7 @@ def get_dask_graph(
367369
f"{type(create_using)}"
368370
)
369371

370-
if len(self.metadata["col_names"]) > 2 and not (ignore_weights):
372+
if len(self.metadata["col_names"]) > 2 and not ignore_weights:
371373
G.from_dask_cudf_edgelist(
372374
self._edgelist,
373375
source=self.metadata["col_names"][0],

python/cugraph/cugraph/tests/utils/test_dataset.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
import pytest
2121

2222
import cudf
23+
import cugraph
2324
import dask_cudf
25+
from cugraph import datasets
26+
from cugraph.dask.common.mg_utils import is_single_gpu
27+
from cugraph.datasets import karate
2428
from cugraph.structure import Graph
2529
from cugraph.testing import (
2630
RAPIDS_DATASET_ROOT_DIR_PATH,
@@ -29,15 +33,15 @@
2933
SMALL_DATASETS,
3034
BENCHMARKING_DATASETS,
3135
)
32-
from cugraph import datasets
33-
from cugraph.dask.common.mg_utils import is_single_gpu
36+
3437

3538
# Add the sg marker to all tests in this module.
3639
pytestmark = pytest.mark.sg
3740

3841

3942
###############################################################################
4043
# Fixtures
44+
###############################################################################
4145

4246

4347
# module fixture - called once for this module
@@ -201,21 +205,26 @@ def test_reader_dask(dask_client, dataset):
201205
@pytest.mark.parametrize("dataset", ALL_DATASETS)
202206
def test_get_edgelist(dataset):
203207
E = dataset.get_edgelist(download=True)
208+
204209
assert E is not None
210+
assert isinstance(E, cudf.DataFrame)
205211

206212

207213
@pytest.mark.skipif(is_single_gpu(), reason="skipping MG testing on Single GPU system")
208214
@pytest.mark.skip(reason="MG not supported on CI")
209215
@pytest.mark.parametrize("dataset", ALL_DATASETS)
210216
def test_get_dask_edgelist(dask_client, dataset):
211217
E = dataset.get_dask_edgelist(download=True)
218+
212219
assert E is not None
220+
assert isinstance(E, dask_cudf.DataFrame)
213221

214222

215223
@pytest.mark.parametrize("dataset", ALL_DATASETS)
216224
def test_get_graph(dataset):
217225
G = dataset.get_graph(download=True)
218226
assert G is not None
227+
assert isinstance(G, cugraph.Graph)
219228

220229

221230
@pytest.mark.skipif(is_single_gpu(), reason="skipping MG testing on Single GPU system")
@@ -224,12 +233,14 @@ def test_get_graph(dataset):
224233
def test_get_dask_graph(dask_client, dataset):
225234
G = dataset.get_dask_graph(download=True)
226235
assert G is not None
236+
# TODO Check G is a DistributedGraph
227237

228238

229239
@pytest.mark.parametrize("dataset", ALL_DATASETS)
230240
def test_metadata(dataset):
231241
M = dataset.metadata
232242
assert M is not None
243+
assert isinstance(M, dict)
233244

234245

235246
@pytest.mark.parametrize("dataset", ALL_DATASETS)
@@ -346,32 +357,19 @@ def test_ctor_with_datafile():
346357
assert ds.get_path() == karate_csv
347358

348359

349-
def test_unload():
350-
email_csv = RAPIDS_DATASET_ROOT_DIR_PATH / "email-Eu-core.csv"
360+
@pytest.mark.parametrize("dataset", [karate])
361+
def test_unload(dataset):
362+
assert dataset._edgelist is None
351363

352-
ds = datasets.Dataset(
353-
csv_file=email_csv.as_posix(),
354-
csv_col_names=["src", "dst", "wgt"],
355-
csv_col_dtypes=["int32", "int32", "float32"],
356-
)
364+
dataset.get_edgelist()
365+
assert dataset._edgelist is not None
366+
dataset.unload()
367+
assert dataset._edgelist is None
357368

358-
# FIXME: another (better?) test would be to check free memory and assert
359-
# the memory use increases after get_*(), then returns to the pre-get_*()
360-
# level after unload(). However, that type of test may fail for several
361-
# reasons (the device being monitored is accidentally also being used by
362-
# another process, and the use of memory pools to name two). Instead, just
363-
# test that the internal members get cleared on unload().
364-
assert ds._edgelist is None
365-
366-
ds.get_edgelist()
367-
assert ds._edgelist is not None
368-
ds.unload()
369-
assert ds._edgelist is None
370-
371-
ds.get_graph()
372-
assert ds._edgelist is not None
373-
ds.unload()
374-
assert ds._edgelist is None
369+
dataset.get_graph()
370+
assert dataset._edgelist is not None
371+
dataset.unload()
372+
assert dataset._edgelist is None
375373

376374

377375
@pytest.mark.parametrize("dataset", ALL_DATASETS)

0 commit comments

Comments
 (0)