Skip to content

Commit 1321752

Browse files
committed
addressed #106 now all thread pools uses Borrow<ThreadPool>
1 parent d60b02f commit 1321752

File tree

11 files changed

+51
-72
lines changed

11 files changed

+51
-72
lines changed

src/algo/llp/gap_cost.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
* SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
66
*/
77

8+
use std::borrow::Borrow;
9+
810
use crate::traits::*;
911
use dsi_progress_logger::prelude::*;
1012
use lender::prelude::*;
@@ -19,7 +21,7 @@ pub(crate) fn compute_log_gap_cost<G: SequentialGraph + Sync>(
1921
graph: &G,
2022
arc_granularity: usize,
2123
deg_cumul: &(impl Succ<Input = usize, Output = usize> + Send + Sync),
22-
thread_pool: &rayon::ThreadPool,
24+
thread_pool: impl Borrow<rayon::ThreadPool>,
2325
pr: Option<&mut ProgressLogger>,
2426
) -> f64 {
2527
graph.par_apply(

src/cli/from_csv.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,10 @@ pub fn from_csv(args: CliArgs) -> Result<()> {
131131
&g,
132132
args.num_nodes,
133133
args.ca.into(),
134-
Threads::Num(args.num_cpus.num_cpus),
134+
rayon::ThreadPoolBuilder::new()
135+
.num_threads(args.num_cpus.num_cpus)
136+
.build()
137+
.expect("Failed to create thread pool"),
135138
dir,
136139
&target_endianness.unwrap_or_else(|| BE::NAME.into()),
137140
)

src/cli/recompress.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ where
8080
{
8181
let dir = Builder::new().prefix("Recompress").tempdir()?;
8282

83+
let thread_pool = rayon::ThreadPoolBuilder::new()
84+
.num_threads(args.num_cpus.num_cpus)
85+
.build()
86+
.expect("Failed to create thread pool");
87+
8388
if args.basename.with_extension(EF_EXTENSION).exists() {
8489
let graph = BVGraph::with_basename(&args.basename)
8590
.endianness::<E>()
@@ -107,12 +112,7 @@ where
107112
>,
108113
>,
109114
JavaPermutation,
110-
>(
111-
&graph,
112-
&permutation,
113-
batch_size,
114-
Threads::Num(args.num_cpus.num_cpus),
115-
)?;
115+
>(&graph, &permutation, batch_size, &thread_pool)?;
116116
log::info!(
117117
"Permuted the graph. It took {:.3} seconds",
118118
start.elapsed().as_secs_f64()
@@ -122,7 +122,7 @@ where
122122
&sorted,
123123
sorted.num_nodes(),
124124
args.ca.into(),
125-
Threads::Num(args.num_cpus.num_cpus),
125+
&thread_pool,
126126
dir,
127127
&target_endianness.unwrap_or_else(|| E::NAME.into()),
128128
)?;
@@ -132,7 +132,7 @@ where
132132
&graph,
133133
graph.num_nodes(),
134134
args.ca.into(),
135-
Threads::Num(args.num_cpus.num_cpus),
135+
&thread_pool,
136136
dir,
137137
&target_endianness.unwrap_or_else(|| E::NAME.into()),
138138
)?;
@@ -159,7 +159,7 @@ where
159159
&permuted,
160160
permuted.num_nodes(),
161161
args.ca.into(),
162-
Threads::Num(args.num_cpus.num_cpus),
162+
&thread_pool,
163163
dir,
164164
&target_endianness.unwrap_or_else(|| E::NAME.into()),
165165
)?;
@@ -169,7 +169,7 @@ where
169169
&seq_graph,
170170
seq_graph.num_nodes(),
171171
args.ca.into(),
172-
Threads::Num(args.num_cpus.num_cpus),
172+
&thread_pool,
173173
dir,
174174
&target_endianness.unwrap_or_else(|| E::NAME.into()),
175175
)?;

src/cli/simplify.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ where
7878
&sorted,
7979
sorted.num_nodes(),
8080
args.ca.into(),
81-
Threads::Num(args.num_cpus.num_cpus),
81+
rayon::ThreadPoolBuilder::new()
82+
.num_threads(args.num_cpus.num_cpus)
83+
.build()
84+
.expect("Failed to create thread pool"),
8285
dir,
8386
&target_endianness.unwrap_or_else(|| E::NAME.into()),
8487
)?;

src/cli/transpose.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ where
7878
&sorted,
7979
sorted.num_nodes(),
8080
args.ca.into(),
81-
Threads::Num(args.num_cpus.num_cpus),
81+
rayon::ThreadPoolBuilder::new()
82+
.num_threads(args.num_cpus.num_cpus)
83+
.build()
84+
.expect("Failed to create thread pool"),
8285
dir,
8386
&target_endianness.unwrap_or_else(|| E::NAME.into()),
8487
)?;

src/graphs/bvgraph/comp/impls.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use anyhow::{ensure, Context, Result};
1010
use dsi_bitstream::prelude::*;
1111
use dsi_progress_logger::prelude::*;
1212
use lender::prelude::*;
13+
use std::borrow::Borrow;
1314
use std::fs::File;
1415
use std::io::{BufReader, BufWriter};
1516
use std::path::{Path, PathBuf};
@@ -196,7 +197,7 @@ impl BVComp<()> {
196197
graph: &G,
197198
num_nodes: usize,
198199
compression_flags: CompFlags,
199-
mut threads: impl AsMut<rayon::ThreadPool>,
200+
threads: impl Borrow<rayon::ThreadPool>,
200201
tmp_dir: P,
201202
endianness: &str,
202203
) -> Result<u64>
@@ -213,7 +214,7 @@ impl BVComp<()> {
213214
Self::parallel_iter::<BigEndian, _>(
214215
basename,
215216
graph
216-
.split_iter(threads.as_mut().current_num_threads())
217+
.split_iter(threads.borrow().current_num_threads())
217218
.into_iter(),
218219
num_nodes,
219220
compression_flags,
@@ -230,7 +231,7 @@ impl BVComp<()> {
230231
Self::parallel_iter::<LittleEndian, _>(
231232
basename,
232233
graph
233-
.split_iter(threads.as_mut().current_num_threads())
234+
.split_iter(threads.borrow().current_num_threads())
234235
.into_iter(),
235236
num_nodes,
236237
compression_flags,
@@ -247,7 +248,7 @@ impl BVComp<()> {
247248
basename: impl AsRef<Path> + Send + Sync,
248249
graph: &(impl SequentialGraph + SplitLabeling),
249250
compression_flags: CompFlags,
250-
mut threads: impl AsMut<rayon::ThreadPool>,
251+
threads: impl Borrow<rayon::ThreadPool>,
251252
tmp_dir: impl AsRef<Path>,
252253
) -> Result<u64>
253254
where
@@ -257,7 +258,7 @@ impl BVComp<()> {
257258
Self::parallel_iter(
258259
basename,
259260
graph
260-
.split_iter(threads.as_mut().current_num_threads())
261+
.split_iter(threads.borrow().current_num_threads())
261262
.into_iter(),
262263
graph.num_nodes(),
263264
compression_flags,
@@ -276,15 +277,14 @@ impl BVComp<()> {
276277
iter: impl Iterator<Item = L>,
277278
num_nodes: usize,
278279
compression_flags: CompFlags,
279-
mut threads: impl AsMut<rayon::ThreadPool>,
280+
threads: impl Borrow<rayon::ThreadPool>,
280281
tmp_dir: impl AsRef<Path>,
281282
) -> Result<u64>
282283
where
283284
BufBitWriter<E, WordAdapter<usize, BufWriter<std::fs::File>>>: CodeWrite<E>,
284285
BufBitReader<E, WordAdapter<u32, BufReader<std::fs::File>>>: BitRead<E>,
285286
{
286-
let thread_pool = threads.as_mut();
287-
287+
let thread_pool = threads.borrow();
288288
let tmp_dir = tmp_dir.as_ref();
289289
let basename = basename.as_ref();
290290

src/traits/labels.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use core::{
3333
use dsi_progress_logger::prelude::*;
3434
use impl_tools::autoimpl;
3535
use lender::*;
36+
use std::borrow::Borrow;
3637
use sux::traits::Succ;
3738

3839
/// A labeling that can be accessed sequentially.
@@ -104,7 +105,7 @@ pub trait SequentialLabeling {
104105
func: F,
105106
fold: R,
106107
node_granularity: usize,
107-
thread_pool: &rayon::ThreadPool,
108+
thread_pool: impl Borrow<rayon::ThreadPool>,
108109
pl: Option<&mut ProgressLogger>,
109110
) -> A
110111
where
@@ -116,6 +117,7 @@ pub trait SequentialLabeling {
116117
let pl_lock = pl.map(std::sync::Mutex::new);
117118
let num_nodes = self.num_nodes();
118119
let num_scoped_threads = thread_pool
120+
.borrow()
119121
.current_num_threads()
120122
.min(num_nodes / node_granularity)
121123
.max(1);
@@ -124,7 +126,7 @@ pub trait SequentialLabeling {
124126

125127
// create a channel to receive the result
126128
let (tx, rx) = std::sync::mpsc::channel();
127-
thread_pool.in_place_scope(|scope| {
129+
thread_pool.borrow().in_place_scope(|scope| {
128130
for _ in 0..num_scoped_threads {
129131
// create some references so that we can share them across threads
130132
let pl_lock = &pl_lock;
@@ -185,7 +187,7 @@ pub trait SequentialLabeling {
185187
fold: R,
186188
arc_granularity: usize,
187189
deg_cumul: &(impl Succ<Input = usize, Output = usize> + Send + Sync),
188-
thread_pool: &rayon::ThreadPool,
190+
thread_pool: impl Borrow<rayon::ThreadPool>,
189191
pl: Option<&mut ProgressLogger>,
190192
) -> A
191193
where
@@ -194,6 +196,7 @@ pub trait SequentialLabeling {
194196
T: Send,
195197
A: Default + Send,
196198
{
199+
let thread_pool = thread_pool.borrow();
197200
let pl_lock = pl.map(std::sync::Mutex::new);
198201
let num_nodes = self.num_nodes();
199202
let num_arcs = self.num_arcs_hint().unwrap();

src/transform/perm.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
* SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
55
*/
66

7+
use std::borrow::Borrow;
8+
79
use crate::graphs::arc_list_graph;
810
use crate::prelude::sort_pairs::{BatchIterator, KMergeIters};
911
use crate::prelude::*;
@@ -73,7 +75,7 @@ pub fn permute_split<S, P>(
7375
graph: &S,
7476
perm: &P,
7577
batch_size: usize,
76-
mut threads: impl AsMut<rayon::ThreadPool>,
78+
threads: impl Borrow<rayon::ThreadPool>,
7779
) -> Result<Left<arc_list_graph::ArcListGraph<KMergeIters<BatchIterator<()>, ()>>>>
7880
where
7981
S: SequentialGraph + SplitLabeling,
@@ -88,7 +90,7 @@ where
8890
// get a premuted view
8991
let pgraph = PermutedGraph { graph, perm };
9092

91-
let pool = threads.as_mut();
93+
let pool = threads.borrow();
9294
let num_threads = pool.current_num_threads();
9395
let mut dirs = vec![];
9496

src/transform/simplify.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
* SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
55
*/
66

7+
use std::borrow::Borrow;
8+
79
use crate::graphs::{arc_list_graph, UnionGraph};
810
use crate::labels::Left;
911
use crate::traits::{SequentialGraph, SplitLabeling};
@@ -97,12 +99,12 @@ pub fn simplify(
9799
pub fn simplify_split<S>(
98100
graph: &S,
99101
batch_size: usize,
100-
mut threads: impl AsMut<rayon::ThreadPool>,
102+
threads: impl Borrow<rayon::ThreadPool>,
101103
) -> Result<Left<arc_list_graph::ArcListGraph<itertools::Dedup<KMergeIters<BatchIterator<()>, ()>>>>>
102104
where
103105
S: SequentialGraph + SplitLabeling,
104106
{
105-
let pool = threads.as_mut();
107+
let pool = threads.borrow();
106108
let num_threads = pool.current_num_threads();
107109
let (tx, rx) = std::sync::mpsc::channel();
108110

src/utils/mod.rs

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -63,45 +63,3 @@ pub use java_perm::*;
6363

6464
pub mod sort_pairs;
6565
pub use sort_pairs::SortPairs;
66-
67-
/// An enum to specify the number of threads to use in parallel operations,
68-
/// either by [`rayon::current_num_threads()`], or by a fixed number, or by a
69-
/// custom [thread pool](rayon::ThreadPool).
70-
pub enum Threads {
71-
Default,
72-
Num(usize),
73-
Pool(rayon::ThreadPool),
74-
}
75-
76-
impl Threads {
77-
pub fn num_threads(&self) -> usize {
78-
match self {
79-
Self::Default => rayon::current_num_threads(),
80-
Self::Num(num_threads) => *num_threads,
81-
Self::Pool(thread_pool) => thread_pool.current_num_threads(),
82-
}
83-
}
84-
}
85-
86-
impl AsMut<rayon::ThreadPool> for Threads {
87-
fn as_mut(&mut self) -> &mut rayon::ThreadPool {
88-
match self {
89-
Self::Default => {
90-
let thread_pool = rayon::ThreadPoolBuilder::new()
91-
.build()
92-
.expect("Failed to create thread pool");
93-
*self = Self::Pool(thread_pool);
94-
self.as_mut()
95-
}
96-
Self::Num(num_threads) => {
97-
let thread_pool = rayon::ThreadPoolBuilder::new()
98-
.num_threads(*num_threads)
99-
.build()
100-
.expect("Failed to create thread pool");
101-
*self = Self::Pool(thread_pool);
102-
self.as_mut()
103-
}
104-
Self::Pool(thread_pool) => thread_pool,
105-
}
106-
}
107-
}

0 commit comments

Comments
 (0)