Skip to content

Commit e9c9589

Browse files
committed
Vec specialization
1 parent ea592a8 commit e9c9589

File tree

2 files changed

+154
-63
lines changed

2 files changed

+154
-63
lines changed

src/decode.rs

+95-52
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::types::Header;
22
use arrayvec::ArrayVec;
33
use bytes::{Buf, Bytes, BytesMut};
4+
use core::any::Any;
45

56
pub trait Decodable: Sized {
67
fn decode(buf: &mut &[u8]) -> Result<Self, DecodeError>;
@@ -41,6 +42,38 @@ mod alloc_impl {
4142
Self::from_utf8(to).map_err(|_| DecodeError::Custom("invalid string"))
4243
}
4344
}
45+
46+
impl<T> Decodable for ::alloc::vec::Vec<T>
47+
where
48+
T: Decodable + 'static,
49+
{
50+
fn decode(buf: &mut &[u8]) -> Result<Self, DecodeError> {
51+
let h = Header::decode(buf)?;
52+
53+
let mut to = ::alloc::vec::Vec::new();
54+
if let Some(to) = <dyn Any>::downcast_mut::<::alloc::vec::Vec<u8>>(&mut to) {
55+
if h.list {
56+
return Err(DecodeError::UnexpectedList);
57+
}
58+
to.extend_from_slice(&buf[..h.payload_length]);
59+
buf.advance(h.payload_length);
60+
} else {
61+
if !h.list {
62+
return Err(DecodeError::UnexpectedString);
63+
}
64+
65+
let payload_view = &mut &buf[..h.payload_length];
66+
67+
while !payload_view.is_empty() {
68+
to.push(T::decode(payload_view)?);
69+
}
70+
71+
buf.advance(h.payload_length);
72+
}
73+
74+
Ok(to)
75+
}
76+
}
4477
}
4578

4679
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -283,21 +316,17 @@ mod ethereum_types_support {
283316
fixed_uint_impl!(U512, 64);
284317
}
285318

286-
impl<const N: usize> Decodable for [u8; N] {
287-
fn decode(from: &mut &[u8]) -> Result<Self, DecodeError> {
288-
let h = Header::decode(from)?;
289-
if h.list {
290-
return Err(DecodeError::UnexpectedList);
291-
}
292-
if h.payload_length != N {
293-
return Err(DecodeError::UnexpectedLength);
294-
}
295-
296-
let mut to = [0_u8; N];
297-
to.copy_from_slice(&from[..N]);
298-
from.advance(N);
299-
300-
Ok(to)
319+
impl<T, const LEN: usize> Decodable for [T; LEN]
320+
where
321+
T: Decodable + 'static,
322+
{
323+
fn decode(buf: &mut &[u8]) -> Result<Self, DecodeError> {
324+
ArrayVec::<T, LEN>::decode(buf)?
325+
.into_inner()
326+
.map_err(|arr| DecodeError::ListLengthMismatch {
327+
expected: LEN,
328+
got: arr.len(),
329+
})
301330
}
302331
}
303332

@@ -345,51 +374,44 @@ impl<'a> Rlp<'a> {
345374
}
346375
}
347376

348-
#[cfg(feature = "alloc")]
349-
impl<E> Decodable for alloc::vec::Vec<E>
377+
impl<T, const LEN: usize> Decodable for ArrayVec<T, LEN>
350378
where
351-
E: Decodable,
379+
T: Decodable + 'static,
352380
{
353381
fn decode(buf: &mut &[u8]) -> Result<Self, DecodeError> {
354-
let h = Header::decode(buf)?;
355-
if !h.list {
356-
return Err(DecodeError::UnexpectedString);
357-
}
358-
359-
let payload_view = &mut &buf[..h.payload_length];
360-
361-
let mut to = alloc::vec::Vec::new();
362-
while !payload_view.is_empty() {
363-
to.push(E::decode(payload_view)?);
364-
}
382+
let mut arr: ArrayVec<T, LEN> = ArrayVec::new();
383+
if let Some(s) = <dyn Any>::downcast_mut::<ArrayVec<u8, LEN>>(&mut arr) {
384+
let h = Header::decode(buf)?;
385+
if h.list {
386+
return Err(DecodeError::UnexpectedList);
387+
}
388+
if h.payload_length != LEN {
389+
return Err(DecodeError::UnexpectedLength);
390+
}
365391

366-
buf.advance(h.payload_length);
392+
s.try_extend_from_slice(&buf[..LEN]).unwrap();
393+
buf.advance(LEN);
394+
} else {
395+
let h = Header::decode(buf)?;
396+
if !h.list {
397+
return Err(DecodeError::UnexpectedString);
398+
}
367399

368-
Ok(to)
369-
}
370-
}
400+
let payload_view = &mut &buf[..h.payload_length];
371401

372-
impl<E, const LEN: usize> Decodable for ArrayVec<E, LEN>
373-
where
374-
E: Decodable,
375-
{
376-
fn decode(buf: &mut &[u8]) -> Result<Self, DecodeError> {
377-
let h = Header::decode(buf)?;
378-
if !h.list {
379-
return Err(DecodeError::UnexpectedString);
380-
}
381-
382-
let payload_view = &mut &buf[..h.payload_length];
402+
while !payload_view.is_empty() {
403+
if arr.try_push(T::decode(payload_view)?).is_err() {
404+
return Err(DecodeError::ListLengthMismatch {
405+
expected: LEN,
406+
got: LEN + 1,
407+
});
408+
}
409+
}
383410

384-
let mut to = ArrayVec::new();
385-
while !payload_view.is_empty() {
386-
to.try_push(E::decode(payload_view)?)
387-
.map_err(|_| DecodeError::Custom("arrayvec full"))?;
411+
buf.advance(h.payload_length);
388412
}
389413

390-
buf.advance(h.payload_length);
391-
392-
Ok(to)
414+
Ok(arr)
393415
}
394416
}
395417

@@ -419,7 +441,7 @@ mod tests {
419441

420442
fn check_decode_list<T, IT>(fixtures: IT)
421443
where
422-
T: Decodable + PartialEq + Debug,
444+
T: Decodable + PartialEq + Debug + 'static,
423445
IT: IntoIterator<Item = (Result<alloc::vec::Vec<T>, DecodeError>, &'static [u8])>,
424446
{
425447
for (expected, mut input) in fixtures {
@@ -640,4 +662,25 @@ mod tests {
640662
),
641663
])
642664
}
665+
666+
#[test]
667+
fn vec_specialization() {
668+
const SPECIALIZED: [u8; 2] = [0x42_u8, 0x43_u8];
669+
const GENERAL: [u64; 2] = [0xFFCCB5_u64, 0xFFC0B5_u64];
670+
671+
const SPECIALIZED_EXP: &[u8] = &hex!("824243");
672+
const GENERAL_EXP: &[u8] = &hex!("C883FFCCB583FFC0B5");
673+
674+
check_decode([(Ok(SPECIALIZED), SPECIALIZED_EXP)]);
675+
check_decode([(Ok(GENERAL), GENERAL_EXP)]);
676+
677+
check_decode([(Ok(ArrayVec::from(SPECIALIZED)), SPECIALIZED_EXP)]);
678+
check_decode([(Ok(ArrayVec::from(GENERAL)), GENERAL_EXP)]);
679+
680+
#[cfg(feature = "alloc")]
681+
{
682+
check_decode([(Ok(SPECIALIZED.to_vec()), SPECIALIZED_EXP)]);
683+
check_decode([(Ok(GENERAL.to_vec()), GENERAL_EXP)]);
684+
}
685+
}
643686
}

src/encode.rs

+59-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::types::*;
22
use arrayvec::ArrayVec;
33
use auto_impl::auto_impl;
44
use bytes::{BufMut, Bytes, BytesMut};
5-
use core::{borrow::Borrow, mem::size_of};
5+
use core::{any::Any, borrow::Borrow, mem::size_of};
66

77
pub fn zeroless_view(v: &impl AsRef<[u8]>) -> &[u8] {
88
let v = v.as_ref();
@@ -95,13 +95,24 @@ impl<'a> Encodable for &'a [u8] {
9595
}
9696
}
9797

98-
impl<const LEN: usize> Encodable for [u8; LEN] {
98+
impl<T, const LEN: usize> Encodable for [T; LEN]
99+
where
100+
T: Encodable + 'static,
101+
{
99102
fn length(&self) -> usize {
100-
(self as &[u8]).length()
103+
if let Some(s) = <dyn Any>::downcast_ref::<[u8; LEN]>(self) {
104+
(s as &[u8]).length()
105+
} else {
106+
list_length(self)
107+
}
101108
}
102109

103110
fn encode(&self, out: &mut dyn BufMut) {
104-
(self as &[u8]).encode(out)
111+
if let Some(s) = <dyn Any>::downcast_ref::<[u8; LEN]>(self) {
112+
(s as &[u8]).encode(out)
113+
} else {
114+
encode_list(self, out)
115+
}
105116
}
106117
}
107118

@@ -264,14 +275,22 @@ mod alloc_support {
264275

265276
impl<T> Encodable for ::alloc::vec::Vec<T>
266277
where
267-
T: Encodable,
278+
T: Encodable + 'static,
268279
{
269280
fn length(&self) -> usize {
270-
list_length(self)
281+
if let Some(s) = <dyn Any>::downcast_ref::<::alloc::vec::Vec<u8>>(self) {
282+
(s as &[u8]).length()
283+
} else {
284+
list_length(self)
285+
}
271286
}
272287

273288
fn encode(&self, out: &mut dyn BufMut) {
274-
encode_list(self, out)
289+
if let Some(s) = <dyn Any>::downcast_ref::<::alloc::vec::Vec<u8>>(self) {
290+
(s as &[u8]).encode(out)
291+
} else {
292+
encode_list(self, out)
293+
}
275294
}
276295
}
277296

@@ -287,14 +306,22 @@ mod alloc_support {
287306

288307
impl<T, const LEN: usize> Encodable for ArrayVec<T, LEN>
289308
where
290-
T: Encodable,
309+
T: Encodable + 'static,
291310
{
292311
fn length(&self) -> usize {
293-
list_length(self)
312+
if let Some(s) = <dyn Any>::downcast_ref::<ArrayVec<u8, LEN>>(self) {
313+
(s as &[u8]).length()
314+
} else {
315+
list_length(self)
316+
}
294317
}
295318

296319
fn encode(&self, out: &mut dyn BufMut) {
297-
encode_list(self, out)
320+
if let Some(s) = <dyn Any>::downcast_ref::<ArrayVec<u8, LEN>>(self) {
321+
(s as &[u8]).encode(out)
322+
} else {
323+
encode_list(self, out)
324+
}
298325
}
299326
}
300327
slice_impl!(Bytes);
@@ -364,7 +391,7 @@ mod tests {
364391
out
365392
}
366393

367-
fn encoded_list<T: Encodable + Clone>(t: &[T]) -> BytesMut {
394+
fn encoded_list<T: Encodable + Clone + 'static>(t: &[T]) -> BytesMut {
368395
let mut out1 = BytesMut::new();
369396
encode_list(t, &mut out1);
370397

@@ -530,4 +557,25 @@ mod tests {
530557
&hex!("c883ffccb583ffc0b5")[..]
531558
);
532559
}
560+
561+
#[test]
562+
fn vec_specialization() {
563+
const SPECIALIZED: [u8; 2] = [0x42_u8, 0x43_u8];
564+
const GENERAL: [u64; 2] = [0xFFCCB5_u64, 0xFFC0B5_u64];
565+
566+
const SPECIALIZED_EXP: &[u8] = &hex!("824243");
567+
const GENERAL_EXP: &[u8] = &hex!("C883FFCCB583FFC0B5");
568+
569+
assert_eq!(&*encoded(SPECIALIZED), SPECIALIZED_EXP);
570+
assert_eq!(&*encoded(GENERAL), GENERAL_EXP);
571+
572+
assert_eq!(&*encoded(ArrayVec::from(SPECIALIZED)), SPECIALIZED_EXP);
573+
assert_eq!(&*encoded(ArrayVec::from(GENERAL)), GENERAL_EXP);
574+
575+
#[cfg(feature = "alloc")]
576+
{
577+
assert_eq!(&*encoded(SPECIALIZED.to_vec()), SPECIALIZED_EXP);
578+
assert_eq!(&*encoded(GENERAL.to_vec()), GENERAL_EXP);
579+
}
580+
}
533581
}

0 commit comments

Comments
 (0)