Skip to content

Commit d3fa655

Browse files
Compilation optimizations - batch 1 (#5112)
Breaking previous PR into smaller chunks. Should be easier to diagnose issues causing failures of tests in CI. Batch 1: add arithmetic variant types, first use in groupby and count functions. 2% reduction in size of libcugraph.so Authors: - Chuck Hastings (https://github.com/ChuckHastings) Approvers: - Seunghwa Kang (https://github.com/seunghwak) - Joseph Nke (https://github.com/jnke2016) URL: #5112
1 parent 2d1fa05 commit d3fa655

File tree

6 files changed

+373
-319
lines changed

6 files changed

+373
-319
lines changed
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <cugraph/edge_property.hpp>
19+
#include <cugraph/utilities/error.hpp>
20+
21+
#include <raft/core/device_span.hpp>
22+
23+
#include <rmm/device_uvector.hpp>
24+
25+
#include <variant>
26+
27+
namespace cugraph {
28+
29+
using arithmetic_device_uvector_t = std::variant<std::monostate,
30+
rmm::device_uvector<float>,
31+
rmm::device_uvector<double>,
32+
rmm::device_uvector<int32_t>,
33+
rmm::device_uvector<int64_t>,
34+
rmm::device_uvector<size_t>>;
35+
using arithmetic_device_span_t = std::variant<std::monostate,
36+
raft::device_span<float>,
37+
raft::device_span<double>,
38+
raft::device_span<int32_t>,
39+
raft::device_span<int64_t>,
40+
raft::device_span<size_t>>;
41+
using const_arithmetic_device_span_t = std::variant<std::monostate,
42+
raft::device_span<float const>,
43+
raft::device_span<double const>,
44+
raft::device_span<int32_t const>,
45+
raft::device_span<int64_t const>,
46+
raft::device_span<size_t const>>;
47+
48+
template <typename edge_t>
49+
using edge_arithmetic_property_view_t =
50+
std::variant<std::monostate,
51+
cugraph::edge_property_view_t<edge_t, float const*>,
52+
cugraph::edge_property_view_t<edge_t, double const*>,
53+
cugraph::edge_property_view_t<edge_t, int32_t const*>,
54+
cugraph::edge_property_view_t<edge_t, int64_t const*>,
55+
cugraph::edge_property_view_t<edge_t, size_t const*>>;
56+
57+
template <typename edge_t>
58+
using edge_arithmetic_property_mutable_view_t =
59+
std::variant<std::monostate,
60+
cugraph::edge_property_view_t<edge_t, float*>,
61+
cugraph::edge_property_view_t<edge_t, double*>,
62+
cugraph::edge_property_view_t<edge_t, int32_t*>,
63+
cugraph::edge_property_view_t<edge_t, int64_t*>,
64+
cugraph::edge_property_view_t<edge_t, size_t*>>;
65+
66+
template <typename func_t>
67+
auto variant_type_dispatch(arithmetic_device_uvector_t& property, func_t func)
68+
{
69+
if (std::holds_alternative<rmm::device_uvector<float>>(property)) {
70+
auto& prop = std::get<rmm::device_uvector<float>>(property);
71+
return func(prop);
72+
} else if (std::holds_alternative<rmm::device_uvector<double>>(property)) {
73+
auto& prop = std::get<rmm::device_uvector<double>>(property);
74+
return func(prop);
75+
} else if (std::holds_alternative<rmm::device_uvector<int32_t>>(property)) {
76+
auto& prop = std::get<rmm::device_uvector<int32_t>>(property);
77+
return func(prop);
78+
} else if (std::holds_alternative<rmm::device_uvector<int64_t>>(property)) {
79+
auto& prop = std::get<rmm::device_uvector<int64_t>>(property);
80+
return func(prop);
81+
} else {
82+
CUGRAPH_EXPECTS(std::holds_alternative<rmm::device_uvector<size_t>>(property),
83+
"unsupported variant type -- shouldn't happen");
84+
85+
auto& prop = std::get<rmm::device_uvector<size_t>>(property);
86+
return func(prop);
87+
}
88+
}
89+
90+
template <typename func_t>
91+
auto variant_type_dispatch(arithmetic_device_uvector_t const& property, func_t func)
92+
{
93+
if (std::holds_alternative<rmm::device_uvector<float>>(property)) {
94+
auto& prop = std::get<rmm::device_uvector<float>>(property);
95+
return func(prop);
96+
} else if (std::holds_alternative<rmm::device_uvector<double>>(property)) {
97+
auto& prop = std::get<rmm::device_uvector<double>>(property);
98+
return func(prop);
99+
} else if (std::holds_alternative<rmm::device_uvector<int32_t>>(property)) {
100+
auto& prop = std::get<rmm::device_uvector<int32_t>>(property);
101+
return func(prop);
102+
} else if (std::holds_alternative<rmm::device_uvector<int64_t>>(property)) {
103+
auto& prop = std::get<rmm::device_uvector<int64_t>>(property);
104+
return func(prop);
105+
} else {
106+
CUGRAPH_EXPECTS(std::holds_alternative<rmm::device_uvector<size_t>>(property),
107+
"unsupported variant type -- shouldn't happen");
108+
auto& prop = std::get<rmm::device_uvector<size_t>>(property);
109+
return func(prop);
110+
}
111+
}
112+
113+
template <typename func_t>
114+
auto variant_type_dispatch(arithmetic_device_span_t& property, func_t func)
115+
{
116+
if (std::holds_alternative<raft::device_span<float>>(property)) {
117+
auto& prop = std::get<raft::device_span<float>>(property);
118+
return func(prop);
119+
} else if (std::holds_alternative<raft::device_span<double>>(property)) {
120+
auto& prop = std::get<raft::device_span<double>>(property);
121+
return func(prop);
122+
} else if (std::holds_alternative<raft::device_span<int32_t>>(property)) {
123+
auto& prop = std::get<raft::device_span<int32_t>>(property);
124+
return func(prop);
125+
} else if (std::holds_alternative<raft::device_span<int64_t>>(property)) {
126+
auto& prop = std::get<raft::device_span<int64_t>>(property);
127+
return func(prop);
128+
} else {
129+
CUGRAPH_EXPECTS(std::holds_alternative<raft::device_span<size_t>>(property),
130+
"unsupported variant type -- shouldn't happen");
131+
132+
auto& prop = std::get<raft::device_span<size_t>>(property);
133+
return func(prop);
134+
}
135+
}
136+
137+
template <typename func_t>
138+
auto variant_type_dispatch(const_arithmetic_device_span_t& property, func_t func)
139+
{
140+
if (std::holds_alternative<raft::device_span<float const>>(property)) {
141+
auto& prop = std::get<raft::device_span<float const>>(property);
142+
return func(prop);
143+
} else if (std::holds_alternative<raft::device_span<double const>>(property)) {
144+
auto& prop = std::get<raft::device_span<double const>>(property);
145+
return func(prop);
146+
} else if (std::holds_alternative<raft::device_span<int32_t const>>(property)) {
147+
auto& prop = std::get<raft::device_span<int32_t const>>(property);
148+
return func(prop);
149+
} else if (std::holds_alternative<raft::device_span<int64_t const>>(property)) {
150+
auto& prop = std::get<raft::device_span<int64_t const>>(property);
151+
return func(prop);
152+
} else {
153+
CUGRAPH_EXPECTS(std::holds_alternative<raft::device_span<size_t const>>(property),
154+
"unsupported variant type -- shouldn't happen");
155+
156+
auto& prop = std::get<raft::device_span<size_t const>>(property);
157+
return func(prop);
158+
}
159+
}
160+
161+
template <typename edge_t, typename func_t>
162+
auto variant_type_dispatch(edge_arithmetic_property_view_t<edge_t>& property, func_t func)
163+
{
164+
if (std::holds_alternative<cugraph::edge_property_view_t<edge_t, float const*>>(property)) {
165+
auto& prop = std::get<cugraph::edge_property_view_t<edge_t, float const*>>(property);
166+
return func(prop);
167+
} else if (std::holds_alternative<cugraph::edge_property_view_t<edge_t, double const*>>(
168+
property)) {
169+
auto& prop = std::get<cugraph::edge_property_view_t<edge_t, double const*>>(property);
170+
return func(prop);
171+
} else if (std::holds_alternative<cugraph::edge_property_view_t<edge_t, int32_t const*>>(
172+
property)) {
173+
auto& prop = std::get<cugraph::edge_property_view_t<edge_t, int32_t const*>>(property);
174+
return func(prop);
175+
} else if (std::holds_alternative<cugraph::edge_property_view_t<edge_t, int64_t const*>>(
176+
property)) {
177+
auto& prop = std::get<cugraph::edge_property_view_t<edge_t, int64_t const*>>(property);
178+
return func(prop);
179+
} else {
180+
CUGRAPH_EXPECTS(
181+
(std::holds_alternative<cugraph::edge_property_view_t<edge_t, size_t const*>>(property)),
182+
"unsupported variant type -- shouldn't happen");
183+
184+
auto& prop = std::get<cugraph::edge_property_view_t<edge_t, size_t const*>>(property);
185+
return func(prop);
186+
}
187+
}
188+
189+
template <typename edge_t, typename func_t>
190+
auto variant_type_dispatch(edge_arithmetic_property_mutable_view_t<edge_t>& property, func_t func)
191+
{
192+
if (std::holds_alternative<cugraph::edge_property_view_t<edge_t, float*>>(property)) {
193+
auto& prop = std::get<cugraph::edge_property_view_t<edge_t, float*>>(property);
194+
return func(prop);
195+
} else if (std::holds_alternative<cugraph::edge_property_view_t<edge_t, double*>>(property)) {
196+
auto& prop = std::get<cugraph::edge_property_view_t<edge_t, double*>>(property);
197+
return func(prop);
198+
} else if (std::holds_alternative<cugraph::edge_property_view_t<edge_t, int32_t*>>(property)) {
199+
auto& prop = std::get<cugraph::edge_property_view_t<edge_t, int32_t*>>(property);
200+
return func(prop);
201+
} else if (std::holds_alternative<cugraph::edge_property_view_t<edge_t, int64_t*>>(property)) {
202+
auto& prop = std::get<cugraph::edge_property_view_t<edge_t, int64_t*>>(property);
203+
return func(prop);
204+
} else {
205+
CUGRAPH_EXPECTS(
206+
(std::holds_alternative<cugraph::edge_property_view_t<edge_t, size_t*>>(property)),
207+
"unsupported variant type -- shouldn't happen");
208+
209+
auto& prop = std::get<cugraph::edge_property_view_t<edge_t, size_t const*>>(property);
210+
return func(prop);
211+
}
212+
}
213+
214+
struct sizeof_arithmetic_element {
215+
template <typename T>
216+
size_t operator()(rmm::device_uvector<T> const&) const
217+
{
218+
return sizeof(T);
219+
}
220+
template <typename T>
221+
size_t operator()(raft::device_span<T> const&) const
222+
{
223+
return sizeof(T);
224+
}
225+
template <typename T>
226+
size_t operator()(raft::device_span<T const> const&) const
227+
{
228+
return sizeof(T);
229+
}
230+
};
231+
232+
inline arithmetic_device_span_t make_arithmetic_device_span(arithmetic_device_uvector_t& v)
233+
{
234+
return variant_type_dispatch(v, [](auto& v) {
235+
using T = typename std::remove_reference<decltype(v)>::type::value_type;
236+
return static_cast<arithmetic_device_span_t>(raft::device_span<T>(v.data(), v.size()));
237+
});
238+
}
239+
240+
inline std::vector<arithmetic_device_span_t> make_arithmetic_device_span_vector(
241+
std::vector<arithmetic_device_uvector_t>& v)
242+
{
243+
std::vector<arithmetic_device_span_t> results(v.size());
244+
std::transform(
245+
v.begin(), v.end(), results.begin(), [](auto& c) { return make_arithmetic_device_span(c); });
246+
return results;
247+
}
248+
249+
} // namespace cugraph

cpp/include/cugraph/detail/shuffle_wrappers.hpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
*/
1616
#pragma once
1717

18+
#include <cugraph/arithmetic_variant_types.hpp>
19+
1820
#include <raft/core/handle.hpp>
1921
#include <raft/core/host_span.hpp>
2022
#include <raft/random/rng_state.hpp>
@@ -231,20 +233,12 @@ shuffle_int_vertex_value_pairs_to_local_gpu_by_vertex_partitioning(
231233
* groupby_and_count_local_partition is false) or in each segment with the same (local partition ID,
232234
* GPU ID) pair.
233235
*/
234-
template <typename vertex_t,
235-
typename edge_t,
236-
typename weight_t,
237-
typename edge_type_t,
238-
typename edge_time_t>
236+
template <typename vertex_t>
239237
rmm::device_uvector<size_t> groupby_and_count_edgelist_by_local_partition_id(
240238
raft::handle_t const& handle,
241-
rmm::device_uvector<vertex_t>& d_edgelist_majors,
242-
rmm::device_uvector<vertex_t>& d_edgelist_minors,
243-
std::optional<rmm::device_uvector<weight_t>>& d_edgelist_weights,
244-
std::optional<rmm::device_uvector<edge_t>>& d_edgelist_edge_ids,
245-
std::optional<rmm::device_uvector<edge_type_t>>& d_edgelist_edge_types,
246-
std::optional<rmm::device_uvector<edge_time_t>>& d_edgelist_edge_start_times,
247-
std::optional<rmm::device_uvector<edge_time_t>>& d_edgelist_edge_end_times,
239+
raft::device_span<vertex_t> edgelist_majors,
240+
raft::device_span<vertex_t> edgelist_minors,
241+
raft::host_span<cugraph::arithmetic_device_span_t> edgelist_properties,
248242
bool groupby_and_count_local_partition_by_minor = false);
249243

250244
/**

0 commit comments

Comments
 (0)