Skip to content

Commit 9816e04

Browse files
committed
Replace sync.Map with mutex locked map
1 parent 7324c07 commit 9816e04

File tree

5 files changed

+279
-37
lines changed

5 files changed

+279
-37
lines changed

dot/state/service_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,8 @@ func TestService_PruneStorage(t *testing.T) {
293293
time.Sleep(1 * time.Second)
294294

295295
for _, v := range prunedArr {
296-
_, has := serv.Storage.tries.Load(v.hash)
297-
require.Equal(t, false, has)
296+
tr := serv.Storage.tries.get(v.hash)
297+
require.Nil(t, tr)
298298
}
299299
}
300300

dot/state/storage.go

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func errTrieDoesNotExist(hash common.Hash) error {
3030
// StorageState is the struct that holds the trie, db and lock
3131
type StorageState struct {
3232
blockState *BlockState
33-
tries *sync.Map // map[common.Hash]*trie.Trie // map of root -> trie
33+
tries *tries
3434

3535
db chaindb.Database
3636
sync.RWMutex
@@ -52,8 +52,7 @@ func NewStorageState(db chaindb.Database, blockState *BlockState,
5252
return nil, fmt.Errorf("cannot have nil trie")
5353
}
5454

55-
tries := new(sync.Map)
56-
tries.Store(t.MustHash(), t)
55+
tries := newTries(t)
5756

5857
storageTable := chaindb.NewTable(db, storagePrefix)
5958

@@ -79,14 +78,14 @@ func NewStorageState(db chaindb.Database, blockState *BlockState,
7978

8079
func (s *StorageState) pruneKey(keyHeader *types.Header) {
8180
logger.Tracef("pruning trie, number=%d hash=%s", keyHeader.Number, keyHeader.Hash())
82-
s.tries.Delete(keyHeader.StateRoot)
81+
s.tries.delete(keyHeader.StateRoot)
8382
}
8483

8584
// StoreTrie stores the given trie in the StorageState and writes it to the database
8685
func (s *StorageState) StoreTrie(ts *rtstorage.TrieState, header *types.Header) error {
8786
root := ts.MustRoot()
8887

89-
_, _ = s.tries.LoadOrStore(root, ts.Trie())
88+
s.tries.softSet(root, ts.Trie())
9089

9190
if _, ok := s.pruner.(*pruner.FullNode); header == nil && ok {
9291
return fmt.Errorf("block cannot be empty for Full node pruner")
@@ -127,20 +126,16 @@ func (s *StorageState) TrieState(root *common.Hash) (*rtstorage.TrieState, error
127126
root = &sr
128127
}
129128

130-
st, has := s.tries.Load(*root)
131-
if !has {
129+
t := s.tries.get(*root)
130+
if t == nil {
132131
var err error
133-
st, err = s.LoadFromDB(*root)
132+
t, err = s.LoadFromDB(*root)
134133
if err != nil {
135134
return nil, err
136135
}
137136

138-
_, _ = s.tries.LoadOrStore(*root, st)
139-
}
140-
141-
t := st.(*trie.Trie)
142-
143-
if has && t.MustHash() != *root {
137+
s.tries.softSet(*root, t)
138+
} else if t.MustHash() != *root {
144139
panic("trie does not have expected root")
145140
}
146141

@@ -162,7 +157,7 @@ func (s *StorageState) LoadFromDB(root common.Hash) (*trie.Trie, error) {
162157
return nil, err
163158
}
164159

165-
_, _ = s.tries.LoadOrStore(t.MustHash(), t)
160+
s.tries.softSet(t.MustHash(), t)
166161
return t, nil
167162
}
168163

@@ -175,8 +170,9 @@ func (s *StorageState) loadTrie(root *common.Hash) (*trie.Trie, error) {
175170
root = &sr
176171
}
177172

178-
if t, has := s.tries.Load(*root); has && t != nil {
179-
return t.(*trie.Trie), nil
173+
t := s.tries.get(*root)
174+
if t != nil {
175+
return t, nil
180176
}
181177

182178
tr, err := s.LoadFromDB(*root)
@@ -205,8 +201,9 @@ func (s *StorageState) GetStorage(root *common.Hash, key []byte) ([]byte, error)
205201
root = &sr
206202
}
207203

208-
if t, has := s.tries.Load(*root); has {
209-
val := t.(*trie.Trie).Get(key)
204+
t := s.tries.get(*root)
205+
if t != nil {
206+
val := t.Get(key)
210207
return val, nil
211208
}
212209

dot/state/storage_test.go

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ package state
55

66
import (
77
"math/big"
8-
"sync"
98
"testing"
109
"time"
1110

@@ -99,7 +98,7 @@ func TestStorage_TrieState(t *testing.T) {
9998
time.Sleep(time.Millisecond * 100)
10099

101100
// get trie from db
102-
storage.tries.Delete(root)
101+
storage.tries.delete(root)
103102
ts3, err := storage.TrieState(&root)
104103
require.NoError(t, err)
105104
require.Equal(t, ts.Trie().MustHash(), ts3.Trie().MustHash())
@@ -131,34 +130,25 @@ func TestStorage_LoadFromDB(t *testing.T) {
131130
require.NoError(t, err)
132131

133132
// Clear trie from cache and fetch data from disk.
134-
storage.tries.Delete(root)
133+
storage.tries.delete(root)
135134

136135
data, err := storage.GetStorage(&root, trieKV[0].key)
137136
require.NoError(t, err)
138137
require.Equal(t, trieKV[0].value, data)
139138

140-
storage.tries.Delete(root)
139+
storage.tries.delete(root)
141140

142141
prefixKeys, err := storage.GetKeysWithPrefix(&root, []byte("ke"))
143142
require.NoError(t, err)
144143
require.Equal(t, 2, len(prefixKeys))
145144

146-
storage.tries.Delete(root)
145+
storage.tries.delete(root)
147146

148147
entries, err := storage.Entries(&root)
149148
require.NoError(t, err)
150149
require.Equal(t, 3, len(entries))
151150
}
152151

153-
func syncMapLen(m *sync.Map) int {
154-
l := 0
155-
m.Range(func(_, _ interface{}) bool {
156-
l++
157-
return true
158-
})
159-
return l
160-
}
161-
162152
func TestStorage_StoreTrie_NotSyncing(t *testing.T) {
163153
storage := newTestStorageState(t)
164154
ts, err := storage.TrieState(&trie.EmptyHash)
@@ -170,7 +160,7 @@ func TestStorage_StoreTrie_NotSyncing(t *testing.T) {
170160

171161
err = storage.StoreTrie(ts, nil)
172162
require.NoError(t, err)
173-
require.Equal(t, 2, syncMapLen(storage.tries))
163+
require.Equal(t, 2, storage.tries.len())
174164
}
175165

176166
func TestGetStorageChildAndGetStorageFromChild(t *testing.T) {
@@ -217,7 +207,7 @@ func TestGetStorageChildAndGetStorageFromChild(t *testing.T) {
217207
require.NoError(t, err)
218208

219209
// Clear trie from cache and fetch data from disk.
220-
storage.tries.Delete(rootHash)
210+
storage.tries.delete(rootHash)
221211

222212
_, err = storage.GetStorageChild(&rootHash, []byte("keyToChild"))
223213
require.NoError(t, err)

dot/state/tries.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright 2022 ChainSafe Systems (ON)
2+
// SPDX-License-Identifier: LGPL-3.0-only
3+
4+
package state
5+
6+
import (
7+
"sync"
8+
9+
"github.com/ChainSafe/gossamer/lib/common"
10+
"github.com/ChainSafe/gossamer/lib/trie"
11+
)
12+
13+
type tries struct {
14+
rootToTrie map[common.Hash]*trie.Trie
15+
mapMutex sync.RWMutex
16+
}
17+
18+
func newTries(t *trie.Trie) *tries {
19+
return &tries{
20+
rootToTrie: map[common.Hash]*trie.Trie{
21+
t.MustHash(): t,
22+
},
23+
}
24+
}
25+
26+
// softSet sets the given trie at the given root hash
27+
// in the memory map only if it is not already set.
28+
func (t *tries) softSet(root common.Hash, trie *trie.Trie) {
29+
t.mapMutex.Lock()
30+
defer t.mapMutex.Unlock()
31+
32+
_, has := t.rootToTrie[root]
33+
if has {
34+
return
35+
}
36+
37+
t.rootToTrie[root] = trie
38+
}
39+
40+
func (t *tries) delete(root common.Hash) {
41+
t.mapMutex.Lock()
42+
defer t.mapMutex.Unlock()
43+
delete(t.rootToTrie, root)
44+
}
45+
46+
// get retrieves the trie corresponding to the root hash given
47+
// from the in-memory thread safe map.
48+
func (t *tries) get(root common.Hash) (tr *trie.Trie) {
49+
t.mapMutex.RLock()
50+
defer t.mapMutex.RUnlock()
51+
return t.rootToTrie[root]
52+
}
53+
54+
// len returns the current numbers of tries
55+
// stored in the in-memory map.
56+
func (t *tries) len() int {
57+
t.mapMutex.RLock()
58+
defer t.mapMutex.RUnlock()
59+
return len(t.rootToTrie)
60+
}

0 commit comments

Comments
 (0)