Skip to content

Commit 58dbc47

Browse files
committed
WIP
1 parent 25e626b commit 58dbc47

File tree

6 files changed

+492
-180
lines changed

6 files changed

+492
-180
lines changed

internal/trie/node/children.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,13 @@ func (n *Node) NumChildren() (count int) {
3030
}
3131
return count
3232
}
33+
34+
// HasChild returns true if the node has at least one child.
35+
func (n *Node) HasChild() (has bool) {
36+
for _, child := range n.Children {
37+
if child != nil {
38+
return true
39+
}
40+
}
41+
return false
42+
}

internal/trie/node/children_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,42 @@ func Test_Node_NumChildren(t *testing.T) {
118118
})
119119
}
120120
}
121+
122+
func Test_Node_HasChild(t *testing.T) {
123+
t.Parallel()
124+
125+
testCases := map[string]struct {
126+
node Node
127+
has bool
128+
}{
129+
"no child": {},
130+
"one child at index 0": {
131+
node: Node{
132+
Children: []*Node{
133+
{},
134+
},
135+
},
136+
has: true,
137+
},
138+
"one child at index 1": {
139+
node: Node{
140+
Children: []*Node{
141+
nil,
142+
{},
143+
},
144+
},
145+
has: true,
146+
},
147+
}
148+
149+
for name, testCase := range testCases {
150+
testCase := testCase
151+
t.Run(name, func(t *testing.T) {
152+
t.Parallel()
153+
154+
has := testCase.node.HasChild()
155+
156+
assert.Equal(t, testCase.has, has)
157+
})
158+
}
159+
}

lib/trie/proof/generate_test.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,16 +153,8 @@ func Test_walk(t *testing.T) {
153153
errWrapped: ErrKeyNotFound,
154154
errMessage: "key not found",
155155
},
156-
"parent encode and hash error": {
157-
parent: &node.Node{
158-
Key: make([]byte, int(^uint16(0))+63),
159-
Value: []byte{1},
160-
},
161-
errWrapped: node.ErrPartialKeyTooBig,
162-
errMessage: "encode node: " +
163-
"cannot encode header: partial key length cannot " +
164-
"be larger than or equal to 2^16: 65535",
165-
},
156+
// The parent encode error cannot be triggered here
157+
// since it can only be caused by a buffer.Write error.
166158
"parent leaf and empty full key": {
167159
parent: &node.Node{
168160
Key: []byte{1, 2},

lib/trie/proof/helpers_test.go

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ package proof
55

66
import (
77
"bytes"
8+
"math/rand"
89
"testing"
910

1011
"github.com/ChainSafe/gossamer/internal/trie/node"
1112
"github.com/ChainSafe/gossamer/lib/common"
13+
"github.com/ChainSafe/gossamer/pkg/scale"
1214
"github.com/stretchr/testify/require"
1315
)
1416

@@ -29,8 +31,72 @@ func encodeNode(t *testing.T, node node.Node) (encoded []byte) {
2931
func blake2bNode(t *testing.T, node node.Node) (digest []byte) {
3032
t.Helper()
3133
encoding := encodeNode(t, node)
32-
digestHash, err := common.Blake2bHash(encoding)
34+
return blake2b(t, encoding)
35+
}
36+
37+
func scaleEncode(t *testing.T, data []byte) (encoded []byte) {
38+
t.Helper()
39+
encoded, err := scale.Marshal(data)
40+
require.NoError(t, err)
41+
return encoded
42+
}
43+
44+
func blake2b(t *testing.T, data []byte) (digest []byte) {
45+
t.Helper()
46+
digestHash, err := common.Blake2bHash(data)
3347
require.NoError(t, err)
3448
digest = digestHash[:]
3549
return digest
3650
}
51+
52+
func concatBytes(slices [][]byte) (concatenated []byte) {
53+
for _, slice := range slices {
54+
concatenated = append(concatenated, slice...)
55+
}
56+
return concatenated
57+
}
58+
59+
// generateBytes generates a pseudo random byte slice
60+
// of the given length. It uses `0` as its seed so
61+
// calling it multiple times will generate the same
62+
// byte slice. This is designed as such in order to have
63+
// deterministic unit tests.
64+
func generateBytes(t *testing.T, length uint) (bytes []byte) {
65+
t.Helper()
66+
generator := rand.New(rand.NewSource(0))
67+
bytes = make([]byte, length)
68+
_, err := generator.Read(bytes)
69+
require.NoError(t, err)
70+
return bytes
71+
}
72+
73+
// getBadNodeEncoding returns a particular bad node encoding of 33 bytes.
74+
func getBadNodeEncoding() (badEncoding []byte) {
75+
return []byte{
76+
0x1, 0x94, 0xfd, 0xc2, 0xfa, 0x2f, 0xfc, 0xc0, 0x41, 0xd3,
77+
0xff, 0x12, 0x4, 0x5b, 0x73, 0xc8, 0x6e, 0x4f, 0xf9, 0x5f,
78+
0xf6, 0x62, 0xa5, 0xee, 0xe8, 0x2a, 0xbd, 0xf4, 0x4a, 0x2d,
79+
0xb, 0x75, 0xfb}
80+
}
81+
82+
func Test_getBadNodeEncoding(t *testing.T) {
83+
t.Parallel()
84+
85+
badEncoding := getBadNodeEncoding()
86+
_, err := node.Decode(bytes.NewBuffer(badEncoding))
87+
require.Error(t, err)
88+
}
89+
90+
func assertLongEncoding(t *testing.T, node node.Node) {
91+
t.Helper()
92+
93+
encoding := encodeNode(t, node)
94+
require.Greater(t, len(encoding), 32)
95+
}
96+
97+
func assertShortEncoding(t *testing.T, node node.Node) {
98+
t.Helper()
99+
100+
encoding := encodeNode(t, node)
101+
require.LessOrEqual(t, len(encoding), 32)
102+
}

lib/trie/proof/verify.go

Lines changed: 71 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -57,76 +57,72 @@ func buildTrie(encodedProofNodes [][]byte, rootHash []byte) (t *trie.Trie, err e
5757
ErrEmptyProof, rootHash)
5858
}
5959

60-
merkleValueToNode := make(map[string]*node.Node, len(encodedProofNodes))
60+
merkleValueToEncoding := make(map[string][]byte, len(encodedProofNodes))
6161

62+
// This loop finds the root node and decodes it.
63+
// The other nodes have their Merkle value (blake2b digest or the encoding itself)
64+
// inserted into a map from merkle value to encoding. They are only decoded
65+
// later if the root or one of its descendant node reference their Merkle value.
6266
var root *node.Node
63-
for i, encodedProofNode := range encodedProofNodes {
64-
decodedNode, err := node.Decode(bytes.NewReader(encodedProofNode))
65-
if err != nil {
66-
return nil, fmt.Errorf("decoding node at index %d: %w (node encoded is 0x%x)",
67-
i, err, encodedProofNode)
68-
}
69-
70-
decodedNode.Encoding = encodedProofNode
71-
// We compute the Merkle value of nodes treating them all
72-
// as non-root nodes, meaning nodes with encoding smaller
73-
// than 33 bytes will have their Merkle value set as their
74-
// encoding. The Blake2b hash of the encoding is computed
75-
// later if needed to compare with the root hash given to find
76-
// which node is the root node.
77-
const isRoot = false
78-
decodedNode.HashDigest, err = node.MerkleValue(encodedProofNode, isRoot)
79-
if err != nil {
80-
return nil, fmt.Errorf("merkle value of node at index %d: %w", i, err)
81-
}
82-
83-
proofHash := common.BytesToHex(decodedNode.HashDigest)
84-
merkleValueToNode[proofHash] = decodedNode
85-
86-
if root != nil {
87-
// Root node already found in proof
88-
continue
89-
}
90-
91-
possibleRootMerkleValue := decodedNode.HashDigest
92-
if len(possibleRootMerkleValue) <= 32 {
93-
// If the root merkle value is smaller than 33 bytes, it means
94-
// it is the encoding of the node. However, the root node merkle
95-
// value is always the blake2b digest of the node, and not its own
96-
// encoding. Therefore, in this case we force the computation of the
97-
// blake2b digest of the node to check if it matches the root hash given.
98-
const isRoot = true
99-
possibleRootMerkleValue, err = node.MerkleValue(encodedProofNode, isRoot)
67+
for _, encodedProofNode := range encodedProofNodes {
68+
var digest []byte
69+
if root == nil {
70+
// root node not found yet
71+
digestHash, err := common.Blake2bHash(encodedProofNode)
10072
if err != nil {
101-
return nil, fmt.Errorf("merkle value of possible root node: %w", err)
73+
return nil, fmt.Errorf("blake2b hash: %w", err)
74+
}
75+
digest = digestHash[:]
76+
77+
if bytes.Equal(digest, rootHash) {
78+
root, err = node.Decode(bytes.NewReader(encodedProofNode))
79+
if err != nil {
80+
return nil, fmt.Errorf("decoding root node: %w", err)
81+
}
82+
continue // no need to add root to map of hash to encoding
10283
}
10384
}
10485

105-
if bytes.Equal(rootHash, possibleRootMerkleValue) {
106-
decodedNode.HashDigest = rootHash
107-
root = decodedNode
86+
var merkleValue []byte
87+
if len(encodedProofNode) <= 32 {
88+
merkleValue = encodedProofNode
89+
} else {
90+
if digest == nil {
91+
digestHash, err := common.Blake2bHash(encodedProofNode)
92+
if err != nil {
93+
return nil, fmt.Errorf("blake2b hash: %w", err)
94+
}
95+
digest = digestHash[:]
96+
}
97+
merkleValue = digest
10898
}
99+
100+
merkleValueToEncoding[string(merkleValue)] = encodedProofNode
109101
}
110102

111103
if root == nil {
112-
proofMerkleValues := make([]string, 0, len(merkleValueToNode))
113-
for merkleValue := range merkleValueToNode {
114-
proofMerkleValues = append(proofMerkleValues, merkleValue)
104+
proofMerkleValues := make([]string, 0, len(merkleValueToEncoding))
105+
for merkleValueString := range merkleValueToEncoding {
106+
merkleValueHex := common.BytesToHex([]byte(merkleValueString))
107+
proofMerkleValues = append(proofMerkleValues, merkleValueHex)
115108
}
116109
return nil, fmt.Errorf("%w: for Merkle root hash 0x%x in proof Merkle value(s) %s",
117110
ErrRootNodeNotFound, rootHash, strings.Join(proofMerkleValues, ", "))
118111
}
119112

120-
loadProof(merkleValueToNode, root)
113+
err = loadProof(merkleValueToEncoding, root)
114+
if err != nil {
115+
return nil, fmt.Errorf("loading proof: %w", err)
116+
}
121117

122118
return trie.NewTrie(root), nil
123119
}
124120

125121
// loadProof is a recursive function that will create all the trie paths based
126122
// on the map from node hash to node starting at the root.
127-
func loadProof(merkleValueToNode map[string]*node.Node, n *node.Node) {
123+
func loadProof(merkleValueToEncoding map[string][]byte, n *node.Node) (err error) {
128124
if n.Type() != node.Branch {
129-
return
125+
return nil
130126
}
131127

132128
branch := n
@@ -135,15 +131,38 @@ func loadProof(merkleValueToNode map[string]*node.Node, n *node.Node) {
135131
continue
136132
}
137133

138-
merkleValueHex := common.BytesToHex(child.HashDigest)
139-
node, ok := merkleValueToNode[merkleValueHex]
134+
merkleValue := child.HashDigest
135+
encoding, ok := merkleValueToEncoding[string(merkleValue)]
140136
if !ok {
137+
inlinedChild := len(child.Value) > 0 || child.HasChild()
138+
if !inlinedChild {
139+
// hash not found and the child is not inlined,
140+
// so clear the child from the branch.
141+
branch.Descendants -= 1 + child.Descendants
142+
branch.Children[i] = nil
143+
if !branch.HasChild() {
144+
// Convert branch to a leaf if all its children are nil.
145+
branch.Children = nil
146+
}
147+
}
141148
continue
142149
}
143150

144-
branch.Children[i] = node
145-
loadProof(merkleValueToNode, node)
151+
child, err := node.Decode(bytes.NewReader(encoding))
152+
if err != nil {
153+
return fmt.Errorf("decoding child node for Merkle value 0x%x: %w",
154+
merkleValue, err)
155+
}
156+
157+
branch.Children[i] = child
158+
branch.Descendants += child.Descendants
159+
err = loadProof(merkleValueToEncoding, child)
160+
if err != nil {
161+
return err // do not wrap error since this is recursive
162+
}
146163
}
164+
165+
return nil
147166
}
148167

149168
func bytesToString(b []byte) (s string) {

0 commit comments

Comments
 (0)