Skip to content

Commit 2779f32

Browse files
authored
MNMG Approximation Algorithm for the Weighted Matching Problem (#4315)
MNMG [Approximation Algorithm for the Weighted Matching Problem](https://web.archive.org/web/20081031230449id_/http://www.ii.uib.no/~fredrikm/fredrik/papers/CP75.pdf) Authors: - Naim (https://github.com/naimnv) Approvers: - Chuck Hastings (https://github.com/ChuckHastings) - Seunghwa Kang (https://github.com/seunghwak) URL: #4315
1 parent 6bd08d2 commit 2779f32

File tree

8 files changed

+942
-0
lines changed

8 files changed

+942
-0
lines changed

cpp/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ set(CUGRAPH_SOURCES
288288
src/structure/symmetrize_edgelist_mg.cu
289289
src/community/triangle_count_sg.cu
290290
src/community/triangle_count_mg.cu
291+
src/community/approx_weighted_matching_sg.cu
292+
src/community/approx_weighted_matching_mg.cu
291293
src/traversal/k_hop_nbrs_sg.cu
292294
src/traversal/k_hop_nbrs_mg.cu
293295
src/mtmg/vertex_result.cu

cpp/include/cugraph/algorithms.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2368,6 +2368,32 @@ rmm::device_uvector<vertex_t> vertex_coloring(
23682368
graph_view_t<vertex_t, edge_t, false, multi_gpu> const& graph_view,
23692369
raft::random::RngState& rng_state);
23702370

2371+
/*
2372+
* @brief Approximate Weighted Matching
2373+
*
2374+
* A matching in an undirected graph G = (V, E) is a pairing of adjacent vertices
2375+
* such that each vertex is matched with at most one other vertex, the objective
2376+
* being to match as many vertices as possible or to maximise the sum of the
2377+
* weights of the matched edges. Here we provide an implementation of an
2378+
* approximation algorithm to the weighted Maximum matching. See
2379+
* https://web.archive.org/web/20081031230449id_/http://www.ii.uib.no/~fredrikm/fredrik/papers/CP75.pdf
2380+
* for further information.
2381+
*
2382+
* @tparam vertex_t Type of vertex identifiers. Needs to be an integral type.
2383+
* @tparam edge_t Type of edge identifiers. Needs to be an integral type.
2384+
* @tparam multi_gpu Flag indicating whether template instantiation should target single-GPU (false)
2385+
* @param[in] handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator,
2386+
* and handles to various CUDA libraries) to run graph algorithms.
2387+
* @param[in] graph_view Graph view object.
2388+
* @param[in] edge_weight_view View object holding edge weights for @p graph_view.
2389+
* @return A tuple of device vector of matched vertex ids and sum of the weights of the matched
2390+
* edges.
2391+
*/
2392+
template <typename vertex_t, typename edge_t, typename weight_t, bool multi_gpu>
2393+
std::tuple<rmm::device_uvector<vertex_t>, weight_t> approximate_weighted_matching(
2394+
raft::handle_t const& handle,
2395+
graph_view_t<vertex_t, edge_t, false, multi_gpu> const& graph_view,
2396+
edge_property_view_t<edge_t, weight_t const*> edge_weight_view);
23712397
} // namespace cugraph
23722398

23732399
/**

cpp/src/community/approx_weighted_matching_impl.cuh

Lines changed: 392 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "approx_weighted_matching_impl.cuh"
17+
18+
namespace cugraph {
19+
20+
template std::tuple<rmm::device_uvector<int32_t>, float> approximate_weighted_matching(
21+
raft::handle_t const& handle,
22+
graph_view_t<int32_t, int32_t, false, true> const& graph_view,
23+
edge_property_view_t<int32_t, float const*> edge_weight_view);
24+
25+
template std::tuple<rmm::device_uvector<int32_t>, double> approximate_weighted_matching(
26+
raft::handle_t const& handle,
27+
graph_view_t<int32_t, int32_t, false, true> const& graph_view,
28+
edge_property_view_t<int32_t, double const*> edge_weight_view);
29+
30+
template std::tuple<rmm::device_uvector<int32_t>, float> approximate_weighted_matching(
31+
raft::handle_t const& handle,
32+
graph_view_t<int32_t, int64_t, false, true> const& graph_view,
33+
edge_property_view_t<int64_t, float const*> edge_weight_view);
34+
35+
template std::tuple<rmm::device_uvector<int64_t>, float> approximate_weighted_matching(
36+
raft::handle_t const& handle,
37+
graph_view_t<int64_t, int64_t, false, true> const& graph_view,
38+
edge_property_view_t<int64_t, float const*> edge_weight_view);
39+
40+
template std::tuple<rmm::device_uvector<int32_t>, double> approximate_weighted_matching(
41+
raft::handle_t const& handle,
42+
graph_view_t<int32_t, int64_t, false, true> const& graph_view,
43+
edge_property_view_t<int64_t, double const*> edge_weight_view);
44+
45+
template std::tuple<rmm::device_uvector<int64_t>, double> approximate_weighted_matching(
46+
raft::handle_t const& handle,
47+
graph_view_t<int64_t, int64_t, false, true> const& graph_view,
48+
edge_property_view_t<int64_t, double const*> edge_weight_view);
49+
50+
} // namespace cugraph
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "approx_weighted_matching_impl.cuh"
17+
18+
namespace cugraph {
19+
20+
template std::tuple<rmm::device_uvector<int32_t>, float> approximate_weighted_matching(
21+
raft::handle_t const& handle,
22+
graph_view_t<int32_t, int32_t, false, false> const& graph_view,
23+
edge_property_view_t<int32_t, float const*> edge_weight_view);
24+
25+
template std::tuple<rmm::device_uvector<int32_t>, double> approximate_weighted_matching(
26+
raft::handle_t const& handle,
27+
graph_view_t<int32_t, int32_t, false, false> const& graph_view,
28+
edge_property_view_t<int32_t, double const*> edge_weight_view);
29+
30+
template std::tuple<rmm::device_uvector<int32_t>, float> approximate_weighted_matching(
31+
raft::handle_t const& handle,
32+
graph_view_t<int32_t, int64_t, false, false> const& graph_view,
33+
edge_property_view_t<int64_t, float const*> edge_weight_view);
34+
35+
template std::tuple<rmm::device_uvector<int64_t>, float> approximate_weighted_matching(
36+
raft::handle_t const& handle,
37+
graph_view_t<int64_t, int64_t, false, false> const& graph_view,
38+
edge_property_view_t<int64_t, float const*> edge_weight_view);
39+
40+
template std::tuple<rmm::device_uvector<int32_t>, double> approximate_weighted_matching(
41+
raft::handle_t const& handle,
42+
graph_view_t<int32_t, int64_t, false, false> const& graph_view,
43+
edge_property_view_t<int64_t, double const*> edge_weight_view);
44+
45+
template std::tuple<rmm::device_uvector<int64_t>, double> approximate_weighted_matching(
46+
raft::handle_t const& handle,
47+
graph_view_t<int64_t, int64_t, false, false> const& graph_view,
48+
edge_property_view_t<int64_t, double const*> edge_weight_view);
49+
50+
} // namespace cugraph

cpp/tests/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,10 @@ ConfigureTest(LOUVAIN_TEST community/louvain_test.cpp)
309309
# - LEIDEN tests ----------------------------------------------------------------------------------
310310
ConfigureTest(LEIDEN_TEST community/leiden_test.cpp)
311311

312+
###################################################################################################
313+
# - WEIGHTED MATCHING tests ----------------------------------------------------------------------------------
314+
ConfigureTest(WEIGHTED_MATCHING_TEST community/weighted_matching_test.cpp)
315+
312316
###################################################################################################
313317
# - Legacy ECG tests -------------------------------------------------------------------------------------
314318
ConfigureTest(LEGACY_ECG_TEST community/legacy_ecg_test.cpp)
@@ -570,6 +574,10 @@ if(BUILD_CUGRAPH_MG_TESTS)
570574
# - MG LEIDEN tests --------------------------------------------------------------------------
571575
ConfigureTestMG(MG_LEIDEN_TEST community/mg_leiden_test.cpp)
572576

577+
###############################################################################################
578+
# - MG WEIGHTED MATCHING tests --------------------------------------------------------------------------
579+
ConfigureTestMG(MG_WEIGHTED_MATCHING_TEST community/mg_weighted_matching_test.cpp)
580+
573581
###############################################################################################
574582
# - MG ECG tests --------------------------------------------------------------------------
575583
ConfigureTestMG(MG_ECG_TEST community/mg_ecg_test.cpp)
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governin_from_mtxg permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "utilities/base_fixture.hpp"
18+
#include "utilities/conversion_utilities.hpp"
19+
#include "utilities/property_generator_utilities.hpp"
20+
#include "utilities/test_graphs.hpp"
21+
22+
#include <cugraph/algorithms.hpp>
23+
#include <cugraph/edge_partition_view.hpp>
24+
#include <cugraph/edge_property.hpp>
25+
#include <cugraph/edge_src_dst_property.hpp>
26+
#include <cugraph/graph_functions.hpp>
27+
#include <cugraph/graph_view.hpp>
28+
#include <cugraph/utilities/dataframe_buffer.hpp>
29+
#include <cugraph/utilities/high_res_timer.hpp>
30+
#include <cugraph/utilities/host_scalar_comm.hpp>
31+
32+
#include <raft/random/rng_state.hpp>
33+
34+
#include <gtest/gtest.h>
35+
36+
#include <chrono>
37+
#include <iostream>
38+
#include <random>
39+
40+
struct WeightedMatching_UseCase {
41+
bool edge_masking{false};
42+
bool check_correctness{true};
43+
};
44+
45+
template <typename input_usecase_t>
46+
class Tests_MGWeightedMatching
47+
: public ::testing::TestWithParam<std::tuple<WeightedMatching_UseCase, input_usecase_t>> {
48+
public:
49+
Tests_MGWeightedMatching() {}
50+
51+
static void SetUpTestCase() { handle_ = cugraph::test::initialize_mg_handle(); }
52+
static void TearDownTestCase() { handle_.reset(); }
53+
54+
virtual void SetUp() {}
55+
virtual void TearDown() {}
56+
57+
template <typename vertex_t, typename edge_t, typename weight_t, typename result_t>
58+
void run_current_test(std::tuple<WeightedMatching_UseCase, input_usecase_t> const& param)
59+
{
60+
auto [weighted_matching_usecase, input_usecase] = param;
61+
62+
HighResTimer hr_timer{};
63+
64+
if (cugraph::test::g_perf) {
65+
RAFT_CUDA_TRY(cudaDeviceSynchronize());
66+
handle_->get_comms().barrier();
67+
hr_timer.start("MG Construct graph");
68+
}
69+
70+
constexpr bool multi_gpu = true;
71+
72+
bool test_weighted = true;
73+
bool renumber = true;
74+
bool drop_self_loops = false;
75+
bool drop_multi_edges = false;
76+
77+
auto [mg_graph, mg_edge_weights, mg_renumber_map] =
78+
cugraph::test::construct_graph<vertex_t, edge_t, weight_t, false, multi_gpu>(
79+
*handle_, input_usecase, test_weighted, renumber, drop_self_loops, drop_multi_edges);
80+
81+
std::tie(mg_graph, mg_edge_weights, mg_renumber_map) = cugraph::symmetrize_graph(
82+
*handle_,
83+
std::move(mg_graph),
84+
std::move(mg_edge_weights),
85+
mg_renumber_map ? std::optional<rmm::device_uvector<vertex_t>>(std::move(*mg_renumber_map))
86+
: std::nullopt,
87+
false);
88+
89+
if (cugraph::test::g_perf) {
90+
RAFT_CUDA_TRY(cudaDeviceSynchronize());
91+
handle_->get_comms().barrier();
92+
hr_timer.stop();
93+
hr_timer.display_and_clear(std::cout);
94+
}
95+
96+
auto mg_graph_view = mg_graph.view();
97+
auto mg_edge_weight_view =
98+
mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt;
99+
100+
std::optional<cugraph::edge_property_t<decltype(mg_graph_view), bool>> edge_mask{std::nullopt};
101+
if (weighted_matching_usecase.edge_masking) {
102+
edge_mask = cugraph::test::generate<decltype(mg_graph_view), bool>::edge_property(
103+
*handle_, mg_graph_view, 2);
104+
mg_graph_view.attach_edge_mask((*edge_mask).view());
105+
}
106+
107+
rmm::device_uvector<vertex_t> mg_partners(0, handle_->get_stream());
108+
weight_t mg_matching_weights;
109+
110+
std::forward_as_tuple(mg_partners, mg_matching_weights) =
111+
cugraph::approximate_weighted_matching<vertex_t, edge_t, weight_t, multi_gpu>(
112+
*handle_, mg_graph_view, (*mg_edge_weights).view());
113+
114+
if (weighted_matching_usecase.check_correctness) {
115+
auto h_mg_partners = cugraph::test::to_host(*handle_, mg_partners);
116+
117+
auto constexpr invalid_partner = cugraph::invalid_vertex_id<vertex_t>::value;
118+
119+
rmm::device_uvector<vertex_t> mg_aggregate_partners(0, handle_->get_stream());
120+
std::tie(std::ignore, mg_aggregate_partners) =
121+
cugraph::test::mg_vertex_property_values_to_sg_vertex_property_values(
122+
*handle_,
123+
std::optional<raft::device_span<vertex_t const>>{std::nullopt},
124+
mg_graph_view.local_vertex_partition_range(),
125+
std::optional<raft::device_span<vertex_t const>>{std::nullopt},
126+
std::optional<raft::device_span<vertex_t const>>{std::nullopt},
127+
raft::device_span<vertex_t const>(mg_partners.data(), mg_partners.size()));
128+
129+
cugraph::graph_t<vertex_t, edge_t, false, false> sg_graph(*handle_);
130+
std::optional<
131+
cugraph::edge_property_t<cugraph::graph_view_t<vertex_t, edge_t, false, false>, weight_t>>
132+
sg_edge_weights{std::nullopt};
133+
std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph(
134+
*handle_,
135+
mg_graph_view,
136+
mg_edge_weight_view,
137+
std::optional<raft::device_span<vertex_t const>>(std::nullopt),
138+
false);
139+
140+
if (handle_->get_comms().get_rank() == 0) {
141+
auto sg_graph_view = sg_graph.view();
142+
143+
rmm::device_uvector<vertex_t> sg_partners(0, handle_->get_stream());
144+
weight_t sg_matching_weights;
145+
146+
std::forward_as_tuple(sg_partners, sg_matching_weights) =
147+
cugraph::approximate_weighted_matching<vertex_t, edge_t, weight_t, false>(
148+
*handle_, sg_graph_view, (*sg_edge_weights).view());
149+
auto h_sg_partners = cugraph::test::to_host(*handle_, sg_partners);
150+
auto h_mg_aggregate_partners = cugraph::test::to_host(*handle_, mg_aggregate_partners);
151+
152+
ASSERT_FLOAT_EQ(mg_matching_weights, sg_matching_weights)
153+
<< "SG and MG matching weights are different";
154+
ASSERT_TRUE(
155+
std::equal(h_sg_partners.begin(), h_sg_partners.end(), h_mg_aggregate_partners.begin()));
156+
}
157+
}
158+
}
159+
160+
private:
161+
static std::unique_ptr<raft::handle_t> handle_;
162+
};
163+
164+
template <typename input_usecase_t>
165+
std::unique_ptr<raft::handle_t> Tests_MGWeightedMatching<input_usecase_t>::handle_ = nullptr;
166+
167+
using Tests_MGWeightedMatching_File = Tests_MGWeightedMatching<cugraph::test::File_Usecase>;
168+
using Tests_MGWeightedMatching_Rmat = Tests_MGWeightedMatching<cugraph::test::Rmat_Usecase>;
169+
170+
TEST_P(Tests_MGWeightedMatching_File, CheckInt32Int32FloatFloat)
171+
{
172+
run_current_test<int32_t, int32_t, float, int>(
173+
override_File_Usecase_with_cmd_line_arguments(GetParam()));
174+
}
175+
176+
TEST_P(Tests_MGWeightedMatching_File, CheckInt32Int64FloatFloat)
177+
{
178+
run_current_test<int32_t, int64_t, float, int>(
179+
override_File_Usecase_with_cmd_line_arguments(GetParam()));
180+
}
181+
182+
TEST_P(Tests_MGWeightedMatching_File, CheckInt64Int64FloatFloat)
183+
{
184+
run_current_test<int64_t, int64_t, float, int>(
185+
override_File_Usecase_with_cmd_line_arguments(GetParam()));
186+
}
187+
188+
TEST_P(Tests_MGWeightedMatching_Rmat, CheckInt32Int32FloatFloat)
189+
{
190+
run_current_test<int32_t, int32_t, float, int>(
191+
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
192+
}
193+
194+
TEST_P(Tests_MGWeightedMatching_Rmat, CheckInt32Int64FloatFloat)
195+
{
196+
run_current_test<int32_t, int64_t, float, int>(
197+
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
198+
}
199+
200+
TEST_P(Tests_MGWeightedMatching_Rmat, CheckInt64Int64FloatFloat)
201+
{
202+
run_current_test<int64_t, int64_t, float, int>(
203+
override_Rmat_Usecase_with_cmd_line_arguments(GetParam()));
204+
}
205+
206+
INSTANTIATE_TEST_SUITE_P(
207+
file_test,
208+
Tests_MGWeightedMatching_File,
209+
::testing::Combine(::testing::Values(WeightedMatching_UseCase{false},
210+
WeightedMatching_UseCase{true}),
211+
::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"))));
212+
213+
INSTANTIATE_TEST_SUITE_P(rmat_small_test,
214+
Tests_MGWeightedMatching_Rmat,
215+
::testing::Combine(::testing::Values(WeightedMatching_UseCase{false},
216+
WeightedMatching_UseCase{true}),
217+
::testing::Values(cugraph::test::Rmat_Usecase(
218+
3, 2, 0.57, 0.19, 0.19, 0, true, false))));
219+
220+
INSTANTIATE_TEST_SUITE_P(
221+
rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with
222+
--gtest_filter to select only the rmat_benchmark_test with a specific
223+
vertex & edge type combination) by command line arguments and do not
224+
include more than one Rmat_Usecase that differ only in scale or edge
225+
factor (to avoid running same benchmarks more than once) */
226+
Tests_MGWeightedMatching_Rmat,
227+
::testing::Combine(
228+
::testing::Values(WeightedMatching_UseCase{false, false},
229+
WeightedMatching_UseCase{true, false}),
230+
::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, true, false))));
231+
232+
CUGRAPH_MG_TEST_PROGRAM_MAIN()

0 commit comments

Comments
 (0)