23
23
#include < thrust/copy.h>
24
24
#include < thrust/count.h>
25
25
#include < thrust/iterator/counting_iterator.h>
26
- #include < thrust/iterator/transform_iterator .h>
26
+ #include < thrust/iterator/transform_output_iterator .h>
27
27
#include < thrust/iterator/zip_iterator.h>
28
28
#include < thrust/random.h>
29
- #include < thrust/transform.h>
30
29
#include < thrust/tuple.h>
31
30
32
31
namespace cugraph {
@@ -42,45 +41,38 @@ generate_erdos_renyi_graph_edgelist_gnp(raft::handle_t const& handle,
42
41
CUGRAPH_EXPECTS (num_vertices < std::numeric_limits<int32_t >::max (),
43
42
" Implementation cannot support specified value" );
44
43
45
- auto random_iterator = thrust::make_transform_iterator (
46
- thrust::make_counting_iterator<size_t >(0 ),
47
- cuda::proclaim_return_type<float >([seed] __device__ (size_t index) {
48
- thrust::default_random_engine rng (seed);
49
- thrust::uniform_real_distribution<float > dist (0.0 , 1.0 );
50
- rng.discard (index);
51
- return dist (rng);
52
- }));
44
+ size_t max_num_edges = static_cast <size_t >(num_vertices) * num_vertices;
53
45
54
- size_t count = thrust::count_if (handle. get_thrust_policy (),
55
- random_iterator,
56
- random_iterator + num_vertices * num_vertices,
57
- [p] __device__ ( float prob) { return prob < p; } );
58
-
59
- rmm::device_uvector< size_t > indices_v (count, handle. get_stream () );
46
+ auto generate_random_value = cuda::proclaim_return_type< float >([seed] __device__ ( size_t index) {
47
+ thrust::default_random_engine rng (seed);
48
+ thrust::uniform_real_distribution< float > dist ( 0.0 , 1.0 );
49
+ rng. discard (index );
50
+ return dist (rng);
51
+ } );
60
52
61
- thrust::copy_if (handle.get_thrust_policy (),
62
- random_iterator,
63
- random_iterator + num_vertices * num_vertices,
64
- indices_v.begin (),
65
- [p] __device__ (float prob) { return prob < p; });
53
+ size_t count = thrust::count_if (handle.get_thrust_policy (),
54
+ thrust::make_counting_iterator<size_t >(0 ),
55
+ thrust::make_counting_iterator<size_t >(max_num_edges),
56
+ [generate_random_value, p] __device__ (size_t index) {
57
+ return generate_random_value (index) < p;
58
+ });
66
59
67
60
rmm::device_uvector<vertex_t > src_v (count, handle.get_stream ());
68
61
rmm::device_uvector<vertex_t > dst_v (count, handle.get_stream ());
69
62
70
- thrust::transform (handle.get_thrust_policy (),
71
- indices_v.begin (),
72
- indices_v.end (),
73
- thrust::make_zip_iterator (thrust::make_tuple (src_v.begin (), src_v.end ())),
63
+ thrust::copy_if (handle.get_thrust_policy (),
64
+ thrust::make_counting_iterator<size_t >(0 ),
65
+ thrust::make_counting_iterator<size_t >(max_num_edges),
66
+ thrust::make_transform_output_iterator (
67
+ thrust::make_zip_iterator (src_v.begin (), dst_v.begin ()),
74
68
cuda::proclaim_return_type<thrust::tuple<vertex_t , vertex_t >>(
75
69
[num_vertices] __device__ (size_t index) {
76
- size_t src = index / num_vertices;
77
- size_t dst = index % num_vertices;
78
-
79
- return thrust::make_tuple (static_cast <vertex_t >(src),
80
- static_cast <vertex_t >(dst));
81
- }));
82
-
83
- handle.sync_stream ();
70
+ return thrust::make_tuple (static_cast <vertex_t >(index / num_vertices),
71
+ static_cast <vertex_t >(index % num_vertices));
72
+ })),
73
+ [generate_random_value, p] __device__ (size_t index) {
74
+ return generate_random_value (index) < p;
75
+ });
84
76
85
77
return std::make_tuple (std::move (src_v), std::move (dst_v));
86
78
}
0 commit comments