Skip to content

Commit be69926

Browse files
committed
Add cached MapReduce operation
``` === RUN TestMapReduceSimple mapreduce_test.go:69: tree size: 600000, cache size: 1000 mapreduce_test.go:76: fresh readCount: 36891 mapreduce_test.go:83: fresh re-readCount: 0 mapreduce_test.go:103: new key readCount: 38 mapreduce_test.go:114: repeat readCount: 0 mapreduce_test.go:121: repeat readCount: 0 mapreduce_test.go:141: new two keys readCount: 76 --- PASS: TestMapReduceSimple (6.89s) ``` Signed-off-by: Jakub Sztandera <[email protected]>
1 parent 9f2472e commit be69926

File tree

3 files changed

+354
-0
lines changed

3 files changed

+354
-0
lines changed

hamt_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,35 @@ func TestSha256(t *testing.T) {
531531
}))
532532
}
533533

534+
func TestForEach(t *testing.T) {
535+
ctx := context.Background()
536+
cs := cbor.NewCborStore(newMockBlocks())
537+
begn, err := NewNode(cs)
538+
require.NoError(t, err)
539+
540+
golden := make(map[string]*CborByteArray)
541+
for range 1000 {
542+
k := randKey()
543+
v := randValue()
544+
golden[k] = v
545+
err = begn.Set(ctx, k, v)
546+
require.NoError(t, err)
547+
}
548+
err = begn.Flush(ctx)
549+
require.NoError(t, err)
550+
err = begn.ForEach(ctx, func(k string, val *cbg.Deferred) error {
551+
v, ok := golden[k]
552+
if !ok {
553+
t.Fatalf("unexpected key in ForEach: %s", k)
554+
}
555+
var val2 CborByteArray
556+
val2.UnmarshalCBOR(bytes.NewReader(val.Raw))
557+
require.Equal(t, []byte(*v), []byte(val2))
558+
return nil
559+
})
560+
require.NoError(t, err)
561+
}
562+
534563
func testBasic(t *testing.T, options ...Option) {
535564
ctx := context.Background()
536565
cs := cbor.NewCborStore(newMockBlocks())

mapreduce.go

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
package hamt
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"errors"
7+
"fmt"
8+
"math/rand/v2"
9+
"sync"
10+
11+
cid "github.com/ipfs/go-cid"
12+
cbg "github.com/whyrusleeping/cbor-gen"
13+
)
14+
15+
type cacheEntry[T any] struct {
16+
value T
17+
weight int
18+
}
19+
type weigthted2RCache[T any] struct {
20+
lk sync.Mutex
21+
cache map[cid.Cid]cacheEntry[T]
22+
cacheSize int
23+
}
24+
25+
func newWeighted2RCache[T any](cacheSize int) *weigthted2RCache[T] {
26+
return &weigthted2RCache[T]{
27+
cache: make(map[cid.Cid]cacheEntry[T]),
28+
cacheSize: cacheSize,
29+
}
30+
}
31+
func (c *weigthted2RCache[T]) Get(k cid.Cid) (cacheEntry[T], bool) {
32+
c.lk.Lock()
33+
defer c.lk.Unlock()
34+
v, ok := c.cache[k]
35+
if !ok {
36+
return v, false
37+
}
38+
return v, true
39+
}
40+
41+
func (c *weigthted2RCache[T]) Add(k cid.Cid, v cacheEntry[T]) {
42+
// dont cache nodes that require less than 6 reads
43+
if v.weight <= 5 {
44+
return
45+
}
46+
c.lk.Lock()
47+
defer c.lk.Unlock()
48+
if _, ok := c.cache[k]; ok {
49+
c.cache[k] = v
50+
return
51+
}
52+
53+
c.cache[k] = v
54+
if len(c.cache) > c.cacheSize {
55+
// pick two random entris using map iteration
56+
// work well for cacheSize > 8
57+
var k1, k2 cid.Cid
58+
var v1, v2 cacheEntry[T]
59+
for k, v := range c.cache {
60+
k1 = k
61+
v1 = v
62+
break
63+
}
64+
for k, v := range c.cache {
65+
k2 = k
66+
v2 = v
67+
break
68+
}
69+
// pick random one based on weight
70+
r1 := rand.Float64()
71+
if r1 < float64(v1.weight)/float64(v1.weight+v2.weight) {
72+
delete(c.cache, k2)
73+
} else {
74+
delete(c.cache, k1)
75+
}
76+
}
77+
}
78+
79+
// CachedMapReduce is a map reduce implementation that caches intermediate results
80+
// to reduce the number of reads from the underlying store.
81+
type CachedMapReduce[T any, PT interface {
82+
*T
83+
cbg.CBORUnmarshaler
84+
}, U any] struct {
85+
mapper func(string, T) (U, error)
86+
reducer func([]U) (U, error)
87+
cache *weigthted2RCache[U]
88+
}
89+
90+
// NewCachedMapReduce creates a new CachedMapReduce instance.
91+
// The mapper translates a key-value pair stored in the HAMT into a chosen U value.
92+
// The reducer reduces the U values into a single U value.
93+
// The cacheSize parameter specifies the maximum number of intermediate results to cache.
94+
func NewCachedMapReduce[T any, PT interface {
95+
*T
96+
cbg.CBORUnmarshaler
97+
}, U any](
98+
mapper func(string, T) (U, error),
99+
reducer func([]U) (U, error),
100+
cacheSize int,
101+
) (*CachedMapReduce[T, PT, U], error) {
102+
return &CachedMapReduce[T, PT, U]{
103+
mapper: mapper,
104+
reducer: reducer,
105+
cache: newWeighted2RCache[U](cacheSize),
106+
}, nil
107+
}
108+
109+
// MapReduce applies the map reduce function to the given root node.
110+
func (cmr *CachedMapReduce[T, PT, U]) MapReduce(ctx context.Context, root *Node) (U, error) {
111+
var res U
112+
if root == nil {
113+
return res, errors.New("root is nil")
114+
}
115+
ce, err := cmr.mapReduceInternal(ctx, root)
116+
if err != nil {
117+
return res, err
118+
}
119+
return ce.value, nil
120+
}
121+
122+
func (cmr *CachedMapReduce[T, PT, U]) mapReduceInternal(ctx context.Context, node *Node) (cacheEntry[U], error) {
123+
var res cacheEntry[U]
124+
125+
Us := make([]U, 0)
126+
weight := 1
127+
for _, p := range node.Pointers {
128+
if p.cache != nil && p.dirty {
129+
return res, errors.New("cannot iterate over a dirty node")
130+
}
131+
if p.isShard() {
132+
if p.cache != nil && p.dirty {
133+
return res, errors.New("cannot iterate over a dirty node")
134+
}
135+
linkU, ok := cmr.cache.Get(p.Link)
136+
if !ok {
137+
chnd, err := p.loadChild(ctx, node.store, node.bitWidth, node.hash)
138+
if err != nil {
139+
return res, fmt.Errorf("loading child: %w", err)
140+
}
141+
142+
linkU, err = cmr.mapReduceInternal(ctx, chnd)
143+
if err != nil {
144+
return res, fmt.Errorf("map reduce child: %w", err)
145+
}
146+
cmr.cache.Add(p.Link, linkU)
147+
}
148+
Us = append(Us, linkU.value)
149+
weight += linkU.weight
150+
} else {
151+
for _, v := range p.KVs {
152+
var pt = PT(new(T))
153+
err := pt.UnmarshalCBOR(bytes.NewReader(v.Value.Raw))
154+
if err != nil {
155+
return res, fmt.Errorf("failed to unmarshal value: %w", err)
156+
}
157+
u, err := cmr.mapper(string(v.Key), *pt)
158+
if err != nil {
159+
return res, fmt.Errorf("failed to map value: %w", err)
160+
}
161+
162+
Us = append(Us, u)
163+
}
164+
}
165+
}
166+
167+
resU, err := cmr.reducer(Us)
168+
if err != nil {
169+
return res, fmt.Errorf("failed to reduce self values: %w", err)
170+
}
171+
172+
return cacheEntry[U]{
173+
value: resU,
174+
weight: weight,
175+
}, nil
176+
}

mapreduce_test.go

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
package hamt
2+
3+
import (
4+
"context"
5+
"slices"
6+
"strings"
7+
"testing"
8+
9+
cid "github.com/ipfs/go-cid"
10+
cbor "github.com/ipfs/go-ipld-cbor"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
type readCounterStore struct {
15+
cbor.IpldStore
16+
readCount int
17+
}
18+
19+
func (rcs *readCounterStore) Get(ctx context.Context, c cid.Cid, out any) error {
20+
rcs.readCount++
21+
return rcs.IpldStore.Get(ctx, c, out)
22+
}
23+
24+
func TestMapReduceSimple(t *testing.T) {
25+
ctx := context.Background()
26+
opts := []Option{UseTreeBitWidth(5)}
27+
cs := &readCounterStore{cbor.NewCborStore(newMockBlocks()), 0}
28+
begn, err := NewNode(cs, opts...)
29+
require.NoError(t, err)
30+
31+
golden := make(map[string]string)
32+
N := 50000
33+
for range N {
34+
k := randKey()
35+
v := randValue()
36+
golden[k] = string([]byte(*v))
37+
begn.Set(ctx, k, v)
38+
}
39+
40+
reLoadNode := func(node *Node) *Node {
41+
c, err := node.Write(ctx)
42+
require.NoError(t, err)
43+
res, err := LoadNode(ctx, cs, c, opts...)
44+
require.NoError(t, err)
45+
return res
46+
}
47+
begn = reLoadNode(begn)
48+
49+
type kv struct {
50+
k string
51+
v string
52+
}
53+
54+
mapper := func(k string, v CborByteArray) ([]kv, error) {
55+
return []kv{{k, string([]byte(v))}}, nil
56+
}
57+
reducer := func(kvs [][]kv) ([]kv, error) {
58+
var kvsConcat []kv
59+
for _, kvs := range kvs {
60+
kvsConcat = append(kvsConcat, kvs...)
61+
}
62+
slices.SortFunc(kvsConcat, func(a, b kv) int {
63+
return strings.Compare(a.k, b.k)
64+
})
65+
return kvsConcat, nil
66+
}
67+
68+
cmr, err := NewCachedMapReduce(mapper, reducer, 200)
69+
t.Logf("tree size: %d, cache size: %d", N, cmr.cache.cacheSize)
70+
require.NoError(t, err)
71+
72+
cs.readCount = 0
73+
res, err := cmr.MapReduce(ctx, begn)
74+
require.NoError(t, err)
75+
require.Equal(t, len(golden), len(res))
76+
t.Logf("fresh readCount: %d", cs.readCount)
77+
78+
begn = reLoadNode(begn)
79+
cs.readCount = 0
80+
res, err = cmr.MapReduce(ctx, begn)
81+
require.NoError(t, err)
82+
t.Logf("fresh re-readCount: %d", cs.readCount)
83+
require.Less(t, cs.readCount, 200)
84+
85+
verifyConsistency := func(res []kv) {
86+
t.Helper()
87+
mappedRes := make(map[string]string)
88+
for _, kv := range res {
89+
mappedRes[kv.k] = kv.v
90+
}
91+
require.Equal(t, len(golden), len(mappedRes))
92+
require.Equal(t, golden, mappedRes)
93+
}
94+
verifyConsistency(res)
95+
96+
{
97+
// add new key
98+
k := randKey()
99+
v := randValue()
100+
golden[k] = string([]byte(*v))
101+
begn.Set(ctx, k, v)
102+
103+
begn = reLoadNode(begn)
104+
}
105+
106+
cs.readCount = 0
107+
res, err = cmr.MapReduce(ctx, begn)
108+
require.NoError(t, err)
109+
verifyConsistency(res)
110+
t.Logf("new key readCount: %d", cs.readCount)
111+
require.Less(t, cs.readCount, 200)
112+
113+
begn = reLoadNode(begn)
114+
cs.readCount = 0
115+
res, err = cmr.MapReduce(ctx, begn)
116+
require.NoError(t, err)
117+
verifyConsistency(res)
118+
t.Logf("repeat readCount: %d", cs.readCount)
119+
require.Less(t, cs.readCount, 200)
120+
121+
begn = reLoadNode(begn)
122+
cs.readCount = 0
123+
res, err = cmr.MapReduce(ctx, begn)
124+
require.NoError(t, err)
125+
verifyConsistency(res)
126+
t.Logf("repeat readCount: %d", cs.readCount)
127+
require.Less(t, cs.readCount, 200)
128+
129+
{
130+
// add new key
131+
k := randKey()
132+
v := randValue()
133+
golden[k] = string([]byte(*v))
134+
begn.Set(ctx, k, v)
135+
k = randKey()
136+
v = randValue()
137+
golden[k] = string([]byte(*v))
138+
begn.Set(ctx, k, v)
139+
140+
begn = reLoadNode(begn)
141+
}
142+
143+
cs.readCount = 0
144+
res, err = cmr.MapReduce(ctx, begn)
145+
require.NoError(t, err)
146+
verifyConsistency(res)
147+
t.Logf("new two keys readCount: %d", cs.readCount)
148+
require.Less(t, cs.readCount, 300)
149+
}

0 commit comments

Comments
 (0)