1
1
use crate :: types:: Header ;
2
2
use arrayvec:: ArrayVec ;
3
3
use bytes:: { Buf , Bytes , BytesMut } ;
4
- use core:: any:: Any ;
4
+ use core:: { any:: Any , num :: NonZeroUsize } ;
5
5
6
6
pub trait Decodable : Sized {
7
7
fn decode ( buf : & mut & [ u8 ] ) -> Result < Self , DecodeError > ;
@@ -80,7 +80,7 @@ mod alloc_impl {
80
80
pub enum DecodeError {
81
81
Overflow ,
82
82
LeadingZero ,
83
- InputTooShort ,
83
+ InputTooShort { needed : Option < NonZeroUsize > } ,
84
84
NonCanonicalSingleByte ,
85
85
NonCanonicalSize ,
86
86
UnexpectedLength ,
@@ -98,7 +98,14 @@ impl core::fmt::Display for DecodeError {
98
98
match self {
99
99
DecodeError :: Overflow => write ! ( f, "overflow" ) ,
100
100
DecodeError :: LeadingZero => write ! ( f, "leading zero" ) ,
101
- DecodeError :: InputTooShort => write ! ( f, "input too short" ) ,
101
+ DecodeError :: InputTooShort { needed } => {
102
+ write ! ( f, "input too short" ) ?;
103
+ if let Some ( needed) = needed {
104
+ write ! ( f, ": need {needed} more bytes" ) ?;
105
+ }
106
+
107
+ Ok ( ( ) )
108
+ }
102
109
DecodeError :: NonCanonicalSingleByte => write ! ( f, "non-canonical single byte" ) ,
103
110
DecodeError :: NonCanonicalSize => write ! ( f, "non-canonical size" ) ,
104
111
DecodeError :: UnexpectedLength => write ! ( f, "unexpected length" ) ,
@@ -115,7 +122,7 @@ impl core::fmt::Display for DecodeError {
115
122
impl Header {
116
123
pub fn decode ( buf : & mut & [ u8 ] ) -> Result < Self , DecodeError > {
117
124
if !buf. has_remaining ( ) {
118
- return Err ( DecodeError :: InputTooShort ) ;
125
+ return Err ( DecodeError :: InputTooShort { needed : None } ) ;
119
126
}
120
127
121
128
let b = buf[ 0 ] ;
@@ -134,7 +141,7 @@ impl Header {
134
141
135
142
if h. payload_length == 1 {
136
143
if !buf. has_remaining ( ) {
137
- return Err ( DecodeError :: InputTooShort ) ;
144
+ return Err ( DecodeError :: InputTooShort { needed : None } ) ;
138
145
}
139
146
if buf[ 0 ] < 0x80 {
140
147
return Err ( DecodeError :: NonCanonicalSingleByte ) ;
@@ -145,8 +152,13 @@ impl Header {
145
152
} else if b < 0xC0 {
146
153
buf. advance ( 1 ) ;
147
154
let len_of_len = b as usize - 0xB7 ;
148
- if buf. len ( ) < len_of_len {
149
- return Err ( DecodeError :: InputTooShort ) ;
155
+ if let Some ( needed) = len_of_len
156
+ . checked_sub ( buf. len ( ) )
157
+ . and_then ( NonZeroUsize :: new)
158
+ {
159
+ return Err ( DecodeError :: InputTooShort {
160
+ needed : Some ( needed) ,
161
+ } ) ;
150
162
}
151
163
let payload_length = usize:: try_from ( u64:: from_be_bytes (
152
164
static_left_pad ( & buf[ ..len_of_len] ) . ok_or ( DecodeError :: LeadingZero ) ?,
@@ -171,8 +183,13 @@ impl Header {
171
183
buf. advance ( 1 ) ;
172
184
let list = true ;
173
185
let len_of_len = b as usize - 0xF7 ;
174
- if buf. len ( ) < len_of_len {
175
- return Err ( DecodeError :: InputTooShort ) ;
186
+ if let Some ( needed) = len_of_len
187
+ . checked_sub ( buf. len ( ) )
188
+ . and_then ( NonZeroUsize :: new)
189
+ {
190
+ return Err ( DecodeError :: InputTooShort {
191
+ needed : Some ( needed) ,
192
+ } ) ;
176
193
}
177
194
let payload_length = usize:: try_from ( u64:: from_be_bytes (
178
195
static_left_pad ( & buf[ ..len_of_len] ) . ok_or ( DecodeError :: LeadingZero ) ?,
@@ -190,8 +207,14 @@ impl Header {
190
207
}
191
208
} ;
192
209
193
- if buf. remaining ( ) < h. payload_length {
194
- return Err ( DecodeError :: InputTooShort ) ;
210
+ if let Some ( needed) = h
211
+ . payload_length
212
+ . checked_sub ( buf. remaining ( ) )
213
+ . and_then ( NonZeroUsize :: new)
214
+ {
215
+ return Err ( DecodeError :: InputTooShort {
216
+ needed : Some ( needed) ,
217
+ } ) ;
195
218
}
196
219
197
220
Ok ( h)
@@ -228,8 +251,14 @@ macro_rules! decode_integer {
228
251
if h. payload_length > ( <$t>:: BITS as usize / 8 ) {
229
252
return Err ( DecodeError :: Overflow ) ;
230
253
}
231
- if buf. remaining( ) < h. payload_length {
232
- return Err ( DecodeError :: InputTooShort ) ;
254
+ if let Some ( needed) = h
255
+ . payload_length
256
+ . checked_sub( buf. remaining( ) )
257
+ . and_then( NonZeroUsize :: new)
258
+ {
259
+ return Err ( DecodeError :: InputTooShort {
260
+ needed: Some ( needed) ,
261
+ } ) ;
233
262
}
234
263
let v = <$t>:: from_be_bytes(
235
264
static_left_pad( & buf[ ..h. payload_length] ) . ok_or( DecodeError :: LeadingZero ) ?,
@@ -296,8 +325,14 @@ mod ethereum_types_support {
296
325
if h. payload_length > $n_bytes {
297
326
return Err ( DecodeError :: Overflow ) ;
298
327
}
299
- if buf. remaining( ) < h. payload_length {
300
- return Err ( DecodeError :: InputTooShort ) ;
328
+ if let Some ( needed) = h
329
+ . payload_length
330
+ . checked_sub( buf. remaining( ) )
331
+ . and_then( NonZeroUsize :: new)
332
+ {
333
+ return Err ( DecodeError :: InputTooShort {
334
+ needed: Some ( needed) ,
335
+ } ) ;
301
336
}
302
337
let n = <$t>:: from_big_endian(
303
338
& static_left_pad:: <$n_bytes>( & buf[ ..h. payload_length] )
@@ -494,7 +529,9 @@ mod tests {
494
529
& hex!( "8AFFFFFFFFFFFFFFFFFF7C" ) [ ..] ,
495
530
) ,
496
531
(
497
- Err ( DecodeError :: InputTooShort ) ,
532
+ Err ( DecodeError :: InputTooShort {
533
+ needed: Some ( NonZeroUsize :: new( 1 ) . unwrap( ) ) ,
534
+ } ) ,
498
535
& hex!( "8BFFFFFFFFFFFFFFFFFF7C" ) [ ..] ,
499
536
) ,
500
537
( Err ( DecodeError :: UnexpectedList ) , & hex!( "C0" ) [ ..] ) ,
@@ -521,7 +558,9 @@ mod tests {
521
558
& hex!( "8AFFFFFFFFFFFFFFFFFF7C" ) [ ..] ,
522
559
) ,
523
560
(
524
- Err ( DecodeError :: InputTooShort ) ,
561
+ Err ( DecodeError :: InputTooShort {
562
+ needed: Some ( NonZeroUsize :: new( 1 ) . unwrap( ) ) ,
563
+ } ) ,
525
564
& hex!( "8BFFFFFFFFFFFFFFFFFF7C" ) [ ..] ,
526
565
) ,
527
566
( Err ( DecodeError :: UnexpectedList ) , & hex!( "C0" ) [ ..] ) ,
@@ -549,7 +588,9 @@ mod tests {
549
588
& hex!( "8AFFFFFFFFFFFFFFFFFF7C" ) [ ..] ,
550
589
) ,
551
590
(
552
- Err ( DecodeError :: InputTooShort ) ,
591
+ Err ( DecodeError :: InputTooShort {
592
+ needed: Some ( NonZeroUsize :: new( 1 ) . unwrap( ) ) ,
593
+ } ) ,
553
594
& hex!( "8BFFFFFFFFFFFFFFFFFF7C" ) [ ..] ,
554
595
) ,
555
596
( Err ( DecodeError :: UnexpectedList ) , & hex!( "C0" ) [ ..] ) ,
@@ -577,7 +618,9 @@ mod tests {
577
618
& hex!( "8AFFFFFFFFFFFFFFFFFF7C" ) [ ..] ,
578
619
) ,
579
620
(
580
- Err ( DecodeError :: InputTooShort ) ,
621
+ Err ( DecodeError :: InputTooShort {
622
+ needed: Some ( NonZeroUsize :: new( 1 ) . unwrap( ) ) ,
623
+ } ) ,
581
624
& hex!( "8BFFFFFFFFFFFFFFFFFF7C" ) [ ..] ,
582
625
) ,
583
626
( Err ( DecodeError :: UnexpectedList ) , & hex!( "C0" ) [ ..] ) ,
@@ -605,7 +648,9 @@ mod tests {
605
648
& hex!( "8AFFFFFFFFFFFFFFFFFF7C" ) [ ..] ,
606
649
) ,
607
650
(
608
- Err ( DecodeError :: InputTooShort ) ,
651
+ Err ( DecodeError :: InputTooShort {
652
+ needed: Some ( NonZeroUsize :: new( 1 ) . unwrap( ) ) ,
653
+ } ) ,
609
654
& hex!( "8BFFFFFFFFFFFFFFFFFF7C" ) [ ..] ,
610
655
) ,
611
656
( Err ( DecodeError :: UnexpectedList ) , & hex!( "C0" ) [ ..] ) ,
@@ -633,7 +678,9 @@ mod tests {
633
678
& hex!( "8AFFFFFFFFFFFFFFFFFF7C" ) [ ..] ,
634
679
) ,
635
680
(
636
- Err ( DecodeError :: InputTooShort ) ,
681
+ Err ( DecodeError :: InputTooShort {
682
+ needed: Some ( NonZeroUsize :: new( 1 ) . unwrap( ) ) ,
683
+ } ) ,
637
684
& hex!( "8BFFFFFFFFFFFFFFFFFF7C" ) [ ..] ,
638
685
) ,
639
686
( Err ( DecodeError :: UnexpectedList ) , & hex!( "C0" ) [ ..] ) ,
0 commit comments