Skip to content

Commit 5c04b9c

Browse files
author
Naim
committed
Debug edge masking bug
1 parent 8dd7111 commit 5c04b9c

File tree

2 files changed

+144
-71
lines changed

2 files changed

+144
-71
lines changed

cpp/src/community/approx_weighted_matching_impl.cuh

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ std::tuple<rmm::device_uvector<vertex_t>, weight_t> approximate_weighted_matchin
8686
local_vertices.size(),
8787
current_graph_view.local_vertex_partition_range_first());
8888

89-
using flag_t = uint8_t;
89+
using flag_t = uint32_t;
9090
edge_src_property_t<graph_view_t, vertex_t> src_key_cache(handle);
9191
cugraph::edge_src_property_t<graph_view_t, flag_t> src_match_flags(handle);
9292
cugraph::edge_dst_property_t<graph_view_t, flag_t> dst_match_flags(handle);
@@ -101,6 +101,28 @@ std::tuple<rmm::device_uvector<vertex_t>, weight_t> approximate_weighted_matchin
101101

102102
vertex_t loop_counter = 0;
103103
while (true) {
104+
std::cout << "#V: " << current_graph_view.number_of_vertices()
105+
<< " #E: " << current_graph_view.compute_number_of_edges(handle) << std::endl;
106+
cugraph::edge_property_t<graph_view_t, bool> temp_eps(handle, current_graph_view);
107+
auto sg = graph_view_t::is_multi_gpu;
108+
cugraph::transform_e(
109+
handle,
110+
current_graph_view,
111+
cugraph::edge_src_dummy_property_t{}.view(),
112+
cugraph::edge_dst_dummy_property_t{}.view(),
113+
edge_weight_view,
114+
[loop_counter, sg] __device__(
115+
auto src, auto dst, thrust::nullopt_t, thrust::nullopt_t, auto wgt) {
116+
printf("\n %d => %d %d %f [%d]\n",
117+
static_cast<int>(loop_counter),
118+
static_cast<int>(src),
119+
static_cast<int>(dst),
120+
static_cast<float>(wgt),
121+
static_cast<int>(sg));
122+
return false;
123+
},
124+
temp_eps.mutable_view());
125+
104126
if constexpr (graph_view_t::is_multi_gpu) {
105127
update_edge_src_property(handle, current_graph_view, local_vertices.begin(), src_key_cache);
106128
}
@@ -150,34 +172,21 @@ std::tuple<rmm::device_uvector<vertex_t>, weight_t> approximate_weighted_matchin
150172
auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name());
151173
auto const minor_comm_size = minor_comm.get_size();
152174

153-
auto func = cugraph::detail::compute_gpu_id_from_int_vertex_t<vertex_t>{
175+
auto key_func = cugraph::detail::compute_gpu_id_from_int_vertex_t<vertex_t>{
154176
raft::device_span<vertex_t const>(d_vertex_partition_range_lasts.data(),
155177
d_vertex_partition_range_lasts.size()),
156178
major_comm_size,
157179
minor_comm_size};
158180

159-
rmm::device_uvector<size_t> d_tx_value_counts(0, handle.get_stream());
160-
161-
auto triplet_first = thrust::make_zip_iterator(
162-
candidates.begin(), offers_from_candidates.begin(), targets.begin());
163-
164-
d_tx_value_counts = cugraph::groupby_and_count(
165-
triplet_first,
166-
triplet_first + candidates.size(),
167-
[func] __device__(auto val) { return func(thrust::get<2>(val)); },
168-
handle.get_comms().get_size(),
169-
std::numeric_limits<vertex_t>::max(),
170-
handle.get_stream());
171-
172-
std::vector<size_t> h_tx_value_counts(d_tx_value_counts.size());
173-
raft::update_host(h_tx_value_counts.data(),
174-
d_tx_value_counts.data(),
175-
d_tx_value_counts.size(),
176-
handle.get_stream());
177-
handle.sync_stream();
178-
179181
std::forward_as_tuple(std::tie(candidates, offers_from_candidates, targets), std::ignore) =
180-
shuffle_values(handle.get_comms(), triplet_first, h_tx_value_counts, handle.get_stream());
182+
cugraph::groupby_gpu_id_and_shuffle_values(
183+
handle.get_comms(),
184+
thrust::make_zip_iterator(thrust::make_tuple(
185+
candidates.begin(), offers_from_candidates.begin(), targets.begin())),
186+
thrust::make_zip_iterator(
187+
thrust::make_tuple(candidates.end(), offers_from_candidates.end(), targets.end())),
188+
[key_func] __device__(auto val) { return key_func(thrust::get<2>(val)); },
189+
handle.get_stream());
181190
}
182191

183192
auto itr_to_tuples = thrust::make_zip_iterator(
@@ -254,6 +263,7 @@ std::tuple<rmm::device_uvector<vertex_t>, weight_t> approximate_weighted_matchin
254263
candidates.begin(),
255264
candidates.end(),
256265
vertex_to_gpu_id_op);
266+
257267
} else {
258268
candidates_of_candidates.resize(candidates.size(), handle.get_stream());
259269

@@ -263,6 +273,35 @@ std::tuple<rmm::device_uvector<vertex_t>, weight_t> approximate_weighted_matchin
263273
handle.get_stream());
264274
}
265275

276+
auto const comm_rank = graph_view_t::is_multi_gpu ? handle.get_comms().get_rank() : 0;
277+
278+
RAFT_CUDA_TRY(cudaDeviceSynchronize());
279+
auto targetss_title = std::string("targets_").append(std::to_string(comm_rank)).append("_");
280+
281+
raft::print_device_vector(targetss_title.c_str(), targets.begin(), targets.size(), std::cout);
282+
283+
RAFT_CUDA_TRY(cudaDeviceSynchronize());
284+
auto cands_title = std::string("cands_").append(std::to_string(comm_rank)).append("_");
285+
286+
raft::print_device_vector(
287+
cands_title.c_str(), candidates.begin(), candidates.size(), std::cout);
288+
289+
RAFT_CUDA_TRY(cudaDeviceSynchronize());
290+
auto offers_title = std::string("offers_").append(std::to_string(comm_rank)).append("_");
291+
292+
raft::print_device_vector(offers_title.c_str(),
293+
offers_from_candidates.begin(),
294+
offers_from_candidates.size(),
295+
std::cout);
296+
297+
RAFT_CUDA_TRY(cudaDeviceSynchronize());
298+
auto ccs_title = std::string("ccs_").append(std::to_string(comm_rank)).append("_");
299+
300+
raft::print_device_vector(ccs_title.c_str(),
301+
candidates_of_candidates.begin(),
302+
candidates_of_candidates.size(),
303+
std::cout);
304+
266305
//
267306
// Mask out neighborhood of matched vertices
268307
//
@@ -302,6 +341,12 @@ std::tuple<rmm::device_uvector<vertex_t>, weight_t> approximate_weighted_matchin
302341
}
303342
});
304343

344+
RAFT_CUDA_TRY(cudaDeviceSynchronize());
345+
auto ivm_title = std::string("ivm_").append(std::to_string(comm_rank)).append("_");
346+
347+
raft::print_device_vector(
348+
ivm_title.c_str(), is_vertex_matched.begin(), is_vertex_matched.size(), std::cout);
349+
305350
if (current_graph_view.compute_number_of_edges(handle) == 0) { break; }
306351

307352
if constexpr (graph_view_t::is_multi_gpu) {
@@ -324,6 +369,15 @@ std::tuple<rmm::device_uvector<vertex_t>, weight_t> approximate_weighted_matchin
324369
cugraph::edge_dummy_property_t{}.view(),
325370
[loop_counter] __device__(
326371
auto src, auto dst, auto is_src_matched, auto is_dst_matched, thrust::nullopt_t) {
372+
bool flag = !((is_src_matched == uint8_t{true}) || (is_dst_matched == uint8_t{true}));
373+
if (flag) {
374+
printf("\n** %d => src %d dst %d sm %d dm %d\n",
375+
static_cast<int>(loop_counter),
376+
static_cast<int>(src),
377+
static_cast<int>(dst),
378+
static_cast<int>(is_src_matched),
379+
static_cast<int>(is_dst_matched));
380+
}
327381
return !((is_src_matched == uint8_t{true}) || (is_dst_matched == uint8_t{true}));
328382
},
329383
edge_masks_odd.mutable_view());
@@ -356,6 +410,14 @@ std::tuple<rmm::device_uvector<vertex_t>, weight_t> approximate_weighted_matchin
356410
loop_counter++;
357411
}
358412

413+
auto const comm_rank = graph_view_t::is_multi_gpu ? handle.get_comms().get_rank() : 0;
414+
415+
RAFT_CUDA_TRY(cudaDeviceSynchronize());
416+
auto ofp_title = std::string("ofp_").append(std::to_string(comm_rank)).append("_");
417+
418+
raft::print_device_vector(
419+
ofp_title.c_str(), offers_from_partners.begin(), offers_from_partners.size(), std::cout);
420+
359421
weight_t sum_matched_edge_weights = thrust::reduce(
360422
handle.get_thrust_policy(), offers_from_partners.begin(), offers_from_partners.end());
361423

cpp/tests/community/mg_weighted_matching_test.cpp

Lines changed: 59 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "utilities/base_fixture.hpp"
1818
#include "utilities/conversion_utilities.hpp"
19+
#include "utilities/property_generator_utilities.hpp"
1920
#include "utilities/test_graphs.hpp"
2021

2122
#include <cugraph/algorithms.hpp>
@@ -37,6 +38,7 @@
3738
#include <random>
3839

3940
struct WeightedMatching_UseCase {
41+
bool edge_masking{false};
4042
bool check_correctness{true};
4143
};
4244

@@ -95,6 +97,13 @@ class Tests_MGWeightedMatching
9597
auto mg_edge_weight_view =
9698
mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt;
9799

100+
std::optional<cugraph::edge_property_t<decltype(mg_graph_view), bool>> edge_mask{std::nullopt};
101+
if (weighted_matching_usecase.edge_masking) {
102+
edge_mask = cugraph::test::generate<decltype(mg_graph_view), bool>::edge_property(
103+
*handle_, mg_graph_view, 2);
104+
mg_graph_view.attach_edge_mask((*edge_mask).view());
105+
}
106+
98107
rmm::device_uvector<vertex_t> mg_partners(0, handle_->get_stream());
99108
weight_t mg_matching_weights;
100109

@@ -155,71 +164,73 @@ class Tests_MGWeightedMatching
155164
template <typename input_usecase_t>
156165
std::unique_ptr<raft::handle_t> Tests_MGWeightedMatching<input_usecase_t>::handle_ = nullptr;
157166

158-
using Tests_MGWeightedMatching_File = Tests_MGWeightedMatching<cugraph::test::File_Usecase>;
167+
// using Tests_MGWeightedMatching_File = Tests_MGWeightedMatching<cugraph::test::File_Usecase>;
159168
using Tests_MGWeightedMatching_Rmat = Tests_MGWeightedMatching<cugraph::test::Rmat_Usecase>;
160169

161-
TEST_P(Tests_MGWeightedMatching_File, CheckInt32Int32FloatFloat)
162-
{
163-
run_current_test<int32_t, int32_t, float, int>(
164-
override_File_Usecase_with_cmd_line_arguments(GetParam()));
165-
}
170+
// TEST_P(Tests_MGWeightedMatching_File, CheckInt32Int32FloatFloat)
171+
// {
172+
// run_current_test<int32_t, int32_t, float, int>(
173+
// override_File_Usecase_with_cmd_line_arguments(GetParam()));
174+
// }
166175

167-
TEST_P(Tests_MGWeightedMatching_File, CheckInt32Int64FloatFloat)
168-
{
169-
run_current_test<int32_t, int64_t, float, int>(
170-
override_File_Usecase_with_cmd_line_arguments(GetParam()));
171-
}
176+
// TEST_P(Tests_MGWeightedMatching_File, CheckInt32Int64FloatFloat)
177+
// {
178+
// run_current_test<int32_t, int64_t, float, int>(
179+
// override_File_Usecase_with_cmd_line_arguments(GetParam()));
180+
// }
172181

173-
TEST_P(Tests_MGWeightedMatching_File, CheckInt64Int64FloatFloat)
174-
{
175-
run_current_test<int64_t, int64_t, float, int>(
176-
override_File_Usecase_with_cmd_line_arguments(GetParam()));
177-
}
182+
// TEST_P(Tests_MGWeightedMatching_File, CheckInt64Int64FloatFloat)
183+
// {
184+
// run_current_test<int64_t, int64_t, float, int>(
185+
// override_File_Usecase_with_cmd_line_arguments(GetParam()));
186+
// }
178187

179188
TEST_P(Tests_MGWeightedMatching_Rmat, CheckInt32Int32FloatFloat)
180189
{
181190
run_current_test<int32_t, int32_t, float, int>(
182191
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
183192
}
184193

185-
TEST_P(Tests_MGWeightedMatching_Rmat, CheckInt32Int64FloatFloat)
186-
{
187-
run_current_test<int32_t, int64_t, float, int>(
188-
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
189-
}
190-
191-
TEST_P(Tests_MGWeightedMatching_Rmat, CheckInt64Int64FloatFloat)
192-
{
193-
run_current_test<int64_t, int64_t, float, int>(
194-
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
195-
}
196-
197-
bool constexpr check_correctness = false;
198-
199-
INSTANTIATE_TEST_SUITE_P(
200-
file_test,
201-
Tests_MGWeightedMatching_File,
202-
::testing::Combine(::testing::Values(WeightedMatching_UseCase{check_correctness},
203-
WeightedMatching_UseCase{check_correctness}),
204-
::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"))));
194+
// TEST_P(Tests_MGWeightedMatching_Rmat, CheckInt32Int64FloatFloat)
195+
// {
196+
// run_current_test<int32_t, int64_t, float, int>(
197+
// override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
198+
// }
199+
200+
// TEST_P(Tests_MGWeightedMatching_Rmat, CheckInt64Int64FloatFloat)
201+
// {
202+
// run_current_test<int64_t, int64_t, float, int>(
203+
// override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
204+
// }
205+
206+
bool constexpr check_correctness = true;
207+
bool constexpr edge_masking = true;
208+
209+
// INSTANTIATE_TEST_SUITE_P(
210+
// file_test,
211+
// Tests_MGWeightedMatching_File,
212+
// ::testing::Combine(::testing::Values(WeightedMatching_UseCase{edge_masking, check_correctness},
213+
// WeightedMatching_UseCase{edge_masking,
214+
// check_correctness}),
215+
// ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"))));
205216

206217
INSTANTIATE_TEST_SUITE_P(
207218
rmat_small_test,
208219
Tests_MGWeightedMatching_Rmat,
209220
::testing::Combine(
210-
::testing::Values(WeightedMatching_UseCase{check_correctness}),
221+
::testing::Values(WeightedMatching_UseCase{edge_masking, check_correctness}),
211222
::testing::Values(cugraph::test::Rmat_Usecase(3, 2, 0.57, 0.19, 0.19, 0, true, false))));
212223

213-
INSTANTIATE_TEST_SUITE_P(
214-
rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with
215-
--gtest_filter to select only the rmat_benchmark_test with a specific
216-
vertex & edge type combination) by command line arguments and do not
217-
include more than one Rmat_Usecase that differ only in scale or edge
218-
factor (to avoid running same benchmarks more than once) */
219-
Tests_MGWeightedMatching_Rmat,
220-
::testing::Combine(
221-
::testing::Values(WeightedMatching_UseCase{check_correctness},
222-
WeightedMatching_UseCase{check_correctness}),
223-
::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, true, false))));
224+
// INSTANTIATE_TEST_SUITE_P(
225+
// rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with
226+
// --gtest_filter to select only the rmat_benchmark_test with a specific
227+
// vertex & edge type combination) by command line arguments and do not
228+
// include more than one Rmat_Usecase that differ only in scale or edge
229+
// factor (to avoid running same benchmarks more than once) */
230+
// Tests_MGWeightedMatching_Rmat,
231+
// ::testing::Combine(
232+
// ::testing::Values(WeightedMatching_UseCase{edge_masking, check_correctness},
233+
// WeightedMatching_UseCase{edge_masking, check_correctness}),
234+
// ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, true, false))));
224235

225236
CUGRAPH_MG_TEST_PROGRAM_MAIN()

0 commit comments

Comments
 (0)