Skip to content

Commit 1b9a62c

Browse files
committed
refactor: use NonZero types
Signed-off-by: Wenxuan Zhang <[email protected]>
1 parent 8b23af0 commit 1b9a62c

File tree

4 files changed

+50
-49
lines changed

4 files changed

+50
-49
lines changed

src/cli.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@
8181
//!
8282
//! -h, --help
8383
//! Print help (see a summary with '-h')
84-
use std::io::stdout;
84+
use std::{
85+
io::stdout,
86+
num::{NonZeroU32, NonZeroU64, NonZeroU8},
87+
};
8588

8689
use clap::{Parser, ValueEnum};
8790
use crossterm::tty::IsTty;
@@ -102,14 +105,14 @@ use crate::{
102105
#[allow(missing_docs)]
103106
pub struct BenchCli {
104107
/// Number of workers to run concurrently
105-
#[clap(long, short = 'c', default_value = "1", value_parser = clap::value_parser!(u32).range(1..))]
106-
pub concurrency: u32,
108+
#[clap(long, short = 'c', default_value = "1")]
109+
pub concurrency: NonZeroU32,
107110

108111
/// Number of iterations
109112
///
110113
/// When set, benchmark stops after reaching the number of iterations.
111-
#[clap(long, short = 'n', value_parser = clap::value_parser!(u64).range(1..))]
112-
pub iterations: Option<u64>,
114+
#[clap(long, short = 'n')]
115+
pub iterations: Option<NonZeroU64>,
113116

114117
/// Duration to run the benchmark
115118
///
@@ -122,8 +125,8 @@ pub struct BenchCli {
122125
/// Rate limit for benchmarking, in iterations per second (ips)
123126
///
124127
/// When set, benchmark will try to run at the specified rate.
125-
#[clap(long, short = 'r', value_parser = clap::value_parser!(u32).range(1..))]
126-
pub rate: Option<u32>,
128+
#[clap(long, short = 'r')]
129+
pub rate: Option<NonZeroU32>,
127130

128131
/// Run benchmark in quiet mode
129132
///
@@ -136,8 +139,8 @@ pub struct BenchCli {
136139
pub collector: Option<Collector>,
137140

138141
/// Refresh rate for the tui collector, in frames per second (fps)
139-
#[clap(long, default_value = "32", value_parser = clap::value_parser!(u8).range(1..))]
140-
pub fps: u8,
142+
#[clap(long, default_value = "32")]
143+
pub fps: NonZeroU8,
141144

142145
/// Output format for the report
143146
#[clap(short, long, value_enum, default_value_t = ReportFormat::Text, ignore_case = true)]
@@ -148,8 +151,8 @@ impl BenchCli {
148151
pub(crate) fn bench_opts(&self, clock: Clock) -> BenchOpts {
149152
BenchOpts {
150153
clock,
151-
concurrency: self.concurrency,
152-
iterations: self.iterations,
154+
concurrency: self.concurrency.get(),
155+
iterations: self.iterations.map(|n| n.get()),
153156
duration: self.duration.map(|d| d.into()),
154157
rate: self.rate,
155158
}

src/collector/tui.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crossterm::{
66
terminal, ExecutableCommand,
77
};
88
use itertools::Itertools;
9+
use nonzero_ext::nonzero;
910
use ratatui::{
1011
backend::CrosstermBackend,
1112
layout::{Constraint, Direction, Layout, Margin, Rect},
@@ -14,7 +15,7 @@ use ratatui::{
1415
widgets::{block::Title, BarChart, Block, Borders, Clear, Gauge, Padding, Paragraph},
1516
CompletedFrame, Frame,
1617
};
17-
use std::{collections::HashMap, fmt, io, time::Duration};
18+
use std::{collections::HashMap, fmt, io, num::NonZeroU8, time::Duration};
1819
use tokio::{
1920
sync::{mpsc, watch},
2021
time::MissedTickBehavior,
@@ -47,7 +48,7 @@ pub struct TuiCollector {
4748
/// The benchmark options.
4849
pub bench_opts: BenchOpts,
4950
/// Refresh rate for the tui collector, in frames per second (fps)
50-
pub fps: u8,
51+
pub fps: NonZeroU8,
5152
/// The receiver for iteration reports.
5253
pub res_rx: mpsc::UnboundedReceiver<Result<IterReport>>,
5354
/// The sender for pausing the benchmark runner.
@@ -63,7 +64,7 @@ impl TuiCollector {
6364
/// Create a new TUI report collector.
6465
pub fn new(
6566
bench_opts: BenchOpts,
66-
fps: u8,
67+
fps: NonZeroU8,
6768
res_rx: mpsc::UnboundedReceiver<Result<IterReport>>,
6869
pause: watch::Sender<bool>,
6970
cancel: CancellationToken,
@@ -131,13 +132,13 @@ impl ReportCollector for TuiCollector {
131132

132133
let mut clock = self.bench_opts.clock.clone();
133134

134-
let mut latest_iters = RotateWindowGroup::new(60);
135+
let mut latest_iters = RotateWindowGroup::new(nonzero!(60usize));
135136
let mut latest_iters_ticker = clock.ticker(SECOND);
136137

137-
let mut latest_stats = RotateDiffWindowGroup::new(self.fps);
138-
let mut latest_stats_ticker = clock.ticker(SECOND / self.fps as u32);
138+
let mut latest_stats = RotateDiffWindowGroup::new(self.fps.into());
139+
let mut latest_stats_ticker = clock.ticker(SECOND / self.fps.get() as u32);
139140

140-
let mut ui_ticker = tokio::time::interval(SECOND / self.fps as u32);
141+
let mut ui_ticker = tokio::time::interval(SECOND / self.fps.get() as u32);
141142
ui_ticker.set_missed_tick_behavior(MissedTickBehavior::Burst);
142143

143144
#[cfg(feature = "log")]

src/runner.rs

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
//! This module defines traits for stateful and stateless benchmark suites.
2-
use anyhow::{anyhow, Result};
2+
use anyhow::Result;
33
use async_trait::async_trait;
44
use governor::{Quota, RateLimiter};
5-
use nonzero_ext::{nonzero, NonZero};
5+
use nonzero_ext::nonzero;
66
use std::{
7+
num::NonZeroU32,
78
sync::{
89
atomic::{AtomicU64, Ordering},
910
Arc,
@@ -39,7 +40,7 @@ pub struct BenchOpts {
3940
pub duration: Option<Duration>,
4041

4142
/// Rate limit for benchmarking, in iterations per second (ips).
42-
pub rate: Option<u32>,
43+
pub rate: Option<NonZeroU32>,
4344
}
4445

4546
/// A trait for benchmark suites.
@@ -156,22 +157,16 @@ where
156157
let concurrency = self.opts.concurrency;
157158
let iterations = self.opts.iterations;
158159

159-
let rli = match self.opts.rate {
160-
Some(r) => {
161-
let quota = Quota::per_second(
162-
NonZero::new(r).ok_or_else(|| anyhow!("rate limit must be greater than 0, got {}", r))?,
163-
)
164-
.allow_burst(nonzero!(1u32));
165-
let clock = &self.opts.clock;
166-
Some(Arc::new(RateLimiter::direct_with_clock(quota, clock)))
167-
}
168-
None => None,
169-
};
160+
let buckets = self.opts.rate.map(|r| {
161+
let quota = Quota::per_second(r).allow_burst(nonzero!(1u32));
162+
let clock = &self.opts.clock;
163+
Arc::new(RateLimiter::direct_with_clock(quota, clock))
164+
});
170165

171166
let mut set: JoinSet<Result<()>> = JoinSet::new();
172167
for worker in 0..concurrency {
173168
let mut b = self.clone();
174-
let rli = rli.clone();
169+
let buckets = buckets.clone();
175170
set.spawn(async move {
176171
let mut state = b.suite.state(worker).await?;
177172
let mut info = IterInfo::new(worker);
@@ -186,14 +181,16 @@ where
186181
}
187182
}
188183

189-
if let Some(rli) = &rli {
184+
if let Some(buckets) = &buckets {
190185
select! {
186+
biased;
191187
_ = cancel.cancelled() => break,
192-
_ = rli.until_ready() => (),
188+
_ = buckets.until_ready() => (),
193189
}
194190
}
195191

196192
select! {
193+
biased;
197194
_ = cancel.cancelled() => break,
198195
_ = b.iteration(&mut state, &info) => (),
199196
}
@@ -207,6 +204,7 @@ where
207204

208205
if let Some(t) = self.opts.duration {
209206
select! {
207+
biased;
210208
_ = self.cancel.cancelled() => (),
211209
_ = self.opts.clock.sleep(t) => self.cancel.cancel(),
212210
_ = join_all(&mut set) => (),

src/stats/window.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
use std::collections::VecDeque;
1+
use std::{collections::VecDeque, num::NonZeroUsize};
22

3+
use nonzero_ext::nonzero;
34
use tokio::time::Duration;
45

56
use crate::report::IterReport;
@@ -8,13 +9,12 @@ use super::IterStats;
89

910
pub struct RotateWindow {
1011
buckets: VecDeque<IterStats>,
11-
size: usize,
12+
size: NonZeroUsize,
1213
}
1314

1415
impl RotateWindow {
15-
fn new(size: usize) -> Self {
16-
assert!(size > 0);
17-
let mut win = Self { buckets: VecDeque::with_capacity(size), size };
16+
fn new(size: NonZeroUsize) -> Self {
17+
let mut win = Self { buckets: VecDeque::with_capacity(size.get()), size };
1818
win.rotate(IterStats::new());
1919
win
2020
}
@@ -25,7 +25,7 @@ impl RotateWindow {
2525
}
2626

2727
fn rotate(&mut self, bucket: IterStats) {
28-
if self.buckets.len() == self.size {
28+
if self.buckets.len() == self.size.get() {
2929
self.buckets.pop_back();
3030
}
3131
self.buckets.push_front(bucket);
@@ -59,7 +59,7 @@ pub struct RotateWindowGroup {
5959
}
6060

6161
impl RotateWindowGroup {
62-
pub fn new(buckets: usize) -> Self {
62+
pub fn new(buckets: NonZeroUsize) -> Self {
6363
Self {
6464
counter: 0,
6565
stats_by_sec: RotateWindow::new(buckets),
@@ -108,15 +108,14 @@ impl RotateDiffWindowGroup {
108108
&mut self.stats_last_10min,
109109
]
110110
}
111-
pub fn new(fps: u8) -> Self {
112-
let fps = fps as usize;
113-
let interval = Duration::from_secs_f64(1.0 / fps as f64);
111+
pub fn new(fps: NonZeroUsize) -> Self {
112+
let interval = Duration::from_secs_f64(1.0 / fps.get() as f64);
114113
let mut group = Self {
115114
interval,
116-
stats_last_sec: RotateWindow::new(fps + 1),
117-
stats_last_10sec: RotateWindow::new(fps * 10 + 1),
118-
stats_last_min: RotateWindow::new(fps * 60 + 1),
119-
stats_last_10min: RotateWindow::new(fps * 600 + 1),
115+
stats_last_sec: RotateWindow::new(fps.saturating_add(1)),
116+
stats_last_10sec: RotateWindow::new(fps.saturating_mul(nonzero!(10usize)).saturating_add(1)),
117+
stats_last_min: RotateWindow::new(fps.saturating_mul(nonzero!(60usize)).saturating_add(1)),
118+
stats_last_10min: RotateWindow::new(fps.saturating_mul(nonzero!(600usize)).saturating_add(1)),
120119
};
121120
group.rotate(&IterStats::new());
122121
group

0 commit comments

Comments
 (0)