Skip to content

Commit ac700f8

Browse files
authored
feat(scale): add range checks to decodeUint function (#2683)
1 parent 62d750d commit ac700f8

File tree

5 files changed

+179
-40
lines changed

5 files changed

+179
-40
lines changed

internal/trie/node/decode_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func Test_decodeBranch(t *testing.T) {
166166
variant: branchVariant.bits,
167167
partialKeyLength: 1,
168168
errWrapped: ErrDecodeChildHash,
169-
errMessage: "cannot decode child hash: at index 10: EOF",
169+
errMessage: "cannot decode child hash: at index 10: reading byte: EOF",
170170
},
171171
"success for branch variant": {
172172
reader: bytes.NewBuffer(
@@ -203,7 +203,7 @@ func Test_decodeBranch(t *testing.T) {
203203
variant: branchWithValueVariant.bits,
204204
partialKeyLength: 1,
205205
errWrapped: ErrDecodeValue,
206-
errMessage: "cannot decode value: EOF",
206+
errMessage: "cannot decode value: reading byte: EOF",
207207
},
208208
"success for branch with value": {
209209
reader: bytes.NewBuffer(concatByteSlices([][]byte{
@@ -333,7 +333,7 @@ func Test_decodeLeaf(t *testing.T) {
333333
variant: leafVariant.bits,
334334
partialKeyLength: 1,
335335
errWrapped: ErrDecodeValue,
336-
errMessage: "cannot decode value: could not decode invalid integer",
336+
errMessage: "cannot decode value: unknown prefix for compact uint: 255",
337337
},
338338
"zero value": {
339339
reader: bytes.NewBuffer([]byte{

lib/runtime/version_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func Test_DecodeVersion(t *testing.T) {
3939
{255, 255}, // error
4040
}),
4141
errWrapped: ErrDecodingVersionField,
42-
errMessage: "decoding version field impl name: could not decode invalid integer",
42+
errMessage: "decoding version field impl name: unknown prefix for compact uint: 255",
4343
},
4444
// TODO add transaction version decode error once
4545
// https://github.com/ChainSafe/gossamer/pull/2683

pkg/scale/decode.go

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ func (ds *decodeState) decodeVaryingDataTypeSlice(dstv reflect.Value) (err error
335335
if err != nil {
336336
return
337337
}
338-
for i := 0; i < l; i++ {
338+
for i := uint(0); i < l; i++ {
339339
vdt := vdts.VaryingDataType
340340
vdtv := reflect.New(reflect.TypeOf(vdt))
341341
vdtv.Elem().Set(reflect.ValueOf(vdt))
@@ -397,7 +397,7 @@ func (ds *decodeState) decodeSlice(dstv reflect.Value) (err error) {
397397
}
398398
in := dstv.Interface()
399399
temp := reflect.New(reflect.ValueOf(in).Type())
400-
for i := 0; i < l; i++ {
400+
for i := uint(0); i < l; i++ {
401401
tempElemType := reflect.TypeOf(in).Elem()
402402
tempElem := reflect.New(tempElemType).Elem()
403403

@@ -478,59 +478,90 @@ func (ds *decodeState) decodeBool(dstv reflect.Value) (err error) {
478478

479479
// decodeUint will decode unsigned integer
480480
func (ds *decodeState) decodeUint(dstv reflect.Value) (err error) {
481-
b, err := ds.ReadByte()
481+
const maxUint32 = ^uint32(0)
482+
const maxUint64 = ^uint64(0)
483+
prefix, err := ds.ReadByte()
482484
if err != nil {
483-
return
485+
return fmt.Errorf("reading byte: %w", err)
484486
}
485487

486488
in := dstv.Interface()
487489
temp := reflect.New(reflect.TypeOf(in))
488490
// check mode of encoding, stored at 2 least significant bits
489-
mode := b & 3
490-
switch {
491-
case mode <= 2:
492-
var val int64
493-
val, err = ds.decodeSmallInt(b, mode)
491+
mode := prefix % 4
492+
var value uint64
493+
switch mode {
494+
case 0:
495+
value = uint64(prefix >> 2)
496+
case 1:
497+
buf, err := ds.ReadByte()
494498
if err != nil {
495-
return
499+
return fmt.Errorf("reading byte: %w", err)
496500
}
497-
temp.Elem().Set(reflect.ValueOf(val).Convert(reflect.TypeOf(in)))
498-
dstv.Set(temp.Elem())
499-
default:
500-
// >4 byte mode
501-
topSixBits := b >> 2
502-
byteLen := uint(topSixBits) + 4
503-
501+
value = uint64(binary.LittleEndian.Uint16([]byte{prefix, buf}) >> 2)
502+
if value <= 0b0011_1111 || value > 0b0111_1111_1111_1111 {
503+
return fmt.Errorf("%w: %d (%b)", ErrU16OutOfRange, value, value)
504+
}
505+
case 2:
506+
buf := make([]byte, 3)
507+
_, err = ds.Read(buf)
508+
if err != nil {
509+
return fmt.Errorf("reading bytes: %w", err)
510+
}
511+
value = uint64(binary.LittleEndian.Uint32(append([]byte{prefix}, buf...)) >> 2)
512+
if value <= 0b0011_1111_1111_1111 || value > uint64(maxUint32>>2) {
513+
return fmt.Errorf("%w: %d (%b)", ErrU32OutOfRange, value, value)
514+
}
515+
case 3:
516+
byteLen := (prefix >> 2) + 4
504517
buf := make([]byte, byteLen)
505518
_, err = ds.Read(buf)
506519
if err != nil {
507-
return
520+
return fmt.Errorf("reading bytes: %w", err)
508521
}
509-
510-
var o uint64
511-
if byteLen == 4 {
512-
o = uint64(binary.LittleEndian.Uint32(buf))
513-
} else if byteLen > 4 && byteLen <= 8 {
522+
switch byteLen {
523+
case 4:
524+
value = uint64(binary.LittleEndian.Uint32(buf))
525+
if value <= uint64(maxUint32>>2) {
526+
return fmt.Errorf("%w: %d (%b)", ErrU32OutOfRange, value, value)
527+
}
528+
case 8:
529+
const uintSize = 32 << (^uint(0) >> 32 & 1)
530+
if uintSize == 32 {
531+
return ErrU64NotSupported
532+
}
514533
tmp := make([]byte, 8)
515534
copy(tmp, buf)
516-
o = binary.LittleEndian.Uint64(tmp)
517-
} else {
518-
err = errors.New("could not decode invalid integer")
519-
return
535+
value = binary.LittleEndian.Uint64(tmp)
536+
if value <= maxUint64>>8 {
537+
return fmt.Errorf("%w: %d (%b)", ErrU64OutOfRange, value, value)
538+
}
539+
default:
540+
return fmt.Errorf("%w: %d", ErrCompactUintPrefixUnknown, prefix)
541+
520542
}
521-
dstv.Set(reflect.ValueOf(o).Convert(reflect.TypeOf(in)))
522543
}
544+
temp.Elem().Set(reflect.ValueOf(value).Convert(reflect.TypeOf(in)))
545+
dstv.Set(temp.Elem())
523546
return
524547
}
525548

549+
var (
550+
ErrU16OutOfRange = errors.New("uint16 out of range")
551+
ErrU32OutOfRange = errors.New("uint32 out of range")
552+
ErrU64OutOfRange = errors.New("uint64 out of range")
553+
ErrU64NotSupported = errors.New("uint64 is not supported")
554+
ErrCompactUintPrefixUnknown = errors.New("unknown prefix for compact uint")
555+
)
556+
526557
// decodeLength is helper method which calls decodeUint and casts to int
527-
func (ds *decodeState) decodeLength() (l int, err error) {
558+
func (ds *decodeState) decodeLength() (l uint, err error) {
528559
dstv := reflect.New(reflect.TypeOf(l))
529560
err = ds.decodeUint(dstv.Elem())
530561
if err != nil {
531562
return
532563
}
533-
l = dstv.Elem().Interface().(int)
564+
l = dstv.Elem().Interface().(uint)
534565
return
535566
}
536567

pkg/scale/decode_test.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/google/go-cmp/cmp"
1313
"github.com/google/go-cmp/cmp/cmpopts"
14+
"github.com/stretchr/testify/assert"
1415
)
1516

1617
func Test_decodeState_decodeFixedWidthInt(t *testing.T) {
@@ -302,3 +303,101 @@ func Test_Decoder_Decode_MultipleCalls(t *testing.T) {
302303
})
303304
}
304305
}
306+
307+
func Test_decodeState_decodeUint(t *testing.T) {
308+
t.Parallel()
309+
decodeUint32Tests := tests{
310+
{
311+
name: "int(1) mode 0",
312+
in: uint32(1),
313+
want: []byte{0x04},
314+
},
315+
{
316+
name: "int(16383) mode 1",
317+
in: int(16383),
318+
want: []byte{0xfd, 0xff},
319+
},
320+
{
321+
name: "int(1073741823) mode 2",
322+
in: int(1073741823),
323+
want: []byte{0xfe, 0xff, 0xff, 0xff},
324+
},
325+
{
326+
name: "int(4294967295) mode 3",
327+
in: int(4294967295),
328+
want: []byte{0x3, 0xff, 0xff, 0xff, 0xff},
329+
},
330+
{
331+
name: "myCustomInt(9223372036854775807) mode 3, 64bit",
332+
in: myCustomInt(9223372036854775807),
333+
want: []byte{19, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
334+
},
335+
{
336+
name: "uint(overload)",
337+
in: int(0),
338+
want: []byte{0x07, 0x08, 0x09, 0x10, 0x0, 0x40},
339+
wantErr: true,
340+
},
341+
{
342+
name: "uint(16384) mode 2",
343+
in: int(16384),
344+
want: []byte{0x02, 0x00, 0x01, 0x0},
345+
},
346+
{
347+
name: "uint(0) mode 1, error",
348+
in: int(0),
349+
want: []byte{0x01, 0x00},
350+
wantErr: true,
351+
},
352+
{
353+
name: "uint(0) mode 2, error",
354+
in: int(0),
355+
want: []byte{0x02, 0x00, 0x00, 0x0},
356+
wantErr: true,
357+
},
358+
{
359+
name: "uint(0) mode 3, error",
360+
in: int(0),
361+
want: []byte{0x03, 0x00, 0x00, 0x0},
362+
wantErr: true,
363+
},
364+
{
365+
name: "mode 3, 64bit, error",
366+
in: int(0),
367+
want: []byte{19, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
368+
wantErr: true,
369+
},
370+
{
371+
name: "[]int{1 << 32, 2, 3, 1 << 32}",
372+
in: uint(4),
373+
want: []byte{0x10, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x0c, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01},
374+
},
375+
{
376+
name: "[4]int{1 << 32, 2, 3, 1 << 32}",
377+
in: [4]int{0, 0, 0, 0},
378+
want: []byte{0x07, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x0c, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01},
379+
wantErr: true,
380+
},
381+
}
382+
383+
for _, tt := range decodeUint32Tests {
384+
tt := tt
385+
t.Run(tt.name, func(t *testing.T) {
386+
t.Parallel()
387+
dst := reflect.New(reflect.TypeOf(tt.in)).Elem().Interface()
388+
dstv := reflect.ValueOf(&dst)
389+
elem := indirect(dstv)
390+
391+
ds := decodeState{
392+
Reader: bytes.NewBuffer(tt.want),
393+
}
394+
err := ds.decodeUint(elem)
395+
if tt.wantErr {
396+
assert.Error(t, err)
397+
} else {
398+
assert.NoError(t, err)
399+
}
400+
assert.Equal(t, tt.in, dst)
401+
})
402+
}
403+
}

pkg/scale/encode_test.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,11 @@ var (
176176
in: int(1),
177177
want: []byte{0x04},
178178
},
179+
{
180+
name: "int(42)",
181+
in: int(42),
182+
want: []byte{0xa8},
183+
},
179184
{
180185
name: "int(16383)",
181186
in: int(16383),
@@ -821,9 +826,11 @@ var (
821826
want: []byte{0x10, 0x03, 0x00, 0x00, 0x00, 0x40, 0x08, 0x0c, 0x10},
822827
},
823828
{
824-
name: "[]int{1 << 32, 2, 3, 1 << 32}",
825-
in: []int{1 << 32, 2, 3, 1 << 32},
826-
want: []byte{0x10, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x0c, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01},
829+
name: "[]int64{1 << 32, 2, 3, 1 << 32}",
830+
in: []int64{1 << 32, 2, 3, 1 << 32},
831+
want: []byte{0x10, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00,
832+
0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
833+
0x00},
827834
},
828835
{
829836
name: "[]bool{true, false, true}",
@@ -864,9 +871,11 @@ var (
864871
want: []byte{0x03, 0x00, 0x00, 0x00, 0x40, 0x08, 0x0c, 0x10},
865872
},
866873
{
867-
name: "[4]int{1 << 32, 2, 3, 1 << 32}",
868-
in: [4]int{1 << 32, 2, 3, 1 << 32},
869-
want: []byte{0x07, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x0c, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01},
874+
name: "[4]int64{1 << 32, 2, 3, 1 << 32}",
875+
in: [4]int64{1 << 32, 2, 3, 1 << 32},
876+
want: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00,
877+
0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
878+
0x00},
870879
},
871880
{
872881
name: "[3]bool{true, false, true}",

0 commit comments

Comments
 (0)