Skip to content

Commit af5c63f

Browse files
authored
feat(pkg/scale): add Encoder with Encode method (#2741)
- Change `encodeState` to use `io.Writer` instead of `bytes.Buffer` - Define `Encoder` with `Encode(value interface{}) error` method - Define constructor `NewEncoder(writer io.Writer) *Encoder` - Add unit tests for encoder
1 parent 363c080 commit af5c63f

File tree

3 files changed

+180
-46
lines changed

3 files changed

+180
-46
lines changed

pkg/scale/encode.go

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,48 @@ import (
77
"bytes"
88
"encoding/binary"
99
"fmt"
10+
"io"
1011
"math/big"
1112
"reflect"
1213
)
1314

15+
// Encoder scale encodes to a given io.Writer.
16+
type Encoder struct {
17+
encodeState
18+
}
19+
20+
// NewEncoder creates a new encoder with the given writer.
21+
func NewEncoder(writer io.Writer) (encoder *Encoder) {
22+
return &Encoder{
23+
encodeState: encodeState{
24+
Writer: writer,
25+
fieldScaleIndicesCache: cache,
26+
},
27+
}
28+
}
29+
30+
// Encode scale encodes value to the encoder writer.
31+
func (e *Encoder) Encode(value interface{}) (err error) {
32+
return e.marshal(value)
33+
}
34+
1435
// Marshal takes in an interface{} and attempts to marshal into []byte
1536
func Marshal(v interface{}) (b []byte, err error) {
37+
buffer := bytes.NewBuffer(nil)
1638
es := encodeState{
39+
Writer: buffer,
1740
fieldScaleIndicesCache: cache,
1841
}
1942
err = es.marshal(v)
2043
if err != nil {
2144
return
2245
}
23-
b = es.Bytes()
46+
b = buffer.Bytes()
2447
return
2548
}
2649

2750
type encodeState struct {
28-
bytes.Buffer
51+
io.Writer
2952
*fieldScaleIndicesCache
3053
}
3154

@@ -64,9 +87,9 @@ func (es *encodeState) marshal(in interface{}) (err error) {
6487
elem := reflect.ValueOf(in).Elem()
6588
switch elem.IsValid() {
6689
case false:
67-
err = es.WriteByte(0)
90+
_, err = es.Write([]byte{0})
6891
default:
69-
err = es.WriteByte(1)
92+
_, err = es.Write([]byte{1})
7093
if err != nil {
7194
return
7295
}
@@ -133,13 +156,13 @@ func (es *encodeState) encodeResult(res Result) (err error) {
133156
var in interface{}
134157
switch res.mode {
135158
case OK:
136-
err = es.WriteByte(0)
159+
_, err = es.Write([]byte{0})
137160
if err != nil {
138161
return
139162
}
140163
in = res.ok
141164
case Err:
142-
err = es.WriteByte(1)
165+
_, err = es.Write([]byte{1})
143166
if err != nil {
144167
return
145168
}
@@ -159,7 +182,7 @@ func (es *encodeState) encodeCustomVaryingDataType(in interface{}) (err error) {
159182
}
160183

161184
func (es *encodeState) encodeVaryingDataType(vdt VaryingDataType) (err error) {
162-
err = es.WriteByte(byte(vdt.value.Index()))
185+
_, err = es.Write([]byte{byte(vdt.value.Index())})
163186
if err != nil {
164187
return
165188
}

pkg/scale/encode_test.go

Lines changed: 135 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,74 @@
44
package scale
55

66
import (
7+
"bytes"
78
"math/big"
89
"reflect"
910
"strings"
1011
"testing"
12+
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
1115
)
1216

17+
func Test_NewEncoder(t *testing.T) {
18+
t.Parallel()
19+
20+
cache.Lock()
21+
defer cache.Unlock()
22+
23+
writer := bytes.NewBuffer(nil)
24+
encoder := NewEncoder(writer)
25+
26+
expectedEncoder := &Encoder{
27+
encodeState: encodeState{
28+
Writer: writer,
29+
fieldScaleIndicesCache: cache,
30+
},
31+
}
32+
33+
assert.Equal(t, expectedEncoder, encoder)
34+
}
35+
36+
func Test_Encoder_Encode(t *testing.T) {
37+
t.Parallel()
38+
39+
buffer := bytes.NewBuffer(nil)
40+
encoder := NewEncoder(buffer)
41+
42+
err := encoder.Encode(uint16(1))
43+
require.NoError(t, err)
44+
45+
err = encoder.Encode(uint8(2))
46+
require.NoError(t, err)
47+
48+
array := [2]byte{4, 5}
49+
err = encoder.Encode(array)
50+
require.NoError(t, err)
51+
52+
type T struct {
53+
Array [2]byte
54+
}
55+
56+
someStruct := T{Array: [2]byte{6, 7}}
57+
err = encoder.Encode(someStruct)
58+
require.NoError(t, err)
59+
60+
structSlice := []T{{Array: [2]byte{8, 9}}}
61+
err = encoder.Encode(structSlice)
62+
require.NoError(t, err)
63+
64+
written := buffer.Bytes()
65+
expectedWritten := []byte{
66+
1, 0,
67+
2,
68+
4, 5,
69+
6, 7,
70+
4, 8, 9,
71+
}
72+
assert.Equal(t, expectedWritten, written)
73+
}
74+
1375
type test struct {
1476
name string
1577
in interface{}
@@ -869,12 +931,15 @@ type MyStructWithPrivate struct {
869931
func Test_encodeState_encodeFixedWidthInteger(t *testing.T) {
870932
for _, tt := range fixedWidthIntegerTests {
871933
t.Run(tt.name, func(t *testing.T) {
872-
es := &encodeState{}
934+
buffer := bytes.NewBuffer(nil)
935+
es := &encodeState{
936+
Writer: buffer,
937+
}
873938
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
874939
t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr)
875940
}
876-
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
877-
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
941+
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
942+
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
878943
}
879944
})
880945
}
@@ -883,12 +948,15 @@ func Test_encodeState_encodeFixedWidthInteger(t *testing.T) {
883948
func Test_encodeState_encodeVariableWidthIntegers(t *testing.T) {
884949
for _, tt := range variableWidthIntegerTests {
885950
t.Run(tt.name, func(t *testing.T) {
886-
es := &encodeState{}
951+
buffer := bytes.NewBuffer(nil)
952+
es := &encodeState{
953+
Writer: buffer,
954+
}
887955
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
888956
t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr)
889957
}
890-
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
891-
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
958+
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
959+
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
892960
}
893961
})
894962
}
@@ -897,12 +965,15 @@ func Test_encodeState_encodeVariableWidthIntegers(t *testing.T) {
897965
func Test_encodeState_encodeBigInt(t *testing.T) {
898966
for _, tt := range bigIntTests {
899967
t.Run(tt.name, func(t *testing.T) {
900-
es := &encodeState{}
968+
buffer := bytes.NewBuffer(nil)
969+
es := &encodeState{
970+
Writer: buffer,
971+
}
901972
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
902973
t.Errorf("encodeState.encodeBigInt() error = %v, wantErr %v", err, tt.wantErr)
903974
}
904-
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
905-
t.Errorf("encodeState.encodeBigInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
975+
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
976+
t.Errorf("encodeState.encodeBigInt() = %v, want %v", buffer.Bytes(), tt.want)
906977
}
907978
})
908979
}
@@ -911,12 +982,15 @@ func Test_encodeState_encodeBigInt(t *testing.T) {
911982
func Test_encodeState_encodeUint128(t *testing.T) {
912983
for _, tt := range uint128Tests {
913984
t.Run(tt.name, func(t *testing.T) {
914-
es := &encodeState{}
985+
buffer := bytes.NewBuffer(nil)
986+
es := &encodeState{
987+
Writer: buffer,
988+
}
915989
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
916990
t.Errorf("encodeState.encodeUin128() error = %v, wantErr %v", err, tt.wantErr)
917991
}
918-
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
919-
t.Errorf("encodeState.encodeUin128() = %v, want %v", es.Buffer.Bytes(), tt.want)
992+
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
993+
t.Errorf("encodeState.encodeUin128() = %v, want %v", buffer.Bytes(), tt.want)
920994
}
921995
})
922996
}
@@ -925,12 +999,16 @@ func Test_encodeState_encodeUint128(t *testing.T) {
925999
func Test_encodeState_encodeBytes(t *testing.T) {
9261000
for _, tt := range stringTests {
9271001
t.Run(tt.name, func(t *testing.T) {
928-
es := &encodeState{}
1002+
1003+
buffer := bytes.NewBuffer(nil)
1004+
es := &encodeState{
1005+
Writer: buffer,
1006+
}
9291007
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
9301008
t.Errorf("encodeState.encodeBytes() error = %v, wantErr %v", err, tt.wantErr)
9311009
}
932-
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
933-
t.Errorf("encodeState.encodeBytes() = %v, want %v", es.Buffer.Bytes(), tt.want)
1010+
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
1011+
t.Errorf("encodeState.encodeBytes() = %v, want %v", buffer.Bytes(), tt.want)
9341012
}
9351013
})
9361014
}
@@ -939,12 +1017,16 @@ func Test_encodeState_encodeBytes(t *testing.T) {
9391017
func Test_encodeState_encodeBool(t *testing.T) {
9401018
for _, tt := range boolTests {
9411019
t.Run(tt.name, func(t *testing.T) {
942-
es := &encodeState{}
1020+
1021+
buffer := bytes.NewBuffer(nil)
1022+
es := &encodeState{
1023+
Writer: buffer,
1024+
}
9431025
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
9441026
t.Errorf("encodeState.encodeBool() error = %v, wantErr %v", err, tt.wantErr)
9451027
}
946-
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
947-
t.Errorf("encodeState.encodeBool() = %v, want %v", es.Buffer.Bytes(), tt.want)
1028+
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
1029+
t.Errorf("encodeState.encodeBool() = %v, want %v", buffer.Bytes(), tt.want)
9481030
}
9491031
})
9501032
}
@@ -953,12 +1035,16 @@ func Test_encodeState_encodeBool(t *testing.T) {
9531035
func Test_encodeState_encodeStruct(t *testing.T) {
9541036
for _, tt := range structTests {
9551037
t.Run(tt.name, func(t *testing.T) {
956-
es := &encodeState{fieldScaleIndicesCache: cache}
1038+
buffer := bytes.NewBuffer(nil)
1039+
es := &encodeState{
1040+
Writer: buffer,
1041+
fieldScaleIndicesCache: cache,
1042+
}
9571043
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
9581044
t.Errorf("encodeState.encodeStruct() error = %v, wantErr %v", err, tt.wantErr)
9591045
}
960-
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
961-
t.Errorf("encodeState.encodeStruct() = %v, want %v", es.Buffer.Bytes(), tt.want)
1046+
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
1047+
t.Errorf("encodeState.encodeStruct() = %v, want %v", buffer.Bytes(), tt.want)
9621048
}
9631049
})
9641050
}
@@ -967,12 +1053,16 @@ func Test_encodeState_encodeStruct(t *testing.T) {
9671053
func Test_encodeState_encodeSlice(t *testing.T) {
9681054
for _, tt := range sliceTests {
9691055
t.Run(tt.name, func(t *testing.T) {
970-
es := &encodeState{fieldScaleIndicesCache: cache}
1056+
buffer := bytes.NewBuffer(nil)
1057+
es := &encodeState{
1058+
Writer: buffer,
1059+
fieldScaleIndicesCache: cache,
1060+
}
9711061
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
9721062
t.Errorf("encodeState.encodeSlice() error = %v, wantErr %v", err, tt.wantErr)
9731063
}
974-
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
975-
t.Errorf("encodeState.encodeSlice() = %v, want %v", es.Buffer.Bytes(), tt.want)
1064+
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
1065+
t.Errorf("encodeState.encodeSlice() = %v, want %v", buffer.Bytes(), tt.want)
9761066
}
9771067
})
9781068
}
@@ -981,12 +1071,16 @@ func Test_encodeState_encodeSlice(t *testing.T) {
9811071
func Test_encodeState_encodeArray(t *testing.T) {
9821072
for _, tt := range arrayTests {
9831073
t.Run(tt.name, func(t *testing.T) {
984-
es := &encodeState{fieldScaleIndicesCache: cache}
1074+
buffer := bytes.NewBuffer(nil)
1075+
es := &encodeState{
1076+
Writer: buffer,
1077+
fieldScaleIndicesCache: cache,
1078+
}
9851079
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
9861080
t.Errorf("encodeState.encodeArray() error = %v, wantErr %v", err, tt.wantErr)
9871081
}
988-
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
989-
t.Errorf("encodeState.encodeArray() = %v, want %v", es.Buffer.Bytes(), tt.want)
1082+
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
1083+
t.Errorf("encodeState.encodeArray() = %v, want %v", buffer.Bytes(), tt.want)
9901084
}
9911085
})
9921086
}
@@ -1007,12 +1101,16 @@ func Test_marshal_optionality(t *testing.T) {
10071101
}
10081102
for _, tt := range ptrTests {
10091103
t.Run(tt.name, func(t *testing.T) {
1010-
es := &encodeState{fieldScaleIndicesCache: cache}
1104+
buffer := bytes.NewBuffer(nil)
1105+
es := &encodeState{
1106+
Writer: buffer,
1107+
fieldScaleIndicesCache: cache,
1108+
}
10111109
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
10121110
t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr)
10131111
}
1014-
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
1015-
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
1112+
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
1113+
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
10161114
}
10171115
})
10181116
}
@@ -1043,12 +1141,16 @@ func Test_marshal_optionality_nil_cases(t *testing.T) {
10431141
}
10441142
for _, tt := range ptrTests {
10451143
t.Run(tt.name, func(t *testing.T) {
1046-
es := &encodeState{fieldScaleIndicesCache: cache}
1144+
buffer := bytes.NewBuffer(nil)
1145+
es := &encodeState{
1146+
Writer: buffer,
1147+
fieldScaleIndicesCache: cache,
1148+
}
10471149
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
10481150
t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr)
10491151
}
1050-
if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) {
1051-
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want)
1152+
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
1153+
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
10521154
}
10531155
})
10541156
}

0 commit comments

Comments
 (0)