Skip to content

Commit b2b44b0

Browse files
authored
Update collect_comm to handle value of tuple type (#4410)
Update collect_comm to handle value of tuple type Authors: - Naim (https://github.com/naimnv) Approvers: - Chuck Hastings (https://github.com/ChuckHastings) - Seunghwa Kang (https://github.com/seunghwak) - Joseph Nke (https://github.com/jnke2016) URL: #4410
1 parent dbd558f commit b2b44b0

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

cpp/src/utilities/collect_comm.cuh

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ collect_values_for_keys(raft::handle_t const& handle,
100100

101101
auto rx_values_for_unique_keys = allocate_dataframe_buffer<value_t>(0, handle.get_stream());
102102
std::tie(rx_values_for_unique_keys, std::ignore) =
103-
shuffle_values(comm, values_for_rx_unique_keys.begin(), rx_value_counts, handle.get_stream());
103+
shuffle_values(comm,
104+
get_dataframe_buffer_begin(values_for_rx_unique_keys),
105+
rx_value_counts,
106+
handle.get_stream());
104107

105108
values_for_unique_keys = std::move(rx_values_for_unique_keys);
106109
}
@@ -136,9 +139,9 @@ collect_values_for_keys(raft::handle_t const& handle,
136139
handle.get_stream());
137140

138141
unique_keys.resize(0, handle.get_stream());
139-
values_for_unique_keys.resize(0, handle.get_stream());
142+
resize_dataframe_buffer(values_for_unique_keys, 0, handle.get_stream());
140143
unique_keys.shrink_to_fit(handle.get_stream());
141-
values_for_unique_keys.shrink_to_fit(handle.get_stream());
144+
shrink_to_fit_dataframe_buffer(values_for_unique_keys, handle.get_stream());
142145
}
143146
auto unique_key_value_store_view = unique_key_value_store.view();
144147

@@ -248,15 +251,15 @@ collect_values_for_unique_int_vertices(raft::handle_t const& handle,
248251
thrust::transform(handle.get_thrust_policy(),
249252
rx_int_vertices.begin(),
250253
rx_int_vertices.end(),
251-
value_buffer.begin(),
254+
get_dataframe_buffer_begin(value_buffer),
252255
[local_value_first, local_int_vertex_first] __device__(auto v) {
253256
return local_value_first[v - local_int_vertex_first];
254257
});
255258

256259
// 3: Shuffle results back to original GPU
257260

258-
std::tie(value_buffer, std::ignore) =
259-
shuffle_values(comm, value_buffer.begin(), rx_int_vertex_counts, handle.get_stream());
261+
std::tie(value_buffer, std::ignore) = shuffle_values(
262+
comm, get_dataframe_buffer_begin(value_buffer), rx_int_vertex_counts, handle.get_stream());
260263

261264
return std::make_tuple(std::move(collect_unique_int_vertices), std::move(value_buffer));
262265
}
@@ -305,7 +308,7 @@ collect_values_for_int_vertices(
305308
thrust::transform(handle.get_thrust_policy(),
306309
collect_vertex_first,
307310
collect_vertex_last,
308-
value_buffer.begin(),
311+
get_dataframe_buffer_begin(value_buffer),
309312
[device_view] __device__(auto v) { return device_view.find(v); });
310313

311314
return value_buffer;

0 commit comments

Comments
 (0)