Skip to content

Commit 5b0bc14

Browse files
authored
Re-org with distr::slice, distr::weighted modules (rust-random#1548)
- Move `Slice` -> `slice::Choose`, `EmptySlice` -> `slice::Empty` - Rename trait `DistString` -> `SampleString` - Rename `DistIter` -> `Iter`, `DistMap` -> `Map` - Move `{Weight, WeightError, WeightedIndex}` -> `weighted::{Weight, Error, WeightedIndex}` - Move `weighted_alias::{AliasableWeight, WeightedAliasIndex}` -> `weighted::{..}` - Move `weighted_tree::WeightedTreeIndex` -> `weighted::WeightedTreeIndex`
1 parent b7877cd commit 5b0bc14

File tree

5 files changed

+91
-70
lines changed

5 files changed

+91
-70
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [0.5.0-beta.3] - 2025-01-03
88
- Bump `rand` version (#1547)
9+
- Move `Slice` -> `slice::Choose`, `EmptySlice` -> `slice::Empty` (#1548)
10+
- Rename trait `DistString` -> `SampleString` (#1548)
11+
- Rename `DistIter` -> `Iter`, `DistMap` -> `Map` (#1548)
12+
- Move `{Weight, WeightError, WeightedIndex}` -> `weighted::{Weight, Error, WeightedIndex}` (#1548)
13+
- Move `weighted_alias::{AliasableWeight, WeightedAliasIndex}` -> `weighted::{..}` (#1548)
14+
- Move `weighted_tree::WeightedTreeIndex` -> `weighted::WeightedTreeIndex` (#1548)
915

1016
## [0.5.0-beta.2] - 2024-11-30
1117
- Bump `rand` version

src/lib.rs

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@
3333
//!
3434
//! The following are re-exported:
3535
//!
36-
//! - The [`Distribution`] trait and [`DistIter`] helper type
36+
//! - The [`Distribution`] trait and [`Iter`] helper type
3737
//! - The [`StandardUniform`], [`Alphanumeric`], [`Uniform`], [`OpenClosed01`],
38-
//! [`Open01`], [`Bernoulli`], and [`WeightedIndex`] distributions
38+
//! [`Open01`], [`Bernoulli`] distributions
39+
//! - The [`weighted`] module
3940
//!
4041
//! ## Distributions
4142
//!
@@ -76,9 +77,6 @@
7677
//! - [`UnitBall`] distribution
7778
//! - [`UnitCircle`] distribution
7879
//! - [`UnitDisc`] distribution
79-
//! - Alternative implementations for weighted index sampling
80-
//! - [`WeightedAliasIndex`] distribution
81-
//! - [`WeightedTreeIndex`] distribution
8280
//! - Misc. distributions
8381
//! - [`InverseGaussian`] distribution
8482
//! - [`NormalInverseGaussian`] distribution
@@ -94,7 +92,7 @@ extern crate std;
9492
use rand::Rng;
9593

9694
pub use rand::distr::{
97-
uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01,
95+
uniform, Alphanumeric, Bernoulli, BernoulliError, Distribution, Iter, Open01, OpenClosed01,
9896
StandardUniform, Uniform,
9997
};
10098

@@ -128,16 +126,13 @@ pub use self::unit_sphere::UnitSphere;
128126
pub use self::weibull::{Error as WeibullError, Weibull};
129127
pub use self::zeta::{Error as ZetaError, Zeta};
130128
pub use self::zipf::{Error as ZipfError, Zipf};
131-
#[cfg(feature = "alloc")]
132-
pub use rand::distr::{WeightError, WeightedIndex};
133129
pub use student_t::StudentT;
134-
#[cfg(feature = "alloc")]
135-
pub use weighted_alias::WeightedAliasIndex;
136-
#[cfg(feature = "alloc")]
137-
pub use weighted_tree::WeightedTreeIndex;
138130

139131
pub use num_traits;
140132

133+
#[cfg(feature = "alloc")]
134+
pub mod weighted;
135+
141136
#[cfg(test)]
142137
#[macro_use]
143138
mod test {
@@ -189,11 +184,6 @@ mod test {
189184
}
190185
}
191186

192-
#[cfg(feature = "alloc")]
193-
pub mod weighted_alias;
194-
#[cfg(feature = "alloc")]
195-
pub mod weighted_tree;
196-
197187
mod beta;
198188
mod binomial;
199189
mod cauchy;

src/weighted/mod.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright 2018 Developers of the Rand project.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
//! Weighted (index) sampling
10+
//!
11+
//! This module is a superset of [`rand::distr::weighted`].
12+
//!
13+
//! Multiple implementations of weighted index sampling are provided:
14+
//!
15+
//! - [`WeightedIndex`] (a re-export from [`rand`]) supports fast construction
16+
//! and `O(log N)` sampling over `N` weights.
17+
//! It also supports updating weights with `O(N)` time.
18+
//! - [`WeightedAliasIndex`] supports `O(1)` sampling, but due to high
19+
//! construction time many samples are required to outperform [`WeightedIndex`].
20+
//! - [`WeightedTreeIndex`] supports `O(log N)` sampling and
21+
//! update/insertion/removal of weights with `O(log N)` time.
22+
23+
mod weighted_alias;
24+
mod weighted_tree;
25+
26+
pub use rand::distr::weighted::*;
27+
pub use weighted_alias::*;
28+
pub use weighted_tree::*;

src/weighted_alias.rs renamed to src/weighted/weighted_alias.rs

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
//! This module contains an implementation of alias method for sampling random
1010
//! indices with probabilities proportional to a collection of weights.
1111
12-
use super::WeightError;
12+
use super::Error;
1313
use crate::{uniform::SampleUniform, Distribution, Uniform};
1414
use alloc::{boxed::Box, vec, vec::Vec};
1515
use core::fmt;
@@ -41,7 +41,7 @@ use serde::{Deserialize, Serialize};
4141
/// # Example
4242
///
4343
/// ```
44-
/// use rand_distr::WeightedAliasIndex;
44+
/// use rand_distr::weighted::WeightedAliasIndex;
4545
/// use rand::prelude::*;
4646
///
4747
/// let choices = vec!['a', 'b', 'c'];
@@ -85,14 +85,14 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
8585
/// Creates a new [`WeightedAliasIndex`].
8686
///
8787
/// Error cases:
88-
/// - [`WeightError::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`.
89-
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number,
88+
/// - [`Error::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`.
89+
/// - [`Error::InvalidWeight`] when a weight is not-a-number,
9090
/// negative or greater than `max = W::MAX / weights.len()`.
91-
/// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero.
92-
pub fn new(weights: Vec<W>) -> Result<Self, WeightError> {
91+
/// - [`Error::InsufficientNonZero`] when the sum of all weights is zero.
92+
pub fn new(weights: Vec<W>) -> Result<Self, Error> {
9393
let n = weights.len();
9494
if n == 0 || n > u32::MAX as usize {
95-
return Err(WeightError::InvalidInput);
95+
return Err(Error::InvalidInput);
9696
}
9797
let n = n as u32;
9898

@@ -103,7 +103,7 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
103103
.iter()
104104
.all(|&w| W::ZERO <= w && w <= max_weight_size)
105105
{
106-
return Err(WeightError::InvalidWeight);
106+
return Err(Error::InvalidWeight);
107107
}
108108

109109
// The sum of weights will represent 100% of no alias odds.
@@ -115,7 +115,7 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
115115
weight_sum
116116
};
117117
if weight_sum == W::ZERO {
118-
return Err(WeightError::InsufficientNonZero);
118+
return Err(Error::InsufficientNonZero);
119119
}
120120

121121
// `weight_sum` would have been zero if `try_from_lossy` causes an error here.
@@ -384,23 +384,23 @@ mod test {
384384
// Floating point special cases
385385
assert_eq!(
386386
WeightedAliasIndex::new(vec![f32::INFINITY]).unwrap_err(),
387-
WeightError::InvalidWeight
387+
Error::InvalidWeight
388388
);
389389
assert_eq!(
390390
WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
391-
WeightError::InsufficientNonZero
391+
Error::InsufficientNonZero
392392
);
393393
assert_eq!(
394394
WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
395-
WeightError::InvalidWeight
395+
Error::InvalidWeight
396396
);
397397
assert_eq!(
398398
WeightedAliasIndex::new(vec![f32::NEG_INFINITY]).unwrap_err(),
399-
WeightError::InvalidWeight
399+
Error::InvalidWeight
400400
);
401401
assert_eq!(
402402
WeightedAliasIndex::new(vec![f32::NAN]).unwrap_err(),
403-
WeightError::InvalidWeight
403+
Error::InvalidWeight
404404
);
405405
}
406406

@@ -418,11 +418,11 @@ mod test {
418418
// Signed integer special cases
419419
assert_eq!(
420420
WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
421-
WeightError::InvalidWeight
421+
Error::InvalidWeight
422422
);
423423
assert_eq!(
424424
WeightedAliasIndex::new(vec![i128::MIN]).unwrap_err(),
425-
WeightError::InvalidWeight
425+
Error::InvalidWeight
426426
);
427427
}
428428

@@ -440,11 +440,11 @@ mod test {
440440
// Signed integer special cases
441441
assert_eq!(
442442
WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
443-
WeightError::InvalidWeight
443+
Error::InvalidWeight
444444
);
445445
assert_eq!(
446446
WeightedAliasIndex::new(vec![i8::MIN]).unwrap_err(),
447-
WeightError::InvalidWeight
447+
Error::InvalidWeight
448448
);
449449
}
450450

@@ -491,15 +491,15 @@ mod test {
491491

492492
assert_eq!(
493493
WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
494-
WeightError::InvalidInput
494+
Error::InvalidInput
495495
);
496496
assert_eq!(
497497
WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
498-
WeightError::InsufficientNonZero
498+
Error::InsufficientNonZero
499499
);
500500
assert_eq!(
501501
WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
502-
WeightError::InvalidWeight
502+
Error::InvalidWeight
503503
);
504504
}
505505

0 commit comments

Comments
 (0)