Skip to content

Integrate state diffs into KV/State() #15439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: state-diff
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions beacon-chain/db/kv/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ go_library(
"migration_state_validators.go",
"schema.go",
"state.go",
"state_diff.go",
"state_diff_cache.go",
"state_diff_helpers.go",
"state_summary.go",
"state_summary_cache.go",
"utils.go",
Expand All @@ -44,13 +47,15 @@ go_library(
"//config/fieldparams:go_default_library",
"//config/params:go_default_library",
"//consensus-types/blocks:go_default_library",
"//consensus-types/hdiff:go_default_library",
"//consensus-types/interfaces:go_default_library",
"//consensus-types/light-client:go_default_library",
"//consensus-types/primitives:go_default_library",
"//container/slice:go_default_library",
"//encoding/bytesutil:go_default_library",
"//encoding/ssz/detect:go_default_library",
"//io/file:go_default_library",
"//math:go_default_library",
"//monitoring/progress:go_default_library",
"//monitoring/tracing:go_default_library",
"//monitoring/tracing/trace:go_default_library",
Expand Down Expand Up @@ -94,6 +99,7 @@ go_test(
"migration_archived_index_test.go",
"migration_block_slot_index_test.go",
"migration_state_validators_test.go",
"state_diff_test.go",
"state_summary_test.go",
"state_test.go",
"utils_test.go",
Expand All @@ -116,6 +122,7 @@ go_test(
"//consensus-types/light-client:go_default_library",
"//consensus-types/primitives:go_default_library",
"//encoding/bytesutil:go_default_library",
"//math:go_default_library",
"//proto/dbval:go_default_library",
"//proto/engine/v1:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
Expand Down
11 changes: 11 additions & 0 deletions beacon-chain/db/kv/kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ type Store struct {
blockCache *ristretto.Cache[string, interfaces.ReadOnlySignedBeaconBlock]
validatorEntryCache *ristretto.Cache[[]byte, *ethpb.Validator]
stateSummaryCache *stateSummaryCache
stateDiffCache *stateDiffCache
ctx context.Context
}

Expand All @@ -112,6 +113,7 @@ var Buckets = [][]byte{
lightClientUpdatesBucket,
lightClientBootstrapBucket,
lightClientSyncCommitteeBucket,
stateDiffBucket,
// Indices buckets.
blockSlotIndicesBucket,
stateSlotIndicesBucket,
Expand Down Expand Up @@ -182,6 +184,7 @@ func NewKVStore(ctx context.Context, dirPath string, opts ...KVStoreOption) (*St
blockCache: blockCache,
validatorEntryCache: validatorCache,
stateSummaryCache: newStateSummaryCache(),
stateDiffCache: nil,
ctx: ctx,
}
for _, o := range opts {
Expand All @@ -200,6 +203,14 @@ func NewKVStore(ctx context.Context, dirPath string, opts ...KVStoreOption) (*St
return nil, err
}

if features.Get().EnableStateDiff {
sdCache, err := newStateDiffCache(kv)
if err != nil {
return nil, err
}
kv.stateDiffCache = sdCache
}

return kv, nil
}

Expand Down
1 change: 1 addition & 0 deletions beacon-chain/db/kv/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ var (
stateValidatorsBucket = []byte("state-validators")
feeRecipientBucket = []byte("fee-recipient")
registrationBucket = []byte("registration")
stateDiffBucket = []byte("state-diff")

// Light Client Updates Bucket
lightClientUpdatesBucket = []byte("light-client-updates")
Expand Down
33 changes: 33 additions & 0 deletions beacon-chain/db/kv/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,39 @@ func (s *Store) State(ctx context.Context, blockRoot [32]byte) (state.BeaconStat
ctx, span := trace.StartSpan(ctx, "BeaconDB.State")
defer span.End()
startTime := time.Now()

// If state diff is enabled, we get the state from the state-diff db.
if features.Get().EnableStateDiff {
var slot primitives.Slot

stateSummary, err := s.StateSummary(ctx, blockRoot)
if err != nil {
return nil, err
}
if stateSummary == nil {
blk, err := s.Block(ctx, blockRoot)
if err != nil {
return nil, err
}
if blk == nil || blk.IsNil() {
return nil, errors.New("neither state summary nor block found")
}
slot = blk.Block().Slot()
} else {
slot = stateSummary.Slot
}

st, err := s.stateByDiff(ctx, slot)
if err != nil {
return nil, err
}
if st == nil || st.IsNil() {
return nil, errors.New("state not found")
}
stateReadingTime.Observe(float64(time.Since(startTime).Milliseconds()))
return st, nil
}

enc, err := s.stateBytes(ctx, blockRoot)
if err != nil {
return nil, err
Expand Down
226 changes: 226 additions & 0 deletions beacon-chain/db/kv/state_diff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
package kv

import (
"context"

"github.com/OffchainLabs/prysm/v6/beacon-chain/state"
"github.com/OffchainLabs/prysm/v6/config/params"
"github.com/OffchainLabs/prysm/v6/consensus-types/hdiff"
"github.com/OffchainLabs/prysm/v6/consensus-types/primitives"
"github.com/OffchainLabs/prysm/v6/monitoring/tracing/trace"
"github.com/pkg/errors"
bolt "go.etcd.io/bbolt"
)

/*
We use a level-based approach to save state diffs. The levels are 0-6, where each level corresponds to an exponent of 2 (exponents[lvl]).
The data at level 0 is saved every 2**exponent[0] slots and always contains a full state snapshot that is used as a base for the delta saved at other levels.
*/

// saveStateByDiff takes a state and decides between saving a full state snapshot or a diff.
func (s *Store) saveStateByDiff(ctx context.Context, st state.ReadOnlyBeaconState) error {
_, span := trace.StartSpan(ctx, "BeaconDB.saveStateByDiff")
defer span.End()

if st == nil {
return errors.New("state is nil")
}

slot := st.Slot()
offset := s.getOffset()
if uint64(slot) < offset {
return ErrSlotBeforeOffset
}

// Find the level to save the state.
lvl := computeLevel(offset, slot)
if lvl == -1 {
return nil
}

// Save full state if level is 0.
if lvl == 0 {
return s.saveFullSnapshot(lvl, st)
}

// Get anchor state to compute the diff from.
anchorState, err := s.getAnchorState(offset, lvl, slot)
if err != nil {
return err
}

err = s.saveHdiff(lvl, anchorState, st)
if err != nil {
return err
}

return nil
}

// stateByDiff retrieves the full state for a given slot.
func (s *Store) stateByDiff(ctx context.Context, slot primitives.Slot) (state.BeaconState, error) {
offset := s.getOffset()
if uint64(slot) < offset {
return nil, ErrSlotBeforeOffset
}

snapshot, diffChain, err := s.getBaseAndDiffChain(offset, slot)
if err != nil {
return nil, err
}

for _, diff := range diffChain {
snapshot, err = hdiff.ApplyDiff(ctx, snapshot, diff)
if err != nil {
return nil, err
}
}

return snapshot, nil
}

// SaveHdiff computes the diff between the anchor state and the current state and saves it to the database.
func (s *Store) saveHdiff(lvl int, anchor, st state.ReadOnlyBeaconState) error {
slot := uint64(st.Slot())
key := makeKey(lvl, slot)

diff, err := hdiff.Diff(anchor, st)
if err != nil {
return err
}

err = s.db.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket(stateDiffBucket)
if bucket == nil {
return bolt.ErrBucketNotFound
}
buf := append(key, "_s"...)
if err := bucket.Put(buf, diff.StateDiff); err != nil {
return err
}
buf = append(key, "_v"...)
if err := bucket.Put(buf, diff.ValidatorDiffs); err != nil {
return err
}
buf = append(key, "_b"...)
if err := bucket.Put(buf, diff.BalancesDiff); err != nil {
return err
}
return nil
})
if err != nil {
return err
}

// Save the full state to the cache (if not the last level).
if lvl != len(params.StateHierarchyExponents())-1 {
err = s.stateDiffCache.setAnchor(lvl, st)
if err != nil {
return err
}
}

return nil
}

// SaveFullSnapshot saves the full level 0 state snapshot to the database.
func (s *Store) saveFullSnapshot(lvl int, st state.ReadOnlyBeaconState) error {
slot := uint64(st.Slot())
key := makeKey(lvl, slot)
stateBytes, err := st.MarshalSSZ()
if err != nil {
return err
}
// add version key to value
enc, err := addKey(st.Version(), stateBytes)
if err != nil {
return err
}

err = s.db.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket(stateDiffBucket)
if bucket == nil {
return bolt.ErrBucketNotFound
}

if err := bucket.Put(key, enc); err != nil {
return err
}

return nil
})
if err != nil {
return err
}
// Save the full state to the cache, and invalidate other levels.
s.stateDiffCache.clearAnchors()
err = s.stateDiffCache.setAnchor(lvl, st)
if err != nil {
return err
}

return nil
}

func (s *Store) getDiff(lvl int, slot uint64) (hdiff.HdiffBytes, error) {
key := makeKey(lvl, slot)
var stateDiff []byte
var validatorDiff []byte
var balancesDiff []byte

err := s.db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket(stateDiffBucket)
if bucket == nil {
return bolt.ErrBucketNotFound
}
buf := append(key, "_s"...)
stateDiff = bucket.Get(buf)
if stateDiff == nil {
return errors.New("state diff not found")
}
buf = append(key, "_v"...)
validatorDiff = bucket.Get(buf)
if validatorDiff == nil {
return errors.New("validator diff not found")
}
buf = append(key, "_b"...)
balancesDiff = bucket.Get(buf)
if balancesDiff == nil {
return errors.New("balances diff not found")
}
return nil
})

if err != nil {
return hdiff.HdiffBytes{}, err
}

return hdiff.HdiffBytes{
StateDiff: stateDiff,
ValidatorDiffs: validatorDiff,
BalancesDiff: balancesDiff,
}, nil
}

func (s *Store) getFullSnapshot(lvl int, slot uint64) (state.BeaconState, error) {
key := makeKey(lvl, slot)
var enc []byte

err := s.db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket(stateDiffBucket)
if bucket == nil {
return bolt.ErrBucketNotFound
}
enc = bucket.Get(key)
if enc == nil {
return errors.New("state not found")
}
return nil
})

if err != nil {
return nil, err
}

return s.decodeStateSnapshot(enc)
}
Loading
Loading