4
4
use core:: ops:: Deref ;
5
5
6
6
use serde:: { Deserialize , Serialize } ;
7
+ use serde_byte_array:: ByteArray ;
7
8
use sha2:: { Digest as _, Sha256 } ;
8
9
use subtle:: ConstantTimeEq as _;
9
10
use trussed:: {
@@ -19,8 +20,8 @@ const SIZE: usize = 256;
19
20
const SALT_LEN : usize = 16 ;
20
21
const HASH_LEN : usize = 32 ;
21
22
22
- type Salt = [ u8 ; SALT_LEN ] ;
23
- type Hash = [ u8 ; HASH_LEN ] ;
23
+ type Salt = ByteArray < SALT_LEN > ;
24
+ type Hash = ByteArray < HASH_LEN > ;
24
25
25
26
#[ derive( Debug , Deserialize , Serialize ) ]
26
27
pub ( crate ) struct PinData {
@@ -37,7 +38,7 @@ impl PinData {
37
38
R : CryptoRng + RngCore ,
38
39
{
39
40
let mut salt = Salt :: default ( ) ;
40
- rng. fill_bytes ( & mut salt) ;
41
+ rng. fill_bytes ( salt. as_mut ( ) ) ;
41
42
let hash = hash ( id, pin, & salt) ;
42
43
Self {
43
44
id,
@@ -111,7 +112,10 @@ impl<'a> PinDataMut<'a> {
111
112
if self . is_blocked ( ) {
112
113
return false ;
113
114
}
114
- let success = hash ( self . id , pin, & self . salt ) . ct_eq ( & self . hash ) . into ( ) ;
115
+ let success = hash ( self . id , pin, & self . salt )
116
+ . as_ref ( )
117
+ . ct_eq ( self . hash . as_ref ( ) )
118
+ . into ( ) ;
115
119
if let Some ( retries) = & mut self . data . retries {
116
120
if success {
117
121
if retries. reset ( ) {
@@ -169,8 +173,8 @@ fn hash(id: PinId, pin: &Pin, salt: &Salt) -> Hash {
169
173
digest. update ( [ u8:: from ( id) ] ) ;
170
174
digest. update ( [ pin_len ( pin) ] ) ;
171
175
digest. update ( pin) ;
172
- digest. update ( salt) ;
173
- digest. finalize ( ) . into ( )
176
+ digest. update ( salt. as_ref ( ) ) ;
177
+ Hash :: new ( digest. finalize ( ) . into ( ) )
174
178
}
175
179
176
180
fn pin_len ( pin : & Pin ) -> u8 {
@@ -191,10 +195,19 @@ mod tests {
191
195
max : u8:: MAX ,
192
196
left : u8:: MAX ,
193
197
} ) ,
194
- salt : [ u8:: MAX ; SALT_LEN ] ,
195
- hash : [ u8:: MAX ; HASH_LEN ] ,
198
+ salt : [ u8:: MAX ; SALT_LEN ] . into ( ) ,
199
+ hash : [ u8:: MAX ; HASH_LEN ] . into ( ) ,
196
200
} ;
197
201
let serialized = trussed:: cbor_serialize_bytes :: < _ , 1024 > ( & data) . unwrap ( ) ;
198
202
assert ! ( serialized. len( ) <= SIZE ) ;
199
203
}
204
+
205
+ #[ test]
206
+ #[ allow( clippy:: unwrap_used) ]
207
+ fn test_salt_size ( ) {
208
+ // We allow one byte overhead for byte array serialization
209
+ let salt = Salt :: from ( [ u8:: MAX ; SALT_LEN ] ) ;
210
+ let serialized = trussed:: cbor_serialize_bytes :: < _ , 1024 > ( & salt) . unwrap ( ) ;
211
+ assert ! ( serialized. len( ) <= SALT_LEN + 1 , "{}" , serialized. len( ) ) ;
212
+ }
200
213
}
0 commit comments