@@ -4,6 +4,8 @@ use std::panic::RefUnwindSafe;
4
4
5
5
use bytemuck:: { Pod , Zeroable } ;
6
6
7
+ pub use half:: f16;
8
+
7
9
use super :: PrimitiveType ;
8
10
9
11
/// Sealed trait implemented by all physical types that can be allocated,
@@ -331,186 +333,27 @@ impl Neg for months_days_ns {
331
333
}
332
334
}
333
335
334
- /// Type representation of the Float16 physical type
335
- #[ derive( Copy , Clone , Default , Zeroable , Pod ) ]
336
- #[ allow( non_camel_case_types) ]
337
- #[ repr( C ) ]
338
- pub struct f16 ( pub u16 ) ;
339
-
340
- impl PartialEq for f16 {
341
- #[ inline]
342
- fn eq ( & self , other : & f16 ) -> bool {
343
- if self . is_nan ( ) || other. is_nan ( ) {
344
- false
345
- } else {
346
- ( self . 0 == other. 0 ) || ( ( self . 0 | other. 0 ) & 0x7FFFu16 == 0 )
347
- }
348
- }
349
- }
350
-
351
- // see https://github.com/starkat99/half-rs/blob/main/src/binary16.rs
352
- impl f16 {
353
- /// The difference between 1.0 and the next largest representable number.
354
- pub const EPSILON : f16 = f16 ( 0x1400u16 ) ;
355
-
356
- #[ inline]
357
- #[ must_use]
358
- pub ( crate ) const fn is_nan ( self ) -> bool {
359
- self . 0 & 0x7FFFu16 > 0x7C00u16
360
- }
361
-
362
- /// Casts from u16.
363
- #[ inline]
364
- pub const fn from_bits ( bits : u16 ) -> f16 {
365
- f16 ( bits)
366
- }
367
-
368
- /// Casts to u16.
369
- #[ inline]
370
- pub const fn to_bits ( self ) -> u16 {
371
- self . 0
372
- }
373
-
374
- /// Casts this `f16` to `f32`
375
- pub fn to_f32 ( self ) -> f32 {
376
- let i = self . 0 ;
377
- // Check for signed zero
378
- if i & 0x7FFFu16 == 0 {
379
- return f32:: from_bits ( ( i as u32 ) << 16 ) ;
380
- }
381
-
382
- let half_sign = ( i & 0x8000u16 ) as u32 ;
383
- let half_exp = ( i & 0x7C00u16 ) as u32 ;
384
- let half_man = ( i & 0x03FFu16 ) as u32 ;
385
-
386
- // Check for an infinity or NaN when all exponent bits set
387
- if half_exp == 0x7C00u32 {
388
- // Check for signed infinity if mantissa is zero
389
- if half_man == 0 {
390
- let number = ( half_sign << 16 ) | 0x7F80_0000u32 ;
391
- return f32:: from_bits ( number) ;
392
- } else {
393
- // NaN, keep current mantissa but also set most significiant mantissa bit
394
- let number = ( half_sign << 16 ) | 0x7FC0_0000u32 | ( half_man << 13 ) ;
395
- return f32:: from_bits ( number) ;
396
- }
397
- }
398
-
399
- // Calculate single-precision components with adjusted exponent
400
- let sign = half_sign << 16 ;
401
- // Unbias exponent
402
- let unbiased_exp = ( ( half_exp as i32 ) >> 10 ) - 15 ;
403
-
404
- // Check for subnormals, which will be normalized by adjusting exponent
405
- if half_exp == 0 {
406
- // Calculate how much to adjust the exponent by
407
- let e = ( half_man as u16 ) . leading_zeros ( ) - 6 ;
408
-
409
- // Rebias and adjust exponent
410
- let exp = ( 127 - 15 - e) << 23 ;
411
- let man = ( half_man << ( 14 + e) ) & 0x7F_FF_FFu32 ;
412
- return f32:: from_bits ( sign | exp | man) ;
413
- }
414
-
415
- // Rebias exponent for a normalized normal
416
- let exp = ( ( unbiased_exp + 127 ) as u32 ) << 23 ;
417
- let man = ( half_man & 0x03FFu32 ) << 13 ;
418
- f32:: from_bits ( sign | exp | man)
419
- }
420
-
421
- /// Casts an `f32` into `f16`
422
- pub fn from_f32 ( value : f32 ) -> Self {
423
- let x: u32 = value. to_bits ( ) ;
424
-
425
- // Extract IEEE754 components
426
- let sign = x & 0x8000_0000u32 ;
427
- let exp = x & 0x7F80_0000u32 ;
428
- let man = x & 0x007F_FFFFu32 ;
429
-
430
- // Check for all exponent bits being set, which is Infinity or NaN
431
- if exp == 0x7F80_0000u32 {
432
- // Set mantissa MSB for NaN (and also keep shifted mantissa bits)
433
- let nan_bit = if man == 0 { 0 } else { 0x0200u32 } ;
434
- return f16 ( ( ( sign >> 16 ) | 0x7C00u32 | nan_bit | ( man >> 13 ) ) as u16 ) ;
435
- }
436
-
437
- // The number is normalized, start assembling half precision version
438
- let half_sign = sign >> 16 ;
439
- // Unbias the exponent, then bias for half precision
440
- let unbiased_exp = ( ( exp >> 23 ) as i32 ) - 127 ;
441
- let half_exp = unbiased_exp + 15 ;
442
-
443
- // Check for exponent overflow, return +infinity
444
- if half_exp >= 0x1F {
445
- return f16 ( ( half_sign | 0x7C00u32 ) as u16 ) ;
446
- }
447
-
448
- // Check for underflow
449
- if half_exp <= 0 {
450
- // Check mantissa for what we can do
451
- if 14 - half_exp > 24 {
452
- // No rounding possibility, so this is a full underflow, return signed zero
453
- return f16 ( half_sign as u16 ) ;
454
- }
455
- // Don't forget about hidden leading mantissa bit when assembling mantissa
456
- let man = man | 0x0080_0000u32 ;
457
- let mut half_man = man >> ( 14 - half_exp) ;
458
- // Check for rounding (see comment above functions)
459
- let round_bit = 1 << ( 13 - half_exp) ;
460
- if ( man & round_bit) != 0 && ( man & ( 3 * round_bit - 1 ) ) != 0 {
461
- half_man += 1 ;
462
- }
463
- // No exponent for subnormals
464
- return f16 ( ( half_sign | half_man) as u16 ) ;
465
- }
466
-
467
- // Rebias the exponent
468
- let half_exp = ( half_exp as u32 ) << 10 ;
469
- let half_man = man >> 13 ;
470
- // Check for rounding (see comment above functions)
471
- let round_bit = 0x0000_1000u32 ;
472
- if ( man & round_bit) != 0 && ( man & ( 3 * round_bit - 1 ) ) != 0 {
473
- // Round it
474
- f16 ( ( ( half_sign | half_exp | half_man) + 1 ) as u16 )
475
- } else {
476
- f16 ( ( half_sign | half_exp | half_man) as u16 )
477
- }
478
- }
479
- }
480
-
481
- impl std:: fmt:: Debug for f16 {
482
- fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
483
- write ! ( f, "{:?}" , self . to_f32( ) )
484
- }
485
- }
486
-
487
- impl std:: fmt:: Display for f16 {
488
- fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
489
- write ! ( f, "{}" , self . to_f32( ) )
490
- }
491
- }
492
-
493
336
impl NativeType for f16 {
494
337
const PRIMITIVE : PrimitiveType = PrimitiveType :: Float16 ;
495
338
type Bytes = [ u8 ; 2 ] ;
496
339
#[ inline]
497
340
fn to_le_bytes ( & self ) -> Self :: Bytes {
498
- self . 0 . to_le_bytes ( )
341
+ f16 :: to_le_bytes ( * self )
499
342
}
500
343
501
344
#[ inline]
502
345
fn to_be_bytes ( & self ) -> Self :: Bytes {
503
- self . 0 . to_be_bytes ( )
346
+ f16 :: to_be_bytes ( * self )
504
347
}
505
348
506
349
#[ inline]
507
350
fn from_be_bytes ( bytes : Self :: Bytes ) -> Self {
508
- Self ( u16 :: from_be_bytes ( bytes) )
351
+ f16 :: from_be_bytes ( bytes)
509
352
}
510
353
511
354
#[ inline]
512
355
fn from_le_bytes ( bytes : Self :: Bytes ) -> Self {
513
- Self ( u16 :: from_le_bytes ( bytes) )
356
+ f16 :: from_le_bytes ( bytes)
514
357
}
515
358
}
516
359
@@ -627,11 +470,14 @@ mod test {
627
470
// diff must be <= 4 * EPSILON, as 7 has two more significant bits than 1
628
471
assert ! ( diff <= 4.0 * f16:: EPSILON . to_f32( ) ) ;
629
472
630
- assert_eq ! ( f16( 0x0000_0001 ) . to_f32( ) , 2.0f32 . powi( -24 ) ) ;
631
- assert_eq ! ( f16( 0x0000_0005 ) . to_f32( ) , 5.0 * 2.0f32 . powi( -24 ) ) ;
473
+ assert_eq ! ( f16:: from_bits ( 0x0000_0001 ) . to_f32( ) , 2.0f32 . powi( -24 ) ) ;
474
+ assert_eq ! ( f16:: from_bits ( 0x0000_0005 ) . to_f32( ) , 5.0 * 2.0f32 . powi( -24 ) ) ;
632
475
633
- assert_eq ! ( f16( 0x0000_0001 ) , f16:: from_f32( 2.0f32 . powi( -24 ) ) ) ;
634
- assert_eq ! ( f16( 0x0000_0005 ) , f16:: from_f32( 5.0 * 2.0f32 . powi( -24 ) ) ) ;
476
+ assert_eq ! ( f16:: from_bits( 0x0000_0001 ) , f16:: from_f32( 2.0f32 . powi( -24 ) ) ) ;
477
+ assert_eq ! (
478
+ f16:: from_bits( 0x0000_0005 ) ,
479
+ f16:: from_f32( 5.0 * 2.0f32 . powi( -24 ) )
480
+ ) ;
635
481
636
482
assert_eq ! ( format!( "{}" , f16:: from_f32( 7.0 ) ) , "7" . to_string( ) ) ;
637
483
assert_eq ! ( format!( "{:?}" , f16:: from_f32( 7.0 ) ) , "7.0" . to_string( ) ) ;
0 commit comments