|
| 1 | +/* |
| 2 | + * Copyright (c) 2024, NVIDIA CORPORATION. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governin_from_mtxg permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include "utilities/base_fixture.hpp" |
| 18 | +#include "utilities/conversion_utilities.hpp" |
| 19 | +#include "utilities/property_generator_utilities.hpp" |
| 20 | +#include "utilities/test_graphs.hpp" |
| 21 | + |
| 22 | +#include <cugraph/algorithms.hpp> |
| 23 | +#include <cugraph/edge_partition_view.hpp> |
| 24 | +#include <cugraph/edge_property.hpp> |
| 25 | +#include <cugraph/edge_src_dst_property.hpp> |
| 26 | +#include <cugraph/graph_functions.hpp> |
| 27 | +#include <cugraph/graph_view.hpp> |
| 28 | +#include <cugraph/utilities/dataframe_buffer.hpp> |
| 29 | +#include <cugraph/utilities/high_res_timer.hpp> |
| 30 | +#include <cugraph/utilities/host_scalar_comm.hpp> |
| 31 | + |
| 32 | +#include <raft/random/rng_state.hpp> |
| 33 | + |
| 34 | +#include <gtest/gtest.h> |
| 35 | + |
| 36 | +#include <chrono> |
| 37 | +#include <iostream> |
| 38 | +#include <random> |
| 39 | + |
| 40 | +struct WeightedMatching_UseCase { |
| 41 | + bool edge_masking{false}; |
| 42 | + bool check_correctness{true}; |
| 43 | +}; |
| 44 | + |
| 45 | +template <typename input_usecase_t> |
| 46 | +class Tests_MGWeightedMatching |
| 47 | + : public ::testing::TestWithParam<std::tuple<WeightedMatching_UseCase, input_usecase_t>> { |
| 48 | + public: |
| 49 | + Tests_MGWeightedMatching() {} |
| 50 | + |
| 51 | + static void SetUpTestCase() { handle_ = cugraph::test::initialize_mg_handle(); } |
| 52 | + static void TearDownTestCase() { handle_.reset(); } |
| 53 | + |
| 54 | + virtual void SetUp() {} |
| 55 | + virtual void TearDown() {} |
| 56 | + |
| 57 | + template <typename vertex_t, typename edge_t, typename weight_t, typename result_t> |
| 58 | + void run_current_test(std::tuple<WeightedMatching_UseCase, input_usecase_t> const& param) |
| 59 | + { |
| 60 | + auto [weighted_matching_usecase, input_usecase] = param; |
| 61 | + |
| 62 | + HighResTimer hr_timer{}; |
| 63 | + |
| 64 | + if (cugraph::test::g_perf) { |
| 65 | + RAFT_CUDA_TRY(cudaDeviceSynchronize()); |
| 66 | + handle_->get_comms().barrier(); |
| 67 | + hr_timer.start("MG Construct graph"); |
| 68 | + } |
| 69 | + |
| 70 | + constexpr bool multi_gpu = true; |
| 71 | + |
| 72 | + bool test_weighted = true; |
| 73 | + bool renumber = true; |
| 74 | + bool drop_self_loops = false; |
| 75 | + bool drop_multi_edges = false; |
| 76 | + |
| 77 | + auto [mg_graph, mg_edge_weights, mg_renumber_map] = |
| 78 | + cugraph::test::construct_graph<vertex_t, edge_t, weight_t, false, multi_gpu>( |
| 79 | + *handle_, input_usecase, test_weighted, renumber, drop_self_loops, drop_multi_edges); |
| 80 | + |
| 81 | + std::tie(mg_graph, mg_edge_weights, mg_renumber_map) = cugraph::symmetrize_graph( |
| 82 | + *handle_, |
| 83 | + std::move(mg_graph), |
| 84 | + std::move(mg_edge_weights), |
| 85 | + mg_renumber_map ? std::optional<rmm::device_uvector<vertex_t>>(std::move(*mg_renumber_map)) |
| 86 | + : std::nullopt, |
| 87 | + false); |
| 88 | + |
| 89 | + if (cugraph::test::g_perf) { |
| 90 | + RAFT_CUDA_TRY(cudaDeviceSynchronize()); |
| 91 | + handle_->get_comms().barrier(); |
| 92 | + hr_timer.stop(); |
| 93 | + hr_timer.display_and_clear(std::cout); |
| 94 | + } |
| 95 | + |
| 96 | + auto mg_graph_view = mg_graph.view(); |
| 97 | + auto mg_edge_weight_view = |
| 98 | + mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt; |
| 99 | + |
| 100 | + std::optional<cugraph::edge_property_t<decltype(mg_graph_view), bool>> edge_mask{std::nullopt}; |
| 101 | + if (weighted_matching_usecase.edge_masking) { |
| 102 | + edge_mask = cugraph::test::generate<decltype(mg_graph_view), bool>::edge_property( |
| 103 | + *handle_, mg_graph_view, 2); |
| 104 | + mg_graph_view.attach_edge_mask((*edge_mask).view()); |
| 105 | + } |
| 106 | + |
| 107 | + rmm::device_uvector<vertex_t> mg_partners(0, handle_->get_stream()); |
| 108 | + weight_t mg_matching_weights; |
| 109 | + |
| 110 | + std::forward_as_tuple(mg_partners, mg_matching_weights) = |
| 111 | + cugraph::approximate_weighted_matching<vertex_t, edge_t, weight_t, multi_gpu>( |
| 112 | + *handle_, mg_graph_view, (*mg_edge_weights).view()); |
| 113 | + |
| 114 | + if (weighted_matching_usecase.check_correctness) { |
| 115 | + auto h_mg_partners = cugraph::test::to_host(*handle_, mg_partners); |
| 116 | + |
| 117 | + auto constexpr invalid_partner = cugraph::invalid_vertex_id<vertex_t>::value; |
| 118 | + |
| 119 | + rmm::device_uvector<vertex_t> mg_aggregate_partners(0, handle_->get_stream()); |
| 120 | + std::tie(std::ignore, mg_aggregate_partners) = |
| 121 | + cugraph::test::mg_vertex_property_values_to_sg_vertex_property_values( |
| 122 | + *handle_, |
| 123 | + std::optional<raft::device_span<vertex_t const>>{std::nullopt}, |
| 124 | + mg_graph_view.local_vertex_partition_range(), |
| 125 | + std::optional<raft::device_span<vertex_t const>>{std::nullopt}, |
| 126 | + std::optional<raft::device_span<vertex_t const>>{std::nullopt}, |
| 127 | + raft::device_span<vertex_t const>(mg_partners.data(), mg_partners.size())); |
| 128 | + |
| 129 | + cugraph::graph_t<vertex_t, edge_t, false, false> sg_graph(*handle_); |
| 130 | + std::optional< |
| 131 | + cugraph::edge_property_t<cugraph::graph_view_t<vertex_t, edge_t, false, false>, weight_t>> |
| 132 | + sg_edge_weights{std::nullopt}; |
| 133 | + std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( |
| 134 | + *handle_, |
| 135 | + mg_graph_view, |
| 136 | + mg_edge_weight_view, |
| 137 | + std::optional<raft::device_span<vertex_t const>>(std::nullopt), |
| 138 | + false); |
| 139 | + |
| 140 | + if (handle_->get_comms().get_rank() == 0) { |
| 141 | + auto sg_graph_view = sg_graph.view(); |
| 142 | + |
| 143 | + rmm::device_uvector<vertex_t> sg_partners(0, handle_->get_stream()); |
| 144 | + weight_t sg_matching_weights; |
| 145 | + |
| 146 | + std::forward_as_tuple(sg_partners, sg_matching_weights) = |
| 147 | + cugraph::approximate_weighted_matching<vertex_t, edge_t, weight_t, false>( |
| 148 | + *handle_, sg_graph_view, (*sg_edge_weights).view()); |
| 149 | + auto h_sg_partners = cugraph::test::to_host(*handle_, sg_partners); |
| 150 | + auto h_mg_aggregate_partners = cugraph::test::to_host(*handle_, mg_aggregate_partners); |
| 151 | + |
| 152 | + ASSERT_FLOAT_EQ(mg_matching_weights, sg_matching_weights) |
| 153 | + << "SG and MG matching weights are different"; |
| 154 | + ASSERT_TRUE( |
| 155 | + std::equal(h_sg_partners.begin(), h_sg_partners.end(), h_mg_aggregate_partners.begin())); |
| 156 | + } |
| 157 | + } |
| 158 | + } |
| 159 | + |
| 160 | + private: |
| 161 | + static std::unique_ptr<raft::handle_t> handle_; |
| 162 | +}; |
| 163 | + |
| 164 | +template <typename input_usecase_t> |
| 165 | +std::unique_ptr<raft::handle_t> Tests_MGWeightedMatching<input_usecase_t>::handle_ = nullptr; |
| 166 | + |
| 167 | +using Tests_MGWeightedMatching_File = Tests_MGWeightedMatching<cugraph::test::File_Usecase>; |
| 168 | +using Tests_MGWeightedMatching_Rmat = Tests_MGWeightedMatching<cugraph::test::Rmat_Usecase>; |
| 169 | + |
| 170 | +TEST_P(Tests_MGWeightedMatching_File, CheckInt32Int32FloatFloat) |
| 171 | +{ |
| 172 | + run_current_test<int32_t, int32_t, float, int>( |
| 173 | + override_File_Usecase_with_cmd_line_arguments(GetParam())); |
| 174 | +} |
| 175 | + |
| 176 | +TEST_P(Tests_MGWeightedMatching_File, CheckInt32Int64FloatFloat) |
| 177 | +{ |
| 178 | + run_current_test<int32_t, int64_t, float, int>( |
| 179 | + override_File_Usecase_with_cmd_line_arguments(GetParam())); |
| 180 | +} |
| 181 | + |
| 182 | +TEST_P(Tests_MGWeightedMatching_File, CheckInt64Int64FloatFloat) |
| 183 | +{ |
| 184 | + run_current_test<int64_t, int64_t, float, int>( |
| 185 | + override_File_Usecase_with_cmd_line_arguments(GetParam())); |
| 186 | +} |
| 187 | + |
| 188 | +TEST_P(Tests_MGWeightedMatching_Rmat, CheckInt32Int32FloatFloat) |
| 189 | +{ |
| 190 | + run_current_test<int32_t, int32_t, float, int>( |
| 191 | + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); |
| 192 | +} |
| 193 | + |
| 194 | +TEST_P(Tests_MGWeightedMatching_Rmat, CheckInt32Int64FloatFloat) |
| 195 | +{ |
| 196 | + run_current_test<int32_t, int64_t, float, int>( |
| 197 | + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); |
| 198 | +} |
| 199 | + |
| 200 | +TEST_P(Tests_MGWeightedMatching_Rmat, CheckInt64Int64FloatFloat) |
| 201 | +{ |
| 202 | + run_current_test<int64_t, int64_t, float, int>( |
| 203 | + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); |
| 204 | +} |
| 205 | + |
| 206 | +INSTANTIATE_TEST_SUITE_P( |
| 207 | + file_test, |
| 208 | + Tests_MGWeightedMatching_File, |
| 209 | + ::testing::Combine(::testing::Values(WeightedMatching_UseCase{false}, |
| 210 | + WeightedMatching_UseCase{true}), |
| 211 | + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); |
| 212 | + |
| 213 | +INSTANTIATE_TEST_SUITE_P(rmat_small_test, |
| 214 | + Tests_MGWeightedMatching_Rmat, |
| 215 | + ::testing::Combine(::testing::Values(WeightedMatching_UseCase{false}, |
| 216 | + WeightedMatching_UseCase{true}), |
| 217 | + ::testing::Values(cugraph::test::Rmat_Usecase( |
| 218 | + 3, 2, 0.57, 0.19, 0.19, 0, true, false)))); |
| 219 | + |
| 220 | +INSTANTIATE_TEST_SUITE_P( |
| 221 | + rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with |
| 222 | + --gtest_filter to select only the rmat_benchmark_test with a specific |
| 223 | + vertex & edge type combination) by command line arguments and do not |
| 224 | + include more than one Rmat_Usecase that differ only in scale or edge |
| 225 | + factor (to avoid running same benchmarks more than once) */ |
| 226 | + Tests_MGWeightedMatching_Rmat, |
| 227 | + ::testing::Combine( |
| 228 | + ::testing::Values(WeightedMatching_UseCase{false, false}, |
| 229 | + WeightedMatching_UseCase{true, false}), |
| 230 | + ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, true, false)))); |
| 231 | + |
| 232 | +CUGRAPH_MG_TEST_PROGRAM_MAIN() |
0 commit comments