Skip to content

Commit 450f987

Browse files
Erdos-Renyi generator had bad logic in thrust calls (#4362)
The `random_iterator`, a thrust transform iterator that was used for counting and filtering returns the probability, but they `copy_if` call was assuming it was returning the index. Modified the logic and then consolidated the `copy_if` and `transform` into a single call with an output iterator. Closes #4359 Authors: - Chuck Hastings (https://github.com/ChuckHastings) Approvers: - Seunghwa Kang (https://github.com/seunghwak) - Naim (https://github.com/naimnv) - Joseph Nke (https://github.com/jnke2016) URL: #4362
1 parent ca88a47 commit 450f987

File tree

2 files changed

+26
-33
lines changed

2 files changed

+26
-33
lines changed

cpp/src/generators/erdos_renyi_generator.cu

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@
2323
#include <thrust/copy.h>
2424
#include <thrust/count.h>
2525
#include <thrust/iterator/counting_iterator.h>
26-
#include <thrust/iterator/transform_iterator.h>
26+
#include <thrust/iterator/transform_output_iterator.h>
2727
#include <thrust/iterator/zip_iterator.h>
2828
#include <thrust/random.h>
29-
#include <thrust/transform.h>
3029
#include <thrust/tuple.h>
3130

3231
namespace cugraph {
@@ -42,45 +41,38 @@ generate_erdos_renyi_graph_edgelist_gnp(raft::handle_t const& handle,
4241
CUGRAPH_EXPECTS(num_vertices < std::numeric_limits<int32_t>::max(),
4342
"Implementation cannot support specified value");
4443

45-
auto random_iterator = thrust::make_transform_iterator(
46-
thrust::make_counting_iterator<size_t>(0),
47-
cuda::proclaim_return_type<float>([seed] __device__(size_t index) {
48-
thrust::default_random_engine rng(seed);
49-
thrust::uniform_real_distribution<float> dist(0.0, 1.0);
50-
rng.discard(index);
51-
return dist(rng);
52-
}));
44+
size_t max_num_edges = static_cast<size_t>(num_vertices) * num_vertices;
5345

54-
size_t count = thrust::count_if(handle.get_thrust_policy(),
55-
random_iterator,
56-
random_iterator + num_vertices * num_vertices,
57-
[p] __device__(float prob) { return prob < p; });
58-
59-
rmm::device_uvector<size_t> indices_v(count, handle.get_stream());
46+
auto generate_random_value = cuda::proclaim_return_type<float>([seed] __device__(size_t index) {
47+
thrust::default_random_engine rng(seed);
48+
thrust::uniform_real_distribution<float> dist(0.0, 1.0);
49+
rng.discard(index);
50+
return dist(rng);
51+
});
6052

61-
thrust::copy_if(handle.get_thrust_policy(),
62-
random_iterator,
63-
random_iterator + num_vertices * num_vertices,
64-
indices_v.begin(),
65-
[p] __device__(float prob) { return prob < p; });
53+
size_t count = thrust::count_if(handle.get_thrust_policy(),
54+
thrust::make_counting_iterator<size_t>(0),
55+
thrust::make_counting_iterator<size_t>(max_num_edges),
56+
[generate_random_value, p] __device__(size_t index) {
57+
return generate_random_value(index) < p;
58+
});
6659

6760
rmm::device_uvector<vertex_t> src_v(count, handle.get_stream());
6861
rmm::device_uvector<vertex_t> dst_v(count, handle.get_stream());
6962

70-
thrust::transform(handle.get_thrust_policy(),
71-
indices_v.begin(),
72-
indices_v.end(),
73-
thrust::make_zip_iterator(thrust::make_tuple(src_v.begin(), src_v.end())),
63+
thrust::copy_if(handle.get_thrust_policy(),
64+
thrust::make_counting_iterator<size_t>(0),
65+
thrust::make_counting_iterator<size_t>(max_num_edges),
66+
thrust::make_transform_output_iterator(
67+
thrust::make_zip_iterator(src_v.begin(), dst_v.begin()),
7468
cuda::proclaim_return_type<thrust::tuple<vertex_t, vertex_t>>(
7569
[num_vertices] __device__(size_t index) {
76-
size_t src = index / num_vertices;
77-
size_t dst = index % num_vertices;
78-
79-
return thrust::make_tuple(static_cast<vertex_t>(src),
80-
static_cast<vertex_t>(dst));
81-
}));
82-
83-
handle.sync_stream();
70+
return thrust::make_tuple(static_cast<vertex_t>(index / num_vertices),
71+
static_cast<vertex_t>(index % num_vertices));
72+
})),
73+
[generate_random_value, p] __device__(size_t index) {
74+
return generate_random_value(index) < p;
75+
});
8476

8577
return std::make_tuple(std::move(src_v), std::move(dst_v));
8678
}

cpp/tests/generators/erdos_renyi_test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ void er_test(size_t num_vertices, float p)
8787
TEST_F(GenerateErdosRenyiTest, ERTest)
8888
{
8989
er_test<int32_t>(size_t{10}, float{0.1});
90+
er_test<int32_t>(size_t{10}, float{0.5});
9091
er_test<int32_t>(size_t{20}, float{0.1});
9192
er_test<int32_t>(size_t{50}, float{0.1});
9293
er_test<int32_t>(size_t{10000}, float{0.1});

0 commit comments

Comments
 (0)