Skip to content

Commit 1d7b2b8

Browse files
committed
Improve: Support more numeric operators in Rust
1 parent bf5a7d2 commit 1d7b2b8

File tree

1 file changed

+286
-0
lines changed

1 file changed

+286
-0
lines changed

rust/lib.rs

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,15 @@ extern "C" {
192192
pub struct f16(pub u16);
193193

194194
impl f16 {
195+
/// Positive zero.
196+
pub const ZERO: Self = f16(0);
197+
198+
/// Positive one.
199+
pub const ONE: Self = f16(0x3C00);
200+
201+
/// Negative one.
202+
pub const NEG_ONE: Self = f16(0xBC00);
203+
195204
/// Converts an f32 to f16 representation.
196205
///
197206
/// # Examples
@@ -200,6 +209,7 @@ impl f16 {
200209
/// use simsimd::f16;
201210
/// let half = f16::from_f32(3.14159);
202211
/// ```
212+
#[inline(always)]
203213
pub fn from_f32(value: f32) -> Self {
204214
let mut result: u16 = 0;
205215
unsafe { simsimd_f32_to_f16(value, &mut result) };
@@ -215,9 +225,61 @@ impl f16 {
215225
/// let half = f16::from_f32(3.14159);
216226
/// let float = half.to_f32();
217227
/// ```
228+
#[inline(always)]
218229
pub fn to_f32(self) -> f32 {
219230
unsafe { simsimd_f16_to_f32(&self.0) }
220231
}
232+
233+
/// Returns true if this value is NaN.
234+
#[inline(always)]
235+
pub fn is_nan(self) -> bool {
236+
self.to_f32().is_nan()
237+
}
238+
239+
/// Returns true if this value is positive or negative infinity.
240+
#[inline(always)]
241+
pub fn is_infinite(self) -> bool {
242+
self.to_f32().is_infinite()
243+
}
244+
245+
/// Returns true if this number is neither infinite nor NaN.
246+
#[inline(always)]
247+
pub fn is_finite(self) -> bool {
248+
self.to_f32().is_finite()
249+
}
250+
251+
/// Returns the absolute value of self.
252+
#[inline(always)]
253+
pub fn abs(self) -> Self {
254+
Self::from_f32(self.to_f32().abs())
255+
}
256+
257+
/// Returns the largest integer less than or equal to a number.
258+
///
259+
/// This method is only available when the `std` feature is enabled.
260+
#[cfg(feature = "std")]
261+
#[inline(always)]
262+
pub fn floor(self) -> Self {
263+
Self::from_f32(self.to_f32().floor())
264+
}
265+
266+
/// Returns the smallest integer greater than or equal to a number.
267+
///
268+
/// This method is only available when the `std` feature is enabled.
269+
#[cfg(feature = "std")]
270+
#[inline(always)]
271+
pub fn ceil(self) -> Self {
272+
Self::from_f32(self.to_f32().ceil())
273+
}
274+
275+
/// Returns the nearest integer to a number. Round half-way cases away from 0.0.
276+
///
277+
/// This method is only available when the `std` feature is enabled.
278+
#[cfg(feature = "std")]
279+
#[inline(always)]
280+
pub fn round(self) -> Self {
281+
Self::from_f32(self.to_f32().round())
282+
}
221283
}
222284

223285
#[cfg(feature = "std")]
@@ -227,6 +289,58 @@ impl std::fmt::Display for f16 {
227289
}
228290
}
229291

292+
impl core::ops::Add for f16 {
293+
type Output = Self;
294+
295+
#[inline(always)]
296+
fn add(self, rhs: Self) -> Self::Output {
297+
Self::from_f32(self.to_f32() + rhs.to_f32())
298+
}
299+
}
300+
301+
impl core::ops::Sub for f16 {
302+
type Output = Self;
303+
304+
#[inline(always)]
305+
fn sub(self, rhs: Self) -> Self::Output {
306+
Self::from_f32(self.to_f32() - rhs.to_f32())
307+
}
308+
}
309+
310+
impl core::ops::Mul for f16 {
311+
type Output = Self;
312+
313+
#[inline(always)]
314+
fn mul(self, rhs: Self) -> Self::Output {
315+
Self::from_f32(self.to_f32() * rhs.to_f32())
316+
}
317+
}
318+
319+
impl core::ops::Div for f16 {
320+
type Output = Self;
321+
322+
#[inline(always)]
323+
fn div(self, rhs: Self) -> Self::Output {
324+
Self::from_f32(self.to_f32() / rhs.to_f32())
325+
}
326+
}
327+
328+
impl core::ops::Neg for f16 {
329+
type Output = Self;
330+
331+
#[inline(always)]
332+
fn neg(self) -> Self::Output {
333+
Self::from_f32(-self.to_f32())
334+
}
335+
}
336+
337+
impl core::cmp::PartialOrd for f16 {
338+
#[inline(always)]
339+
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
340+
self.to_f32().partial_cmp(&other.to_f32())
341+
}
342+
}
343+
230344
/// A brain floating point (bfloat16) number.
231345
///
232346
/// This type represents Google's bfloat16 format, which truncates IEEE 754
@@ -252,6 +366,15 @@ impl std::fmt::Display for f16 {
252366
pub struct bf16(pub u16);
253367

254368
impl bf16 {
369+
/// Positive zero.
370+
pub const ZERO: Self = bf16(0);
371+
372+
/// Positive one.
373+
pub const ONE: Self = bf16(0x3F80);
374+
375+
/// Negative one.
376+
pub const NEG_ONE: Self = bf16(0xBF80);
377+
255378
/// Converts an f32 to bf16 representation.
256379
///
257380
/// # Examples
@@ -260,6 +383,7 @@ impl bf16 {
260383
/// use simsimd::bf16;
261384
/// let brain_half = bf16::from_f32(3.14159);
262385
/// ```
386+
#[inline(always)]
263387
pub fn from_f32(value: f32) -> Self {
264388
let mut result: u16 = 0;
265389
unsafe { simsimd_f32_to_bf16(value, &mut result) };
@@ -275,9 +399,61 @@ impl bf16 {
275399
/// let brain_half = bf16::from_f32(3.14159);
276400
/// let float = brain_half.to_f32();
277401
/// ```
402+
#[inline(always)]
278403
pub fn to_f32(self) -> f32 {
279404
unsafe { simsimd_bf16_to_f32(&self.0) }
280405
}
406+
407+
/// Returns true if this value is NaN.
408+
#[inline(always)]
409+
pub fn is_nan(self) -> bool {
410+
self.to_f32().is_nan()
411+
}
412+
413+
/// Returns true if this value is positive or negative infinity.
414+
#[inline(always)]
415+
pub fn is_infinite(self) -> bool {
416+
self.to_f32().is_infinite()
417+
}
418+
419+
/// Returns true if this number is neither infinite nor NaN.
420+
#[inline(always)]
421+
pub fn is_finite(self) -> bool {
422+
self.to_f32().is_finite()
423+
}
424+
425+
/// Returns the absolute value of self.
426+
#[inline(always)]
427+
pub fn abs(self) -> Self {
428+
Self::from_f32(self.to_f32().abs())
429+
}
430+
431+
/// Returns the largest integer less than or equal to a number.
432+
///
433+
/// This method is only available when the `std` feature is enabled.
434+
#[cfg(feature = "std")]
435+
#[inline(always)]
436+
pub fn floor(self) -> Self {
437+
Self::from_f32(self.to_f32().floor())
438+
}
439+
440+
/// Returns the smallest integer greater than or equal to a number.
441+
///
442+
/// This method is only available when the `std` feature is enabled.
443+
#[cfg(feature = "std")]
444+
#[inline(always)]
445+
pub fn ceil(self) -> Self {
446+
Self::from_f32(self.to_f32().ceil())
447+
}
448+
449+
/// Returns the nearest integer to a number. Round half-way cases away from 0.0.
450+
///
451+
/// This method is only available when the `std` feature is enabled.
452+
#[cfg(feature = "std")]
453+
#[inline(always)]
454+
pub fn round(self) -> Self {
455+
Self::from_f32(self.to_f32().round())
456+
}
281457
}
282458

283459
#[cfg(feature = "std")]
@@ -287,6 +463,58 @@ impl std::fmt::Display for bf16 {
287463
}
288464
}
289465

466+
impl core::ops::Add for bf16 {
467+
type Output = Self;
468+
469+
#[inline(always)]
470+
fn add(self, rhs: Self) -> Self::Output {
471+
Self::from_f32(self.to_f32() + rhs.to_f32())
472+
}
473+
}
474+
475+
impl core::ops::Sub for bf16 {
476+
type Output = Self;
477+
478+
#[inline(always)]
479+
fn sub(self, rhs: Self) -> Self::Output {
480+
Self::from_f32(self.to_f32() - rhs.to_f32())
481+
}
482+
}
483+
484+
impl core::ops::Mul for bf16 {
485+
type Output = Self;
486+
487+
#[inline(always)]
488+
fn mul(self, rhs: Self) -> Self::Output {
489+
Self::from_f32(self.to_f32() * rhs.to_f32())
490+
}
491+
}
492+
493+
impl core::ops::Div for bf16 {
494+
type Output = Self;
495+
496+
#[inline(always)]
497+
fn div(self, rhs: Self) -> Self::Output {
498+
Self::from_f32(self.to_f32() / rhs.to_f32())
499+
}
500+
}
501+
502+
impl core::ops::Neg for bf16 {
503+
type Output = Self;
504+
505+
#[inline(always)]
506+
fn neg(self) -> Self::Output {
507+
Self::from_f32(-self.to_f32())
508+
}
509+
}
510+
511+
impl core::cmp::PartialOrd for bf16 {
512+
#[inline(always)]
513+
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
514+
self.to_f32().partial_cmp(&other.to_f32())
515+
}
516+
}
517+
290518
/// The `capabilities` module provides functions for detecting the hardware features
291519
/// available on the current system.
292520
pub mod capabilities {
@@ -1374,4 +1602,62 @@ mod tests {
13741602
}
13751603
}
13761604
}
1605+
1606+
#[test]
1607+
fn test_f16_arithmetic() {
1608+
let a = f16::from_f32(3.5);
1609+
let b = f16::from_f32(2.0);
1610+
1611+
// Test basic arithmetic
1612+
assert!((a + b).to_f32() - 5.5 < 0.01);
1613+
assert!((a - b).to_f32() - 1.5 < 0.01);
1614+
assert!((a * b).to_f32() - 7.0 < 0.01);
1615+
assert!((a / b).to_f32() - 1.75 < 0.01);
1616+
assert!((-a).to_f32() + 3.5 < 0.01);
1617+
1618+
// Test constants
1619+
assert!(f16::ZERO.to_f32() == 0.0);
1620+
assert!((f16::ONE.to_f32() - 1.0).abs() < 0.01);
1621+
assert!((f16::NEG_ONE.to_f32() + 1.0).abs() < 0.01);
1622+
1623+
// Test comparisons
1624+
assert!(a > b);
1625+
assert!(!(a < b));
1626+
assert!(a == a);
1627+
1628+
// Test utility methods
1629+
assert!((-a).abs().to_f32() - 3.5 < 0.01);
1630+
assert!(a.is_finite());
1631+
assert!(!a.is_nan());
1632+
assert!(!a.is_infinite());
1633+
}
1634+
1635+
#[test]
1636+
fn test_bf16_arithmetic() {
1637+
let a = bf16::from_f32(3.5);
1638+
let b = bf16::from_f32(2.0);
1639+
1640+
// Test basic arithmetic
1641+
assert!((a + b).to_f32() - 5.5 < 0.1);
1642+
assert!((a - b).to_f32() - 1.5 < 0.1);
1643+
assert!((a * b).to_f32() - 7.0 < 0.1);
1644+
assert!((a / b).to_f32() - 1.75 < 0.1);
1645+
assert!((-a).to_f32() + 3.5 < 0.1);
1646+
1647+
// Test constants
1648+
assert!(bf16::ZERO.to_f32() == 0.0);
1649+
assert!((bf16::ONE.to_f32() - 1.0).abs() < 0.01);
1650+
assert!((bf16::NEG_ONE.to_f32() + 1.0).abs() < 0.01);
1651+
1652+
// Test comparisons
1653+
assert!(a > b);
1654+
assert!(!(a < b));
1655+
assert!(a == a);
1656+
1657+
// Test utility methods
1658+
assert!((-a).abs().to_f32() - 3.5 < 0.1);
1659+
assert!(a.is_finite());
1660+
assert!(!a.is_nan());
1661+
assert!(!a.is_infinite());
1662+
}
13771663
}

0 commit comments

Comments
 (0)