diff --git a/common/mclock/alarm.go b/common/mclock/alarm.go new file mode 100644 index 000000000..e83810a6a --- /dev/null +++ b/common/mclock/alarm.go @@ -0,0 +1,106 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package mclock + +import ( + "time" +) + +// Alarm sends timed notifications on a channel. This is very similar to a regular timer, +// but is easier to use in code that needs to re-schedule the same timer over and over. +// +// When scheduling an Alarm, the channel returned by C() will receive a value no later +// than the scheduled time. An Alarm can be reused after it has fired and can also be +// canceled by calling Stop. +type Alarm struct { + ch chan struct{} + clock Clock + timer Timer + deadline AbsTime +} + +// NewAlarm creates an Alarm. +func NewAlarm(clock Clock) *Alarm { + if clock == nil { + panic("nil clock") + } + return &Alarm{ + ch: make(chan struct{}, 1), + clock: clock, + } +} + +// C returns the alarm notification channel. This channel remains identical for +// the entire lifetime of the alarm, and is never closed. +func (e *Alarm) C() <-chan struct{} { + return e.ch +} + +// Stop cancels the alarm and drains the channel. +// This method is not safe for concurrent use. +func (e *Alarm) Stop() { + // Clear timer. + if e.timer != nil { + e.timer.Stop() + } + e.deadline = 0 + + // Drain the channel. + select { + case <-e.ch: + default: + } +} + +// Schedule sets the alarm to fire no later than the given time. If the alarm was already +// scheduled but has not fired yet, it may fire earlier than the newly-scheduled time. +func (e *Alarm) Schedule(time AbsTime) { + now := e.clock.Now() + e.schedule(now, time) +} + +func (e *Alarm) schedule(now, newDeadline AbsTime) { + if e.timer != nil { + if e.deadline > now && e.deadline <= newDeadline { + // Here, the current timer can be reused because it is already scheduled to + // occur earlier than the new deadline. + // + // The e.deadline > now part of the condition is important. If the old + // deadline lies in the past, we assume the timer has already fired and needs + // to be rescheduled. + return + } + e.timer.Stop() + } + + // Set the timer. + d := time.Duration(0) + if newDeadline < now { + newDeadline = now + } else { + d = newDeadline.Sub(now) + } + e.timer = e.clock.AfterFunc(d, e.send) + e.deadline = newDeadline +} + +func (e *Alarm) send() { + select { + case e.ch <- struct{}{}: + default: + } +} diff --git a/common/mclock/alarm_test.go b/common/mclock/alarm_test.go new file mode 100644 index 000000000..d2ad9913f --- /dev/null +++ b/common/mclock/alarm_test.go @@ -0,0 +1,116 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package mclock + +import "testing" + +// This test checks basic functionality of Alarm. +func TestAlarm(t *testing.T) { + clk := new(Simulated) + clk.Run(20) + a := NewAlarm(clk) + + a.Schedule(clk.Now() + 10) + if recv(a.C()) { + t.Fatal("Alarm fired before scheduled deadline") + } + if ntimers := clk.ActiveTimers(); ntimers != 1 { + t.Fatal("clock has", ntimers, "active timers, want", 1) + } + clk.Run(5) + if recv(a.C()) { + t.Fatal("Alarm fired too early") + } + + clk.Run(5) + if !recv(a.C()) { + t.Fatal("Alarm did not fire") + } + if recv(a.C()) { + t.Fatal("Alarm fired twice") + } + if ntimers := clk.ActiveTimers(); ntimers != 0 { + t.Fatal("clock has", ntimers, "active timers, want", 0) + } + + a.Schedule(clk.Now() + 5) + if recv(a.C()) { + t.Fatal("Alarm fired before scheduled deadline when scheduling the second event") + } + + clk.Run(5) + if !recv(a.C()) { + t.Fatal("Alarm did not fire when scheduling the second event") + } + if recv(a.C()) { + t.Fatal("Alarm fired twice when scheduling the second event") + } +} + +// This test checks that scheduling an Alarm to an earlier time than the +// one already scheduled works properly. +func TestAlarmScheduleEarlier(t *testing.T) { + clk := new(Simulated) + clk.Run(20) + a := NewAlarm(clk) + + a.Schedule(clk.Now() + 50) + clk.Run(5) + a.Schedule(clk.Now() + 1) + clk.Run(3) + if !recv(a.C()) { + t.Fatal("Alarm did not fire") + } +} + +// This test checks that scheduling an Alarm to a later time than the +// one already scheduled works properly. +func TestAlarmScheduleLater(t *testing.T) { + clk := new(Simulated) + clk.Run(20) + a := NewAlarm(clk) + + a.Schedule(clk.Now() + 50) + clk.Run(5) + a.Schedule(clk.Now() + 100) + clk.Run(50) + if !recv(a.C()) { + t.Fatal("Alarm did not fire") + } +} + +// This test checks that scheduling an Alarm in the past makes it fire immediately. +func TestAlarmNegative(t *testing.T) { + clk := new(Simulated) + clk.Run(50) + a := NewAlarm(clk) + + a.Schedule(-1) + clk.Run(1) // needed to process timers + if !recv(a.C()) { + t.Fatal("Alarm did not fire for negative time") + } +} + +func recv(ch <-chan struct{}) bool { + select { + case <-ch: + return true + default: + return false + } +} diff --git a/p2p/discover/common.go b/p2p/discover/common.go index 0872b1fa2..617b670a4 100644 --- a/p2p/discover/common.go +++ b/p2p/discover/common.go @@ -18,8 +18,13 @@ package discover import ( "crypto/ecdsa" + crand "crypto/rand" + "encoding/binary" "fmt" + "math/rand" "net" + "sync" + "time" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/core/forkid" @@ -72,15 +77,28 @@ type Config struct { // These settings are optional: NetRestrict *netutil.Netlist // list of allowed IP networks - Bootnodes []*enode.Node // list of bootstrap nodes Unhandled chan<- ReadPacket // unhandled packets are sent on this channel Log log.Logger // if set, log messages go here ValidSchemes enr.IdentityScheme // allowed identity schemes Clock mclock.Clock FilterFunction NodeFilterFunc // function for filtering ENR entries + + // Node table configuration: + Bootnodes []*enode.Node // list of bootstrap nodes + PingInterval time.Duration // speed of node liveness check + RefreshInterval time.Duration // used in bucket refresh } func (cfg Config) withDefaults() Config { + // Node table configuration: + if cfg.PingInterval == 0 { + cfg.PingInterval = 3 * time.Second + } + if cfg.RefreshInterval == 0 { + cfg.RefreshInterval = 30 * time.Minute + } + + // Debug/test settings: if cfg.Log == nil { cfg.Log = log.Root() } @@ -105,9 +123,43 @@ type ReadPacket struct { Addr *net.UDPAddr } -func min(x, y int) int { - if x > y { - return y - } - return x +type randomSource interface { + Intn(int) int + Int63n(int64) int64 + Shuffle(int, func(int, int)) +} + +// reseedingRandom is a random number generator that tracks when it was last re-seeded. +type reseedingRandom struct { + mu sync.Mutex + cur *rand.Rand +} + +func (r *reseedingRandom) seed() { + var b [8]byte + crand.Read(b[:]) + seed := binary.BigEndian.Uint64(b[:]) + new := rand.New(rand.NewSource(int64(seed))) + + r.mu.Lock() + r.cur = new + r.mu.Unlock() +} + +func (r *reseedingRandom) Intn(n int) int { + r.mu.Lock() + defer r.mu.Unlock() + return r.cur.Intn(n) +} + +func (r *reseedingRandom) Int63n(n int64) int64 { + r.mu.Lock() + defer r.mu.Unlock() + return r.cur.Int63n(n) +} + +func (r *reseedingRandom) Shuffle(n int, swap func(i, j int)) { + r.mu.Lock() + defer r.mu.Unlock() + r.cur.Shuffle(n, swap) } diff --git a/p2p/discover/lookup.go b/p2p/discover/lookup.go index 9ab4a71ce..6c725ae57 100644 --- a/p2p/discover/lookup.go +++ b/p2p/discover/lookup.go @@ -18,6 +18,7 @@ package discover import ( "context" + "errors" "time" "github.com/ethereum/go-ethereum/p2p/enode" @@ -28,7 +29,7 @@ import ( // not need to be an actual node identifier. type lookup struct { tab *Table - queryfunc func(*node) ([]*node, error) + queryfunc queryFunc replyCh chan []*node cancelCh <-chan struct{} asked, seen map[enode.ID]bool @@ -139,32 +140,13 @@ func (it *lookup) slowdown() { } func (it *lookup) query(n *node, reply chan<- []*node) { - fails := it.tab.db.FindFails(n.ID(), n.IP()) r, err := it.queryfunc(n) - if err == errClosed { - // Avoid recording failures on shutdown. - reply <- nil - return - } else if len(r) == 0 { - fails++ - it.tab.db.UpdateFindFails(n.ID(), n.IP(), fails) - // Remove the node from the local table if it fails to return anything useful too - // many times, but only if there are enough other nodes in the bucket. - dropped := false - if fails >= maxFindnodeFailures && it.tab.bucketLen(n.ID()) >= bucketSize/2 { - dropped = true - it.tab.delete(n) + if !errors.Is(err, errClosed) { // avoid recording failures on shutdown. + success := len(r) > 0 + it.tab.trackRequest(n, success, r) + if err != nil { + it.tab.log.Trace("FINDNODE failed", "id", n.ID(), "err", err) } - it.tab.log.Trace("FINDNODE failed", "id", n.ID(), "failcount", fails, "dropped", dropped, "err", err) - } else if fails > 0 { - // Reset failure counter because it counts _consecutive_ failures. - it.tab.db.UpdateFindFails(n.ID(), n.IP(), 0) - } - - // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll - // just remove those again during revalidation. - for _, n := range r { - it.tab.addSeenNode(n) } reply <- r } diff --git a/p2p/discover/node.go b/p2p/discover/node.go index 9ffe101cc..47788248f 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -29,12 +29,23 @@ import ( "github.com/ethereum/go-ethereum/p2p/enode" ) +type BucketNode struct { + Node *enode.Node `json:"node"` + AddedToTable time.Time `json:"addedToTable"` + AddedToBucket time.Time `json:"addedToBucket"` + Checks int `json:"checks"` + Live bool `json:"live"` +} + // node represents a host on the network. // The fields of Node may not be modified. type node struct { - enode.Node - addedAt time.Time // time when the node was added to the table - livenessChecks uint // how often liveness was checked + *enode.Node + revalList *revalidationList + addedToTable time.Time // first time node was added to bucket or replacement list + addedToBucket time.Time // time it was added in the actual bucket + livenessChecks uint // how often liveness was checked + isValidatedLive bool // true if existence of node is considered validated right now } type encPubkey [64]byte @@ -65,7 +76,7 @@ func (e encPubkey) id() enode.ID { } func wrapNode(n *enode.Node) *node { - return &node{Node: *n} + return &node{Node: n} } func wrapNodes(ns []*enode.Node) []*node { @@ -77,7 +88,7 @@ func wrapNodes(ns []*enode.Node) []*node { } func unwrapNode(n *node) *enode.Node { - return &n.Node + return n.Node } func unwrapNodes(ns []*node) []*enode.Node { diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 1b3c1cac9..b5db78176 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -23,16 +23,15 @@ package discover import ( - crand "crypto/rand" - "encoding/binary" "fmt" - mrand "math/rand" "net" + "slices" "sort" "sync" "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p/enode" @@ -55,7 +54,6 @@ const ( bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24 tableIPLimit, tableSubnet = 10, 24 - refreshInterval = 30 * time.Minute revalidateInterval = 10 * time.Second copyNodesInterval = 30 * time.Second seedMinTableTime = 5 * time.Minute @@ -67,20 +65,28 @@ const ( // itself up-to-date by verifying the liveness of neighbors and requesting their node // records when announcements of a new record version are received. type Table struct { - mutex sync.Mutex // protects buckets, bucket content, nursery, rand - buckets [nBuckets]*bucket // index of known nodes by distance - nursery []*node // bootstrap nodes - rand *mrand.Rand // source of randomness, periodically reseeded - ips netutil.DistinctNetSet - - log log.Logger - db *enode.DB // database of known nodes - net transport - refreshReq chan chan struct{} - initDone chan struct{} - closeReq chan struct{} - closed chan struct{} - workerPoolTask chan struct{} + mutex sync.Mutex // protects buckets, bucket content, nursery, rand + buckets [nBuckets]*bucket // index of known nodes by distance + nursery []*node // bootstrap nodes + rand reseedingRandom // source of randomness, periodically reseeded + ips netutil.DistinctNetSet + revalidation tableRevalidation + + log log.Logger + db *enode.DB // database of known nodes + net transport + cfg Config + + // loop channels + refreshReq chan chan struct{} + revalResponseCh chan revalidationResponse + addNodeCh chan addNodeOp + addNodeHandled chan bool + trackRequestCh chan trackRequestOp + initDone chan struct{} + closeReq chan struct{} + closed chan struct{} + workerPoolTask chan struct{} nodeAddedHook func(*bucket, *node) nodeRemovedHook func(*bucket, *node) @@ -105,22 +111,36 @@ type bucket struct { index int } -func newTable(t transport, db *enode.DB, bootnodes []*enode.Node, log log.Logger, filter NodeFilterFunc) (*Table, error) { +type addNodeOp struct { + node *node + isInbound bool + syncExecution bool // if true, the operation is executed synchronously, only for Testing. +} + +type trackRequestOp struct { + node *node + foundNodes []*node + success bool +} + +func newTable(t transport, db *enode.DB, cfg Config) (*Table, error) { + cfg = cfg.withDefaults() tab := &Table{ - net: t, - db: db, - refreshReq: make(chan chan struct{}), - initDone: make(chan struct{}), - closeReq: make(chan struct{}), - closed: make(chan struct{}), - workerPoolTask: make(chan struct{}, maxWorkerTask), - rand: mrand.New(mrand.NewSource(0)), - ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}, - log: log, - enrFilter: filter, - } - if err := tab.setFallbackNodes(bootnodes); err != nil { - return nil, err + net: t, + db: db, + cfg: cfg, + log: cfg.Log, + refreshReq: make(chan chan struct{}), + revalResponseCh: make(chan revalidationResponse), + addNodeCh: make(chan addNodeOp), + addNodeHandled: make(chan bool), + trackRequestCh: make(chan trackRequestOp), + initDone: make(chan struct{}), + closeReq: make(chan struct{}), + closed: make(chan struct{}), + workerPoolTask: make(chan struct{}, maxWorkerTask), + ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}, + enrFilter: cfg.FilterFunction, } for i := range tab.buckets { tab.buckets[i] = &bucket{ @@ -132,14 +152,20 @@ func newTable(t transport, db *enode.DB, bootnodes []*enode.Node, log log.Logger tab.workerPoolTask <- struct{}{} } - tab.seedRand() + tab.rand.seed() + tab.revalidation.init(&cfg) + + // initial table content + if err := tab.setFallbackNodes(cfg.Bootnodes); err != nil { + return nil, err + } tab.loadSeedNodes() return tab, nil } -func newMeteredTable(t transport, db *enode.DB, bootnodes []*enode.Node, log log.Logger, filter NodeFilterFunc) (*Table, error) { - tab, err := newTable(t, db, bootnodes, log, filter) +func newMeteredTable(t transport, db *enode.DB, cfg Config) (*Table, error) { + tab, err := newTable(t, db, cfg) if err != nil { return nil, err } @@ -158,38 +184,6 @@ func (tab *Table) self() *enode.Node { return tab.net.Self() } -func (tab *Table) seedRand() { - var b [8]byte - crand.Read(b[:]) - - tab.mutex.Lock() - tab.rand.Seed(int64(binary.BigEndian.Uint64(b[:]))) - tab.mutex.Unlock() -} - -// ReadRandomNodes fills the given slice with random nodes from the table. The results -// are guaranteed to be unique for a single invocation, no node will appear twice. -func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) { - if !tab.isInitDone() { - return 0 - } - tab.mutex.Lock() - defer tab.mutex.Unlock() - - var nodes []*enode.Node - for _, b := range &tab.buckets { - for _, n := range b.entries { - nodes = append(nodes, unwrapNode(n)) - } - } - // Shuffle. - for i := 0; i < len(nodes); i++ { - j := tab.rand.Intn(len(nodes)) - nodes[i], nodes[j] = nodes[j], nodes[i] - } - return copy(buf, nodes) -} - // getNode returns the node with the given ID or nil if it isn't in the table. func (tab *Table) getNode(id enode.ID) *enode.Node { tab.mutex.Lock() @@ -198,7 +192,7 @@ func (tab *Table) getNode(id enode.ID) *enode.Node { b := tab.bucket(id) for _, e := range b.entries { if e.ID() == id { - return unwrapNode(e) + return e.Node } } return nil @@ -233,12 +227,18 @@ func (tab *Table) close() { // are used to connect to the network if the table is empty and there // are no known nodes in the database. func (tab *Table) setFallbackNodes(nodes []*enode.Node) error { + nursery := make([]*node, 0, len(nodes)) for _, n := range nodes { if err := n.ValidateComplete(); err != nil { return fmt.Errorf("bad bootstrap node %q: %v", n, err) } + if tab.cfg.NetRestrict != nil && !tab.cfg.NetRestrict.Contains(n.IP()) { + tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IP()) + continue + } + nursery = append(nursery, wrapNode(n)) } - tab.nursery = wrapNodes(nodes) + tab.nursery = nursery return nil } @@ -262,51 +262,79 @@ func (tab *Table) refresh() <-chan struct{} { return done } +func (tab *Table) trackRequest(n *node, success bool, foundNodes []*node) { + op := trackRequestOp{n, foundNodes, success} + select { + case tab.trackRequestCh <- op: + case <-tab.closeReq: + } +} + // loop schedules runs of doRefresh, doRevalidate and copyLiveNodes. func (tab *Table) loop() { var ( - revalidate = time.NewTimer(tab.nextRevalidateTime()) - refresh = time.NewTicker(refreshInterval) - copyNodes = time.NewTicker(copyNodesInterval) - refreshDone = make(chan struct{}) // where doRefresh reports completion - revalidateDone chan struct{} // where doRevalidate reports completion - waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs + refresh = time.NewTimer(tab.nextRefreshTime()) + refreshDone = make(chan struct{}) // where doRefresh reports completion + waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs + revalTimer = mclock.NewAlarm(tab.cfg.Clock) + reseedRandTimer = time.NewTicker(10 * time.Minute) ) defer refresh.Stop() - defer revalidate.Stop() - defer copyNodes.Stop() // Start initial refresh. go tab.doRefresh(refreshDone) loop: for { + nextTime := tab.revalidation.run(tab, tab.cfg.Clock.Now()) + revalTimer.Schedule(nextTime) + select { + case <-reseedRandTimer.C: + tab.rand.seed() + + case <-revalTimer.C(): + + case r := <-tab.revalResponseCh: + tab.revalidation.handleResponse(tab, r) + + case op := <-tab.addNodeCh: + // only for testing, syncExecution is true. + if op.syncExecution { + ok := tab.handleAddNode(op) + tab.addNodeHandled <- ok + } else { + select { + case <-tab.workerPoolTask: + go tab.handleAddNode(op) + default: + tab.log.Debug("Worker pool task is full, dropping node", "id", op.node.ID(), "addr", op.node.addr()) + } + } + + case op := <-tab.trackRequestCh: + tab.handleTrackRequest(op) + case <-refresh.C: - tab.seedRand() if refreshDone == nil { refreshDone = make(chan struct{}) go tab.doRefresh(refreshDone) } + case req := <-tab.refreshReq: waiting = append(waiting, req) if refreshDone == nil { refreshDone = make(chan struct{}) go tab.doRefresh(refreshDone) } + case <-refreshDone: for _, ch := range waiting { close(ch) } waiting, refreshDone = nil, nil - case <-revalidate.C: - revalidateDone = make(chan struct{}) - go tab.doRevalidate(revalidateDone) - case <-revalidateDone: - revalidate.Reset(tab.nextRevalidateTime()) - revalidateDone = nil - case <-copyNodes.C: - go tab.copyLiveNodes() + refresh.Reset(tab.nextRefreshTime()) + case <-tab.closeReq: break loop } @@ -318,9 +346,6 @@ loop: for _, ch := range waiting { close(ch) } - if revalidateDone != nil { - <-revalidateDone - } tab.closeWorkerTask() close(tab.closed) @@ -355,100 +380,14 @@ func (tab *Table) loadSeedNodes() { seeds = append(seeds, tab.nursery...) for i := range seeds { seed := seeds[i] - age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) }} + age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age) - tab.addSeenNode(seed) - } -} - -// doRevalidate checks that the last node in a random bucket is still live and replaces or -// deletes the node if it isn't. -func (tab *Table) doRevalidate(done chan<- struct{}) { - defer func() { done <- struct{}{} }() - - last, bi := tab.nodeToRevalidate() - if last == nil { - // No non-empty bucket found. - return - } - var errHandle error - // Ping the selected node and wait for a pong. - remoteSeq, err := tab.net.ping(unwrapNode(last)) - - if err != nil { - errHandle = err - } - - // Also fetch record if the node replied and returned a higher sequence number. - if last.Seq() < remoteSeq { - n, err := tab.net.RequestENR(unwrapNode(last)) - if err != nil { - tab.log.Debug("ENR request failed", "id", last.ID(), "addr", last.addr(), "err", err) - errHandle = err - } else { - if tab.enrFilter != nil { - if !tab.enrFilter(n.Record()) { - tab.log.Trace("ENR record filter out", "id", last.ID(), "addr", last.addr()) - errHandle = fmt.Errorf("filtered node") - } - } - last = &node{Node: *n, addedAt: last.addedAt, livenessChecks: last.livenessChecks} - } - } - - tab.mutex.Lock() - defer tab.mutex.Unlock() - b := tab.buckets[bi] - if errHandle == nil { - // The node responded, move it to the front. - last.livenessChecks++ - tab.log.Debug("Revalidated node", "b", bi, "id", last.ID(), "checks", last.livenessChecks) - tab.bumpInBucket(b, last) - return - } - // No reply received, pick a replacement or delete the node if there aren't - // any replacements. - if r := tab.replace(b, last); r != nil { - tab.log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks, "r", r.ID(), "rip", r.IP()) - } else { - tab.log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks) - } -} - -// nodeToRevalidate returns the last node in a random, non-empty bucket. -func (tab *Table) nodeToRevalidate() (n *node, bi int) { - tab.mutex.Lock() - defer tab.mutex.Unlock() - - for _, bi = range tab.rand.Perm(len(tab.buckets)) { - b := tab.buckets[bi] - if len(b.entries) > 0 { - last := b.entries[len(b.entries)-1] - return last, bi - } - } - return nil, 0 -} - -func (tab *Table) nextRevalidateTime() time.Duration { - tab.mutex.Lock() - defer tab.mutex.Unlock() - return time.Duration(tab.rand.Int63n(int64(revalidateInterval))) -} - -// copyLiveNodes adds nodes from the table to the database if they have been in the table -// longer than seedMinTableTime. -func (tab *Table) copyLiveNodes() { - tab.mutex.Lock() - defer tab.mutex.Unlock() - - now := time.Now() - for _, b := range &tab.buckets { - for _, n := range b.entries { - if n.livenessChecks > 0 && now.Sub(n.addedAt) >= seedMinTableTime { - tab.db.UpdateNode(unwrapNode(n)) - } + select { + case <-tab.workerPoolTask: + go tab.handleAddNode(addNodeOp{node: seed, isInbound: false}) + default: + tab.log.Debug("Worker pool task is full, dropping node", "id", seed.ID(), "addr", seed.addr()) } } } @@ -516,188 +455,156 @@ func (tab *Table) bucketAtDistance(d int) *bucket { return tab.buckets[d-bucketMinDistance-1] } -// addSeenNode adds a node which may or may not be live to the end of a bucket. If the -// bucket has space available, adding the node succeeds immediately. Otherwise, the node is -// added to the replacements list. -// -// The caller must not hold tab.mutex. -func (tab *Table) addSeenNode(n *node) { - select { - case <-tab.workerPoolTask: - go tab.addSeenNodeSync(n) - default: - tab.log.Debug("workerPoolTask is not filled yet, dropping node", "id", n.ID(), "addr", n.addr()) +func (tab *Table) filterNode(n *node) bool { + if tab.enrFilter == nil { + return false + } + if node, err := tab.net.RequestENR(unwrapNode(n)); err != nil { + tab.log.Debug("ENR request failed", "id", n.ID(), "addr", n.addr(), "err", err) + return true + } else if !tab.enrFilter(node.Record()) { + tab.log.Trace("ENR record filter out", "id", n.ID(), "addr", n.addr()) + return true } + return false } -func (tab *Table) addSeenNodeSync(n *node) { - defer func() { - select { - case tab.workerPoolTask <- struct{}{}: - default: - tab.log.Debug("workerPoolTask full, dropping node", "id", n.ID(), "addr", n.addr()) - } - }() - - if n.ID() == tab.self().ID() { - return +func (tab *Table) addIP(b *bucket, ip net.IP) bool { + if len(ip) == 0 { + return false // Nodes without IP cannot be added. } - - filterEnrTotalCounter.Inc(1) - - if tab.filterNode(n) { - filterEnrSuccessCounter.Inc(1) - return + if netutil.IsLAN(ip) { + return true } - - tab.mutex.Lock() - defer tab.mutex.Unlock() - b := tab.bucket(n.ID()) - if contains(b.entries, n.ID()) { - // Already in bucket, don't add. - return + if !tab.ips.Add(ip) { + tab.log.Debug("IP exceeds table limit", "ip", ip) + return false } - if len(b.entries) >= bucketSize { - // Bucket full, maybe add as replacement. - tab.addReplacement(b, n) - return + if !b.ips.Add(ip) { + tab.log.Debug("IP exceeds bucket limit", "ip", ip) + tab.ips.Remove(ip) + return false } - if !tab.addIP(b, n.IP()) { - // Can't add: IP limit reached. + return true +} + +func (tab *Table) removeIP(b *bucket, ip net.IP) { + if netutil.IsLAN(ip) { return } - // Add to end of bucket: - b.entries = append(b.entries, n) - b.replacements = deleteNode(b.replacements, n) - n.addedAt = time.Now() - if tab.nodeAddedHook != nil { - tab.nodeAddedHook(b, n) - } + tab.ips.Remove(ip) + b.ips.Remove(ip) } -func (tab *Table) filterNode(n *node) bool { - if tab.enrFilter == nil { +// addFoundNode adds a node which may not be live. If the bucket has space available, +// adding the node succeeds immediately. Otherwise, the node is added to the replacements +// list. +// +// The caller must not hold tab.mutex. For Testing purpose mostly. +func (tab *Table) addFoundNode(n *node) bool { + op := addNodeOp{node: n, isInbound: false, syncExecution: true} + select { + case tab.addNodeCh <- op: + return <-tab.addNodeHandled + case <-tab.closeReq: return false } - if node, err := tab.net.RequestENR(unwrapNode(n)); err != nil { - tab.log.Debug("ENR request failed", "id", n.ID(), "addr", n.addr(), "err", err) - return true - } else if !tab.enrFilter(node.Record()) { - tab.log.Trace("ENR record filter out", "id", n.ID(), "addr", n.addr()) - return true - } - return false } -// addVerifiedNode adds a node whose existence has been verified recently to the front of a -// bucket. If the node is already in the bucket, it is moved to the front. If the bucket -// has no space, the node is added to the replacements list. +// addInboundNode adds a node from an inbound contact. If the bucket has no space, the +// node is added to the replacements list. // -// There is an additional safety measure: if the table is still initializing the node -// is not added. This prevents an attack where the table could be filled by just sending -// ping repeatedly. +// There is an additional safety measure: if the table is still initializing the node is +// not added. This prevents an attack where the table could be filled by just sending ping +// repeatedly. // // The caller must not hold tab.mutex. +func (tab *Table) addInboundNode(n *node) { + op := addNodeOp{node: n, isInbound: true} + select { + case tab.addNodeCh <- op: + return + case <-tab.closeReq: + return + } +} -func (tab *Table) addVerifiedNode(n *node) { +// Only for Testing purpose with syncExecution. +func (tab *Table) addInboundNodeSync(n *node) bool { + op := addNodeOp{node: n, isInbound: true, syncExecution: true} select { - case <-tab.workerPoolTask: - go tab.addVerifiedNodeSync(n) - default: - tab.log.Debug("workerPoolTask is not filled yet, dropping node", "id", n.ID(), "addr", n.addr()) + case tab.addNodeCh <- op: + return <-tab.addNodeHandled + case <-tab.closeReq: + return false } } -func (tab *Table) addVerifiedNodeSync(n *node) { +// handleAddNode adds the node in the request to the table, if there is space. +// The caller must hold tab.mutex. +func (tab *Table) handleAddNode(req addNodeOp) bool { defer func() { select { case tab.workerPoolTask <- struct{}{}: default: - tab.log.Debug("workerPoolTask task queue full, dropping node", "id", n.ID(), "addr", n.addr()) + tab.log.Debug("Worker pool task is full, no need to release worker task.") } }() - if !tab.isInitDone() { - return - } - if n.ID() == tab.self().ID() { - return + + if req.node.ID() == tab.self().ID() { + return false } filterEnrTotalCounter.Inc(1) - - if tab.filterNode(n) { + // Filter node if ENR not match. + if tab.filterNode(req.node) { filterEnrSuccessCounter.Inc(1) - return - } - tab.mutex.Lock() - defer tab.mutex.Unlock() - b := tab.bucket(n.ID()) - if tab.bumpInBucket(b, n) { - // Already in bucket, moved to front. - return - } - if len(b.entries) >= bucketSize { - // Bucket full, maybe add as replacement. - tab.addReplacement(b, n) - return - } - if !tab.addIP(b, n.IP()) { - // Can't add: IP limit reached. - return - } - // Add to front of bucket. - b.entries, _ = pushNode(b.entries, n, bucketSize) - b.replacements = deleteNode(b.replacements, n) - n.addedAt = time.Now() - if tab.nodeAddedHook != nil { - tab.nodeAddedHook(b, n) + return false } -} -// delete removes an entry from the node table. It is used to evacuate dead nodes. -func (tab *Table) delete(node *node) { tab.mutex.Lock() defer tab.mutex.Unlock() - tab.deleteInBucket(tab.bucket(node.ID()), node) -} - -func (tab *Table) addIP(b *bucket, ip net.IP) bool { - if len(ip) == 0 { - return false // Nodes without IP cannot be added. + // For nodes from inbound contact, there is an additional safety measure: if the table + // is still initializing the node is not added. + if req.isInbound && !tab.isInitDone() { + return false } - if netutil.IsLAN(ip) { - return true + + b := tab.bucket(req.node.ID()) + n, _ := tab.bumpInBucket(b, req.node.Node, req.isInbound) + if n != nil { + // Already in bucket. + return false } - if !tab.ips.Add(ip) { - tab.log.Debug("IP exceeds table limit", "ip", ip) + if len(b.entries) >= bucketSize { + // Bucket full, maybe add as replacement. + tab.addReplacement(b, req.node) return false } - if !b.ips.Add(ip) { - tab.log.Debug("IP exceeds bucket limit", "ip", ip) - tab.ips.Remove(ip) + if !tab.addIP(b, req.node.IP()) { + // Can't add: IP limit reached. return false } - return true -} -func (tab *Table) removeIP(b *bucket, ip net.IP) { - if netutil.IsLAN(ip) { - return - } - tab.ips.Remove(ip) - b.ips.Remove(ip) + // Add to bucket. + b.entries = append(b.entries, req.node) + b.replacements = deleteNode(b.replacements, req.node) + tab.nodeAdded(b, req.node) + return true } +// addReplacement adds n to the replacement cache of bucket b. func (tab *Table) addReplacement(b *bucket, n *node) { - for _, e := range b.replacements { - if e.ID() == n.ID() { - return // already in list - } + if contains(b.replacements, n.ID()) { + // TODO: update ENR + return } if !tab.addIP(b, n.IP()) { return } + + n.addedToTable = time.Now() var removed *node b.replacements, removed = pushNode(b.replacements, n, maxReplacements) if removed != nil { @@ -705,63 +612,141 @@ func (tab *Table) addReplacement(b *bucket, n *node) { } } -// replace removes n from the replacement list and replaces 'last' with it if it is the -// last entry in the bucket. If 'last' isn't the last entry, it has either been replaced -// with someone else or became active. -func (tab *Table) replace(b *bucket, last *node) *node { - if len(b.entries) == 0 || b.entries[len(b.entries)-1].ID() != last.ID() { - // Entry has moved, don't replace it. +func (tab *Table) nodeAdded(b *bucket, n *node) { + if n.addedToTable == (time.Time{}) { + n.addedToTable = time.Now() + } + n.addedToBucket = time.Now() + tab.revalidation.nodeAdded(tab, n) + if tab.nodeAddedHook != nil { + tab.nodeAddedHook(b, n) + } + if metrics.Enabled { + bucketsCounter[b.index].Inc(1) + } +} + +func (tab *Table) nodeRemoved(b *bucket, n *node) { + tab.revalidation.nodeRemoved(n) + if tab.nodeRemovedHook != nil { + tab.nodeRemovedHook(b, n) + } + if metrics.Enabled { + bucketsCounter[b.index].Dec(1) + } +} + +// deleteInBucket removes node n from the table. +// If there are replacement nodes in the bucket, the node is replaced. +func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *node { + index := slices.IndexFunc(b.entries, func(e *node) bool { return e.ID() == id }) + if index == -1 { + // Entry has been removed already. return nil } - // Still the last entry. + + // Remove the node. + n := b.entries[index] + b.entries = slices.Delete(b.entries, index, index+1) + tab.removeIP(b, n.IP()) + tab.nodeRemoved(b, n) + + // Add replacement. if len(b.replacements) == 0 { - tab.deleteInBucket(b, last) + tab.log.Debug("Removed dead node", "b", b.index, "id", n.ID(), "ip", n.IP()) return nil } - r := b.replacements[tab.rand.Intn(len(b.replacements))] - b.replacements = deleteNode(b.replacements, r) - b.entries[len(b.entries)-1] = r - tab.removeIP(b, last.IP()) - return r -} - -// bumpInBucket moves the given node to the front of the bucket entry list -// if it is contained in that list. -func (tab *Table) bumpInBucket(b *bucket, n *node) bool { - for i := range b.entries { - if b.entries[i].ID() == n.ID() { - if !n.IP().Equal(b.entries[i].IP()) { - // Endpoint has changed, ensure that the new IP fits into table limits. - tab.removeIP(b, b.entries[i].IP()) - if !tab.addIP(b, n.IP()) { - // It doesn't, put the previous one back. - tab.addIP(b, b.entries[i].IP()) - return false - } - } - // Move it to the front. - copy(b.entries[1:], b.entries[:i]) - b.entries[0] = n - return true + rindex := tab.rand.Intn(len(b.replacements)) + rep := b.replacements[rindex] + b.replacements = slices.Delete(b.replacements, rindex, rindex+1) + b.entries = append(b.entries, rep) + tab.nodeAdded(b, rep) + tab.log.Debug("Replaced dead node", "b", b.index, "id", n.ID(), "ip", n.IP(), "r", rep.ID(), "rip", rep.IP()) + return rep +} + +// bumpInBucket updates a node record if it exists in the bucket. +// The second return value reports whether the node's endpoint (IP/port) was updated. +func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool) (n *node, endpointChanged bool) { + i := slices.IndexFunc(b.entries, func(elem *node) bool { + return elem.ID() == newRecord.ID() + }) + if i == -1 { + return nil, false // not in bucket + } + n = b.entries[i] + + // For inbound updates (from the node itself) we accept any change, even if it sets + // back the sequence number. For found nodes (!isInbound), seq has to advance. Note + // this check also ensures found discv4 nodes (which always have seq=0) can't be + // updated. + if newRecord.Seq() <= n.Seq() && !isInbound { + return n, false + } + + // Check endpoint update against IP limits. + ipchanged := !(net.IP.Equal(newRecord.IP(), n.IP())) + portchanged := newRecord.UDP() != n.UDP() + if ipchanged { + tab.removeIP(b, n.IP()) + if !tab.addIP(b, newRecord.IP()) { + // It doesn't fit with the limit, put the previous record back. + tab.addIP(b, n.IP()) + return n, false } } - return false + + // Apply update. + n.Node = newRecord + if ipchanged || portchanged { + // Ensure node is revalidated quickly for endpoint changes. + tab.revalidation.nodeEndpointChanged(tab, n) + return n, true + } + return n, false } -func (tab *Table) deleteInBucket(b *bucket, n *node) { - // Check if node is actually in the bucket so the removed hook - // isn't called multipled for the same node. - if !contains(b.entries, n.ID()) { - return +func (tab *Table) handleTrackRequest(op trackRequestOp) { + var fails int + if op.success { + // Reset failure counter because it counts _consecutive_ failures. + tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), 0) + } else { + fails = tab.db.FindFails(op.node.ID(), op.node.IP()) + fails++ + tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), fails) } - b.entries = deleteNode(b.entries, n) - tab.removeIP(b, n.IP()) - if tab.nodeRemovedHook != nil { - tab.nodeRemovedHook(b, n) + tab.mutex.Lock() + + b := tab.bucket(op.node.ID()) + // Remove the node from the local table if it fails to return anything useful too + // many times, but only if there are enough other nodes in the bucket. This latter + // condition specifically exists to make bootstrapping in smaller test networks more + // reliable. + if fails >= maxFindnodeFailures && len(b.entries) >= bucketSize/4 { + tab.deleteInBucket(b, op.node.ID()) + } + + // We already hold lock in handleAddNode, so need to unlock it here for avoiding hold lock so much when running handleAddNode. + tab.mutex.Unlock() + + // Add found nodes. + for _, n := range op.foundNodes { + select { + case <-tab.workerPoolTask: + go tab.handleAddNode(addNodeOp{n, false, false}) + default: + tab.log.Debug("Worker pool task is full, dropping node", "id", n.ID(), "addr", n.addr()) + } } } +func (tab *Table) nextRefreshTime() time.Duration { + half := tab.cfg.RefreshInterval / 2 + return half + time.Duration(tab.rand.Int63n(int64(half))) +} + func contains(ns []*node, id enode.ID) bool { for _, n := range ns { if n.ID() == id { @@ -803,15 +788,14 @@ func (h *nodesByDistance) push(n *node, maxElems int) { ix := sort.Search(len(h.entries), func(i int) bool { return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0 }) + + end := len(h.entries) if len(h.entries) < maxElems { h.entries = append(h.entries, n) } - if ix == len(h.entries) { - // farther away than all nodes we already have. - // if there was room for it, the node is now the last element. - } else { - // slide existing entries down to make room - // this will overwrite the entry we just appended. + if ix < end { + // Slide existing entries down to make room. + // This will overwrite the entry we just appended. copy(h.entries[ix+1:], h.entries[ix:]) h.entries[ix] = n } diff --git a/p2p/discover/table_reval.go b/p2p/discover/table_reval.go new file mode 100644 index 000000000..01e9cf7ca --- /dev/null +++ b/p2p/discover/table_reval.go @@ -0,0 +1,253 @@ +// Copyright 2024 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "fmt" + "math" + "slices" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +const never = mclock.AbsTime(math.MaxInt64) + +const slowRevalidationFactor = 3 + +// tableRevalidation implements the node revalidation process. +// It tracks all nodes contained in Table, and schedules sending PING to them. +type tableRevalidation struct { + fast revalidationList + slow revalidationList + activeReq map[enode.ID]struct{} +} + +type revalidationResponse struct { + n *node + newRecord *enode.Node + didRespond bool +} + +func (tr *tableRevalidation) init(cfg *Config) { + tr.activeReq = make(map[enode.ID]struct{}) + tr.fast.nextTime = never + tr.fast.interval = cfg.PingInterval + tr.fast.name = "fast" + tr.slow.nextTime = never + tr.slow.interval = cfg.PingInterval * slowRevalidationFactor + tr.slow.name = "slow" +} + +// nodeAdded is called when the table receives a new node. +func (tr *tableRevalidation) nodeAdded(tab *Table, n *node) { + tr.fast.push(n, tab.cfg.Clock.Now(), &tab.rand) +} + +// nodeRemoved is called when a node was removed from the table. +func (tr *tableRevalidation) nodeRemoved(n *node) { + if n.revalList == nil { + panic(fmt.Errorf("removed node %v has nil revalList", n.ID())) + } + n.revalList.remove(n) +} + +// nodeEndpointChanged is called when a change in IP or port is detected. +func (tr *tableRevalidation) nodeEndpointChanged(tab *Table, n *node) { + n.isValidatedLive = false + tr.moveToList(&tr.fast, n, tab.cfg.Clock.Now(), &tab.rand) +} + +// run performs node revalidation. +// It returns the next time it should be invoked, which is used in the Table main loop +// to schedule a timer. However, run can be called at any time. +func (tr *tableRevalidation) run(tab *Table, now mclock.AbsTime) (nextTime mclock.AbsTime) { + reval := func(list *revalidationList) { + if list.nextTime <= now { + if n := list.get(&tab.rand, tr.activeReq); n != nil { + tr.startRequest(tab, n) + } + // Update nextTime regardless if any requests were started because + // current value has passed. + list.schedule(now, &tab.rand) + } + } + reval(&tr.fast) + reval(&tr.slow) + + return min(tr.fast.nextTime, tr.slow.nextTime) +} + +// startRequest spawns a revalidation request for node n. +func (tr *tableRevalidation) startRequest(tab *Table, n *node) { + if _, ok := tr.activeReq[n.ID()]; ok { + panic(fmt.Errorf("duplicate startRequest (node %v)", n.ID())) + } + tr.activeReq[n.ID()] = struct{}{} + resp := revalidationResponse{n: n} + + // Fetch the node while holding lock. + tab.mutex.Lock() + node := n.Node + tab.mutex.Unlock() + + go tab.doRevalidate(resp, node) +} + +func (tab *Table) doRevalidate(resp revalidationResponse, node *enode.Node) { + // Ping the selected node and wait for a pong response. + remoteSeq, err := tab.net.ping(node) + resp.didRespond = err == nil + + // Also fetch record if the node replied and returned a higher sequence number. + if remoteSeq > node.Seq() { + newrec, err := tab.net.RequestENR(node) + if err != nil { + tab.log.Debug("ENR request failed", "id", node.ID(), "err", err) + } else { + resp.newRecord = newrec + } + } + + select { + case tab.revalResponseCh <- resp: + case <-tab.closed: + } +} + +// handleResponse processes the result of a revalidation request. +func (tr *tableRevalidation) handleResponse(tab *Table, resp revalidationResponse) { + var ( + now = tab.cfg.Clock.Now() + n = resp.n + b = tab.bucket(n.ID()) + ) + delete(tr.activeReq, n.ID()) + + // If the node was removed from the table while getting checked, we need to stop + // processing here to avoid re-adding it. + if n.revalList == nil { + return + } + + // Store potential seeds in database. + // This is done via defer to avoid holding Table lock while writing to DB. + defer func() { + if n.isValidatedLive && n.livenessChecks > 5 { + tab.db.UpdateNode(resp.n.Node) + } + }() + + // Remaining logic needs access to Table internals. + tab.mutex.Lock() + defer tab.mutex.Unlock() + + if !resp.didRespond { + n.livenessChecks /= 3 + if n.livenessChecks <= 0 { + tab.deleteInBucket(b, n.ID()) + } else { + tab.log.Debug("Node revalidation failed", "b", b.index, "id", n.ID(), "checks", n.livenessChecks, "q", n.revalList.name) + tr.moveToList(&tr.fast, n, now, &tab.rand) + } + return + } + + // The node responded. + n.livenessChecks++ + n.isValidatedLive = true + tab.log.Debug("Node revalidated", "b", b.index, "id", n.ID(), "checks", n.livenessChecks, "q", n.revalList.name) + var endpointChanged bool + if resp.newRecord != nil { + if tab.enrFilter != nil && !tab.enrFilter(resp.newRecord.Record()) { + tab.log.Trace("ENR record filter out", "id", n.ID()) + tab.deleteInBucket(b, n.ID()) + return + } + _, endpointChanged = tab.bumpInBucket(b, resp.newRecord, false) + } + + // Node moves to slow list if it passed and hasn't changed. + if !endpointChanged { + tr.moveToList(&tr.slow, n, now, &tab.rand) + } +} + +// moveToList ensures n is in the 'dest' list. +func (tr *tableRevalidation) moveToList(dest *revalidationList, n *node, now mclock.AbsTime, rand randomSource) { + if n.revalList == dest { + return + } + if n.revalList != nil { + n.revalList.remove(n) + } + dest.push(n, now, rand) +} + +// revalidationList holds a list nodes and the next revalidation time. +type revalidationList struct { + nodes []*node + nextTime mclock.AbsTime + interval time.Duration + name string +} + +// get returns a random node from the queue. Nodes in the 'exclude' map are not returned. +func (list *revalidationList) get(rand randomSource, exclude map[enode.ID]struct{}) *node { + if len(list.nodes) == 0 { + return nil + } + for i := 0; i < len(list.nodes)*3; i++ { + n := list.nodes[rand.Intn(len(list.nodes))] + _, excluded := exclude[n.ID()] + if !excluded { + return n + } + } + return nil +} + +func (list *revalidationList) schedule(now mclock.AbsTime, rand randomSource) { + list.nextTime = now.Add(time.Duration(rand.Int63n(int64(list.interval)))) +} + +func (list *revalidationList) push(n *node, now mclock.AbsTime, rand randomSource) { + list.nodes = append(list.nodes, n) + if list.nextTime == never { + list.schedule(now, rand) + } + n.revalList = list +} + +func (list *revalidationList) remove(n *node) { + i := slices.Index(list.nodes, n) + if i == -1 { + panic(fmt.Errorf("node %v not found in list", n.ID())) + } + list.nodes = slices.Delete(list.nodes, i, i+1) + if len(list.nodes) == 0 { + list.nextTime = never + } + n.revalList = nil +} + +func (list *revalidationList) contains(id enode.ID) bool { + return slices.ContainsFunc(list.nodes, func(n *node) bool { + return n.ID() == id + }) +} diff --git a/p2p/discover/table_reval_test.go b/p2p/discover/table_reval_test.go new file mode 100644 index 000000000..d168767e0 --- /dev/null +++ b/p2p/discover/table_reval_test.go @@ -0,0 +1,119 @@ +// Copyright 2024 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "net" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" +) + +// This test checks that revalidation can handle a node disappearing while +// a request is active. +func TestRevalidation_nodeRemoved(t *testing.T) { + var ( + clock mclock.Simulated + transport = newPingRecorder() + tab, db = newInactiveTestTable(transport, Config{Clock: &clock}) + tr = &tab.revalidation + ) + defer db.Close() + + // Add a node to the table. + node := nodeAtDistance(tab.self().ID(), 255, net.IP{77, 88, 99, 1}) + tab.handleAddNode(addNodeOp{node: node}) + + // Start a revalidation request. Schedule once to get the next start time, + // then advance the clock to that point and schedule again to start. + next := tr.run(tab, clock.Now()) + clock.Run(time.Duration(next + 1)) + tr.run(tab, clock.Now()) + if len(tr.activeReq) != 1 { + t.Fatal("revalidation request did not start:", tr.activeReq) + } + + // Delete the node. + tab.deleteInBucket(tab.bucket(node.ID()), node.ID()) + + // Now finish the revalidation request. + var resp revalidationResponse + select { + case resp = <-tab.revalResponseCh: + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for revalidation") + } + tr.handleResponse(tab, resp) + + // Ensure the node was not re-added to the table. + if tab.getNode(node.ID()) != nil { + t.Fatal("node was re-added to Table") + } + if tr.fast.contains(node.ID()) || tr.slow.contains(node.ID()) { + t.Fatal("removed node contained in revalidation list") + } +} + +// This test checks that nodes with an updated endpoint remain in the fast revalidation list. +func TestRevalidation_endpointUpdate(t *testing.T) { + var ( + clock mclock.Simulated + transport = newPingRecorder() + tab, db = newInactiveTestTable(transport, Config{Clock: &clock}) + tr = &tab.revalidation + ) + defer db.Close() + + // Add node to table. + node := nodeAtDistance(tab.self().ID(), 255, net.IP{77, 88, 99, 1}) + tab.handleAddNode(addNodeOp{node: node}) + + // Update the record in transport, including endpoint update. + record := node.Record() + record.Set(enr.IP{100, 100, 100, 100}) + record.Set(enr.UDP(9999)) + nodev2 := enode.SignNull(record, node.ID()) + transport.updateRecord(nodev2) + + // Start a revalidation request. Schedule once to get the next start time, + // then advance the clock to that point and schedule again to start. + next := tr.run(tab, clock.Now()) + clock.Run(time.Duration(next + 1)) + tr.run(tab, clock.Now()) + if len(tr.activeReq) != 1 { + t.Fatal("revalidation request did not start:", tr.activeReq) + } + + // Now finish the revalidation request. + var resp revalidationResponse + select { + case resp = <-tab.revalResponseCh: + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for revalidation") + } + tr.handleResponse(tab, resp) + + if !tr.fast.contains(node.ID()) { + t.Fatal("node not contained in fast revalidation list") + } + if node.isValidatedLive { + t.Fatal("node is marked live after endpoint change") + } +} diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index d6e965377..f1ba110ab 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -20,20 +20,19 @@ import ( "crypto/ecdsa" "fmt" "math/rand" - "net" "reflect" "testing" "testing/quick" "time" - "github.com/ethereum/go-ethereum/core/forkid" + "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/internal/testlog" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/netutil" - "github.com/ethereum/go-ethereum/params" - "github.com/ethereum/go-ethereum/rlp" ) func TestTable_pingReplace(t *testing.T) { @@ -52,106 +51,109 @@ func TestTable_pingReplace(t *testing.T) { } func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) { + simclock := new(mclock.Simulated) transport := newPingRecorder() - tab, db := newTestTable(transport) + tab, db := newTestTable(transport, Config{ + Clock: simclock, + Log: testlog.Logger(t, log.LvlTrace), + }) defer db.Close() defer tab.close() <-tab.initDone // Fill up the sender's bucket. - pingKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8") - pingSender := wrapNode(enode.NewV4(&pingKey.PublicKey, net.IP{127, 0, 0, 1}, 99, 99)) - last := fillBucket(tab, pingSender) + replacementNodeKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8") + replacementNode := wrapNode(enode.NewV4(&replacementNodeKey.PublicKey, net.IP{127, 0, 0, 1}, 99, 99)) + last := fillBucket(tab, replacementNode.ID()) + tab.mutex.Lock() + nodeEvents := newNodeEventRecorder(128) + tab.nodeAddedHook = nodeEvents.nodeAdded + tab.nodeRemovedHook = nodeEvents.nodeRemoved + tab.mutex.Unlock() - // Add the sender as if it just pinged us. Revalidate should replace the last node in - // its bucket if it is unresponsive. Revalidate again to ensure that + // The revalidation process should replace + // this node in the bucket if it is unresponsive. transport.dead[last.ID()] = !lastInBucketIsResponding - transport.dead[pingSender.ID()] = !newNodeIsResponding - tab.addSeenNodeSync(pingSender) - tab.doRevalidate(make(chan struct{}, 1)) - tab.doRevalidate(make(chan struct{}, 1)) - - if !transport.pinged[last.ID()] { - // Oldest node in bucket is pinged to see whether it is still alive. - t.Error("table did not ping last node in bucket") + transport.dead[replacementNode.ID()] = !newNodeIsResponding + + // Add replacement node to table. + tab.addFoundNode(replacementNode) + + t.Log("last:", last.ID()) + t.Log("replacement:", replacementNode.ID()) + + // Wait until the last node was pinged. + waitForRevalidationPing(t, transport, tab, last.ID()) + + if !lastInBucketIsResponding { + if !nodeEvents.waitNodeAbsent(last.ID(), 2*time.Second) { + t.Error("last node was not removed") + } + if !nodeEvents.waitNodePresent(replacementNode.ID(), 2*time.Second) { + t.Error("replacement node was not added") + } + + // If a replacement is expected, we also need to wait until the replacement node + // was pinged and added/removed. + waitForRevalidationPing(t, transport, tab, replacementNode.ID()) + if !newNodeIsResponding { + if !nodeEvents.waitNodeAbsent(replacementNode.ID(), 2*time.Second) { + t.Error("replacement node was not removed") + } + } } + // Check bucket content. tab.mutex.Lock() defer tab.mutex.Unlock() wantSize := bucketSize if !lastInBucketIsResponding && !newNodeIsResponding { wantSize-- } - if l := len(tab.bucket(pingSender.ID()).entries); l != wantSize { - t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize) + bucket := tab.bucket(replacementNode.ID()) + if l := len(bucket.entries); l != wantSize { + t.Errorf("wrong bucket size after revalidation: got %d, want %d", l, wantSize) } - if found := contains(tab.bucket(pingSender.ID()).entries, last.ID()); found != lastInBucketIsResponding { - t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding) + if ok := contains(bucket.entries, last.ID()); ok != lastInBucketIsResponding { + t.Errorf("revalidated node found: %t, want: %t", ok, lastInBucketIsResponding) } wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding - if found := contains(tab.bucket(pingSender.ID()).entries, pingSender.ID()); found != wantNewEntry { - t.Errorf("new entry found: %t, want: %t", found, wantNewEntry) + if ok := contains(bucket.entries, replacementNode.ID()); ok != wantNewEntry { + t.Errorf("replacement node found: %t, want: %t", ok, wantNewEntry) } } -func TestBucket_bumpNoDuplicates(t *testing.T) { - t.Parallel() - cfg := &quick.Config{ - MaxCount: 1000, - Rand: rand.New(rand.NewSource(time.Now().Unix())), - Values: func(args []reflect.Value, rand *rand.Rand) { - // generate a random list of nodes. this will be the content of the bucket. - n := rand.Intn(bucketSize-1) + 1 - nodes := make([]*node, n) - for i := range nodes { - nodes[i] = nodeAtDistance(enode.ID{}, 200, intIP(200)) - } - args[0] = reflect.ValueOf(nodes) - // generate random bump positions. - bumps := make([]int, rand.Intn(100)) - for i := range bumps { - bumps[i] = rand.Intn(len(nodes)) - } - args[1] = reflect.ValueOf(bumps) - }, - } - - prop := func(nodes []*node, bumps []int) (ok bool) { - tab, db := newTestTable(newPingRecorder()) - defer db.Close() - defer tab.close() +// waitForRevalidationPing waits until a PING message is sent to a node with the given id. +func waitForRevalidationPing(t *testing.T, transport *pingRecorder, tab *Table, id enode.ID) *enode.Node { + t.Helper() - b := &bucket{entries: make([]*node, len(nodes))} - copy(b.entries, nodes) - for i, pos := range bumps { - tab.bumpInBucket(b, b.entries[pos]) - if hasDuplicates(b.entries) { - t.Logf("bucket has duplicates after %d/%d bumps:", i+1, len(bumps)) - for _, n := range b.entries { - t.Logf(" %p", n) - } - return false - } + simclock := tab.cfg.Clock.(*mclock.Simulated) + maxAttempts := tab.len() * 8 + for i := 0; i < maxAttempts; i++ { + simclock.Run(tab.cfg.PingInterval * slowRevalidationFactor) + p := transport.waitPing(2 * time.Second) + if p == nil { + t.Fatal("Table did not send revalidation ping") + } + if id == (enode.ID{}) || p.ID() == id { + return p } - checkIPLimitInvariant(t, tab) - return true - } - if err := quick.Check(prop, cfg); err != nil { - t.Error(err) } + t.Fatalf("Table did not ping node %v (%d attempts)", id, maxAttempts) + return nil } // This checks that the table-wide IP limit is applied correctly. func TestTable_IPLimit(t *testing.T) { transport := newPingRecorder() - tab, db := newTestTable(transport) + tab, db := newTestTable(transport, Config{}) defer db.Close() defer tab.close() for i := 0; i < tableIPLimit+1; i++ { n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)}) - tab.addSeenNodeSync(n) + tab.addFoundNode(n) } if tab.len() > tableIPLimit { t.Errorf("too many nodes in table") @@ -162,14 +164,14 @@ func TestTable_IPLimit(t *testing.T) { // This checks that the per-bucket IP limit is applied correctly. func TestTable_BucketIPLimit(t *testing.T) { transport := newPingRecorder() - tab, db := newTestTable(transport) + tab, db := newTestTable(transport, Config{}) defer db.Close() defer tab.close() d := 3 for i := 0; i < bucketIPLimit+1; i++ { n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)}) - tab.addSeenNodeSync(n) + tab.addFoundNode(n) } if tab.len() > bucketIPLimit { t.Errorf("too many nodes in table") @@ -199,10 +201,10 @@ func TestTable_findnodeByID(t *testing.T) { test := func(test *closeTest) bool { // for any node table, Target and N transport := newPingRecorder() - tab, db := newTestTable(transport) + tab, db := newTestTable(transport, Config{}) defer db.Close() defer tab.close() - fillTable(tab, test.All) + fillTable(tab, test.All, true) // check that closest(Target, N) returns nodes result := tab.findnodeByID(test.Target, test.N, false).entries @@ -250,41 +252,6 @@ func TestTable_findnodeByID(t *testing.T) { } } -func TestTable_ReadRandomNodesGetAll(t *testing.T) { - cfg := &quick.Config{ - MaxCount: 200, - Rand: rand.New(rand.NewSource(time.Now().Unix())), - Values: func(args []reflect.Value, rand *rand.Rand) { - args[0] = reflect.ValueOf(make([]*enode.Node, rand.Intn(1000))) - }, - } - test := func(buf []*enode.Node) bool { - transport := newPingRecorder() - tab, db := newTestTable(transport) - defer db.Close() - defer tab.close() - <-tab.initDone - - for i := 0; i < len(buf); i++ { - ld := cfg.Rand.Intn(len(tab.buckets)) - fillTable(tab, []*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))}) - } - gotN := tab.ReadRandomNodes(buf) - if gotN != tab.len() { - t.Errorf("wrong number of nodes, got %d, want %d", gotN, tab.len()) - return false - } - if hasDuplicates(wrapNodes(buf[:gotN])) { - t.Errorf("result contains duplicates") - return false - } - return true - } - if err := quick.Check(test, cfg); err != nil { - t.Error(err) - } -} - type closeTest struct { Self enode.ID Target enode.ID @@ -308,8 +275,8 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value { return reflect.ValueOf(t) } -func TestTable_addVerifiedNode(t *testing.T) { - tab, db := newTestTable(newPingRecorder()) +func TestTable_addInboundNode(t *testing.T) { + tab, db := newTestTable(newPingRecorder(), Config{}) <-tab.initDone defer db.Close() defer tab.close() @@ -317,31 +284,29 @@ func TestTable_addVerifiedNode(t *testing.T) { // Insert two nodes. n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) - tab.addSeenNodeSync(n1) - tab.addSeenNodeSync(n2) - - // Verify bucket content: - bcontent := []*node{n1, n2} - if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) { - t.Fatalf("wrong bucket content: %v", tab.bucket(n1.ID()).entries) - } + tab.addFoundNode(n1) + tab.addFoundNode(n2) + checkBucketContent(t, tab, []*enode.Node{n1.Node, n2.Node}) - // Add a changed version of n2. + // Add a changed version of n2. The bucket should be updated. newrec := n2.Record() newrec.Set(enr.IP{99, 99, 99, 99}) - newn2 := wrapNode(enode.SignNull(newrec, n2.ID())) - tab.addVerifiedNodeSync(newn2) - - // Check that bucket is updated correctly. - newBcontent := []*node{newn2, n1} - if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, newBcontent) { - t.Fatalf("wrong bucket content after update: %v", tab.bucket(n1.ID()).entries) - } - checkIPLimitInvariant(t, tab) + n2v2 := enode.SignNull(newrec, n2.ID()) + tab.addInboundNodeSync(wrapNode(n2v2)) + checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2}) + + // Try updating n2 without sequence number change. The update is accepted + // because it's inbound. + newrec = n2.Record() + newrec.Set(enr.IP{100, 100, 100, 100}) + newrec.SetSeq(n2.Seq()) + n2v3 := enode.SignNull(newrec, n2.ID()) + tab.addInboundNodeSync(wrapNode(n2v3)) + checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v3}) } -func TestTable_addSeenNode(t *testing.T) { - tab, db := newTestTable(newPingRecorder()) +func TestTable_addFoundNode(t *testing.T) { + tab, db := newTestTable(newPingRecorder(), Config{}) <-tab.initDone defer db.Close() defer tab.close() @@ -349,25 +314,86 @@ func TestTable_addSeenNode(t *testing.T) { // Insert two nodes. n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) - tab.addSeenNodeSync(n1) - tab.addSeenNodeSync(n2) - - // Verify bucket content: - bcontent := []*node{n1, n2} - if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) { - t.Fatalf("wrong bucket content: %v", tab.bucket(n1.ID()).entries) - } + tab.addFoundNode(n1) + tab.addFoundNode(n2) + checkBucketContent(t, tab, []*enode.Node{n1.Node, n2.Node}) - // Add a changed version of n2. + // Add a changed version of n2. The bucket should be updated. newrec := n2.Record() newrec.Set(enr.IP{99, 99, 99, 99}) - newn2 := wrapNode(enode.SignNull(newrec, n2.ID())) - tab.addSeenNodeSync(newn2) + n2v2 := enode.SignNull(newrec, n2.ID()) + tab.addFoundNode(wrapNode(n2v2)) + checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2}) + + // Try updating n2 without a sequence number change. + // The update should not be accepted. + newrec = n2.Record() + newrec.Set(enr.IP{100, 100, 100, 100}) + newrec.SetSeq(n2.Seq()) + n2v3 := enode.SignNull(newrec, n2.ID()) + tab.addFoundNode(wrapNode(n2v3)) + checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2}) +} + +// This test checks that discv4 nodes can update their own endpoint via PING. +func TestTable_addInboundNodeUpdateV4Accept(t *testing.T) { + tab, db := newTestTable(newPingRecorder(), Config{}) + <-tab.initDone + defer db.Close() + defer tab.close() + + // Add a v4 node. + key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3") + n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000) + tab.addInboundNodeSync(wrapNode(n1)) + checkBucketContent(t, tab, []*enode.Node{n1}) + + // Add an updated version with changed IP. + // The update will be accepted because it is inbound. + n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000) + tab.addInboundNodeSync(wrapNode(n1v2)) + checkBucketContent(t, tab, []*enode.Node{n1v2}) +} + +// This test checks that discv4 node entries will NOT be updated when a +// changed record is found. +func TestTable_addFoundNodeV4UpdateReject(t *testing.T) { + tab, db := newTestTable(newPingRecorder(), Config{}) + <-tab.initDone + defer db.Close() + defer tab.close() + + // Add a v4 node. + key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3") + n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000) + tab.addFoundNode(wrapNode(n1)) + checkBucketContent(t, tab, []*enode.Node{n1}) + + // Add an updated version with changed IP. + // The update won't be accepted because it isn't inbound. + n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000) + tab.addFoundNode(wrapNode(n1v2)) + checkBucketContent(t, tab, []*enode.Node{n1}) +} + +func checkBucketContent(t *testing.T, tab *Table, nodes []*enode.Node) { + t.Helper() - // Check that bucket content is unchanged. - if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) { - t.Fatalf("wrong bucket content after update: %v", tab.bucket(n1.ID()).entries) + b := tab.bucket(nodes[0].ID()) + if reflect.DeepEqual(unwrapNodes(b.entries), nodes) { + return } + t.Log("wrong bucket content. have nodes:") + for _, n := range b.entries { + t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP()) + } + t.Log("want nodes:") + for _, n := range nodes { + t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP()) + } + t.FailNow() + + // Also check IP limits. checkIPLimitInvariant(t, tab) } @@ -375,7 +401,10 @@ func TestTable_addSeenNode(t *testing.T) { // announces a new sequence number, the new record should be pulled. func TestTable_revalidateSyncRecord(t *testing.T) { transport := newPingRecorder() - tab, db := newTestTable(transport) + tab, db := newTestTable(transport, Config{ + Clock: new(mclock.Simulated), + Log: testlog.Logger(t, log.LvlTrace), + }) <-tab.initDone defer db.Close() defer tab.close() @@ -385,53 +414,75 @@ func TestTable_revalidateSyncRecord(t *testing.T) { r.Set(enr.IP(net.IP{127, 0, 0, 1})) id := enode.ID{1} n1 := wrapNode(enode.SignNull(&r, id)) - tab.addSeenNodeSync(n1) + tab.addFoundNode(n1) // Update the node record. r.Set(enr.WithEntry("foo", "bar")) n2 := enode.SignNull(&r, id) transport.updateRecord(n2) - tab.doRevalidate(make(chan struct{}, 1)) + // Wait for revalidation. We wait for the node to be revalidated two times + // in order to synchronize with the update in the able. + waitForRevalidationPing(t, transport, tab, n2.ID()) + waitForRevalidationPing(t, transport, tab, n2.ID()) + intable := tab.getNode(id) if !reflect.DeepEqual(intable, n2) { t.Fatalf("table contains old record with seq %d, want seq %d", intable.Seq(), n2.Seq()) } } -// This test checks that ENR filtering is working properly -func TestTable_filterNode(t *testing.T) { - // Create ENR filter - type eth struct { - ForkID forkid.ID - Tail []rlp.RawValue `rlp:"tail"` +func TestNodesPush(t *testing.T) { + var target enode.ID + n1 := nodeAtDistance(target, 255, intIP(1)) + n2 := nodeAtDistance(target, 254, intIP(2)) + n3 := nodeAtDistance(target, 253, intIP(3)) + perm := [][]*node{ + {n3, n2, n1}, + {n3, n1, n2}, + {n2, n3, n1}, + {n2, n1, n3}, + {n1, n3, n2}, + {n1, n2, n3}, + } + + // Insert all permutations into lists with size limit 3. + for _, nodes := range perm { + list := nodesByDistance{target: target} + for _, n := range nodes { + list.push(n, 3) + } + if !slicesEqual(list.entries, perm[0], nodeIDEqual) { + t.Fatal("not equal") + } } - enrFilter, _ := ParseEthFilter("ronin-mainnet") - - // Check test ENR record - var r1 enr.Record - r1.Set(enr.WithEntry("foo", "bar")) - if enrFilter(&r1) { - t.Fatalf("filterNode doesn't work correctly for entry") + // Insert all permutations into lists with size limit 2. + for _, nodes := range perm { + list := nodesByDistance{target: target} + for _, n := range nodes { + list.push(n, 2) + } + if !slicesEqual(list.entries, perm[0][:2], nodeIDEqual) { + t.Fatal("not equal") + } } - t.Logf("Check test ENR record - passed") +} - // Check wrong genesis ENR record - var r2 enr.Record - r2.Set(enr.WithEntry("eth", eth{ForkID: forkid.NewID(params.RoninMainnetChainConfig, params.RoninTestnetGenesisHash, uint64(0))})) - if enrFilter(&r2) { - t.Fatalf("filterNode doesn't work correctly for wrong genesis entry") - } - t.Logf("Check wrong genesis ENR record - passed") +func nodeIDEqual(n1, n2 *node) bool { + return n1.ID() == n2.ID() +} - // Check correct genesis ENR record - var r3 enr.Record - r3.Set(enr.WithEntry("eth", eth{ForkID: forkid.NewID(params.RoninMainnetChainConfig, params.RoninMainnetGenesisHash, uint64(0))})) - if !enrFilter(&r3) { - t.Fatalf("filterNode doesn't work correctly for correct genesis entry") +func slicesEqual[T any](s1, s2 []T, check func(e1, e2 T) bool) bool { + if len(s1) != len(s2) { + return false + } + for i := range s1 { + if !check(s1[i], s2[i]) { + return false + } } - t.Logf("Check correct genesis ENR record - passed") + return true } // gen wraps quick.Value so it's easier to use. diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index 5da68e72e..59045bf2a 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -24,11 +24,12 @@ import ( "fmt" "math/rand" "net" - "sort" + "slices" "sync" + "sync/atomic" + "time" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" ) @@ -41,17 +42,24 @@ func init() { nullNode = enode.SignNull(&r, enode.ID{}) } -func newTestTable(t transport) (*Table, *enode.DB) { - db, _ := enode.OpenDB("") - tab, _ := newTable(t, db, nil, log.Root(), nil) +func newTestTable(t transport, cfg Config) (*Table, *enode.DB) { + tab, db := newInactiveTestTable(t, cfg) go tab.loop() return tab, db } +// newInactiveTestTable creates a Table without running the main loop. +func newInactiveTestTable(t transport, cfg Config) (*Table, *enode.DB) { + db, _ := enode.OpenDB("") + tab, _ := newTable(t, db, cfg) + return tab, db +} + // nodeAtDistance creates a node for which enode.LogDist(base, n.id) == ld. func nodeAtDistance(base enode.ID, ld int, ip net.IP) *node { var r enr.Record r.Set(enr.IP(ip)) + r.Set(enr.UDP(30303)) return wrapNode(enode.SignNull(&r, idAtDistance(base, ld))) } @@ -97,28 +105,37 @@ func intIP(i int) net.IP { } // fillBucket inserts nodes into the given bucket until it is full. -func fillBucket(tab *Table, n *node) (last *node) { - ld := enode.LogDist(tab.self().ID(), n.ID()) - b := tab.bucket(n.ID()) +func fillBucket(tab *Table, id enode.ID) (last *node) { + ld := enode.LogDist(tab.self().ID(), id) + b := tab.bucket(id) for len(b.entries) < bucketSize { - b.entries = append(b.entries, nodeAtDistance(tab.self().ID(), ld, intIP(ld))) + node := nodeAtDistance(tab.self().ID(), ld, intIP(ld)) + if !tab.addFoundNode(node) { + panic("node not added") + } } return b.entries[bucketSize-1] } // fillTable adds nodes the table to the end of their corresponding bucket // if the bucket is not full. The caller must not hold tab.mutex. -func fillTable(tab *Table, nodes []*node) { +func fillTable(tab *Table, nodes []*node, setLive bool) { for _, n := range nodes { - tab.addSeenNodeSync(n) + if setLive { + n.livenessChecks = 1 + n.isValidatedLive = true + } + tab.addFoundNode(n) } } type pingRecorder struct { - mu sync.Mutex - dead, pinged map[enode.ID]bool - records map[enode.ID]*enode.Node - n *enode.Node + mu sync.Mutex + cond *sync.Cond + dead map[enode.ID]bool + records map[enode.ID]*enode.Node + pinged []*enode.Node + n *enode.Node } func newPingRecorder() *pingRecorder { @@ -126,16 +143,17 @@ func newPingRecorder() *pingRecorder { r.Set(enr.IP{0, 0, 0, 0}) n := enode.SignNull(&r, enode.ID{}) - return &pingRecorder{ + t := &pingRecorder{ dead: make(map[enode.ID]bool), - pinged: make(map[enode.ID]bool), records: make(map[enode.ID]*enode.Node), n: n, } + t.cond = sync.NewCond(&t.mu) + return t } -// setRecord updates a node record. Future calls to ping and -// requestENR will return this record. +// updateRecord updates a node record. Future calls to ping and +// RequestENR will return this record. func (t *pingRecorder) updateRecord(n *enode.Node) { t.mu.Lock() defer t.mu.Unlock() @@ -147,12 +165,40 @@ func (t *pingRecorder) Self() *enode.Node { return nullNode } func (t *pingRecorder) lookupSelf() []*enode.Node { return nil } func (t *pingRecorder) lookupRandom() []*enode.Node { return nil } +func (t *pingRecorder) waitPing(timeout time.Duration) *enode.Node { + t.mu.Lock() + defer t.mu.Unlock() + + // Wake up the loop on timeout. + var timedout atomic.Bool + timer := time.AfterFunc(timeout, func() { + timedout.Store(true) + t.cond.Broadcast() + }) + defer timer.Stop() + + // Wait for a ping. + for { + if timedout.Load() { + return nil + } + if len(t.pinged) > 0 { + n := t.pinged[0] + t.pinged = append(t.pinged[:0], t.pinged[1:]...) + return n + } + t.cond.Wait() + } +} + // ping simulates a ping request. func (t *pingRecorder) ping(n *enode.Node) (seq uint64, err error) { t.mu.Lock() defer t.mu.Unlock() - t.pinged[n.ID()] = true + t.pinged = append(t.pinged, n) + t.cond.Broadcast() + if t.dead[n.ID()] { return 0, errTimeout } @@ -162,7 +208,7 @@ func (t *pingRecorder) ping(n *enode.Node) (seq uint64, err error) { return seq, nil } -// requestENR simulates an ENR request. +// RequestENR simulates an ENR request. func (t *pingRecorder) RequestENR(n *enode.Node) (*enode.Node, error) { t.mu.Lock() defer t.mu.Unlock() @@ -174,7 +220,7 @@ func (t *pingRecorder) RequestENR(n *enode.Node) (*enode.Node, error) { } func hasDuplicates(slice []*node) bool { - seen := make(map[enode.ID]bool) + seen := make(map[enode.ID]bool, len(slice)) for i, e := range slice { if e == nil { panic(fmt.Sprintf("nil *Node at %d", i)) @@ -216,14 +262,14 @@ func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool { } func sortByID(nodes []*enode.Node) { - sort.Slice(nodes, func(i, j int) bool { - return string(nodes[i].ID().Bytes()) < string(nodes[j].ID().Bytes()) + slices.SortFunc(nodes, func(a, b *enode.Node) int { + return bytes.Compare(a.ID().Bytes(), b.ID().Bytes()) }) } func sortedByDistanceTo(distbase enode.ID, slice []*node) bool { - return sort.SliceIsSorted(slice, func(i, j int) bool { - return enode.DistCmp(distbase, slice[i].ID(), slice[j].ID()) < 0 + return slices.IsSortedFunc(slice, func(a, b *node) int { + return enode.DistCmp(distbase, a.ID(), b.ID()) }) } @@ -252,3 +298,57 @@ func hexEncPubkey(h string) (ret encPubkey) { copy(ret[:], b) return ret } + +type nodeEventRecorder struct { + evc chan recordedNodeEvent +} + +type recordedNodeEvent struct { + node *node + added bool +} + +func newNodeEventRecorder(buffer int) *nodeEventRecorder { + return &nodeEventRecorder{ + evc: make(chan recordedNodeEvent, buffer), + } +} + +func (set *nodeEventRecorder) nodeAdded(b *bucket, n *node) { + select { + case set.evc <- recordedNodeEvent{n, true}: + default: + panic("no space in event buffer") + } +} + +func (set *nodeEventRecorder) nodeRemoved(b *bucket, n *node) { + select { + case set.evc <- recordedNodeEvent{n, false}: + default: + panic("no space in event buffer") + } +} + +func (set *nodeEventRecorder) waitNodePresent(id enode.ID, timeout time.Duration) bool { + return set.waitNodeEvent(id, timeout, true) +} + +func (set *nodeEventRecorder) waitNodeAbsent(id enode.ID, timeout time.Duration) bool { + return set.waitNodeEvent(id, timeout, false) +} + +func (set *nodeEventRecorder) waitNodeEvent(id enode.ID, timeout time.Duration, added bool) bool { + timer := time.NewTimer(timeout) + defer timer.Stop() + for { + select { + case ev := <-set.evc: + if ev.node.ID() == id && ev.added == added { + return true + } + case <-timer.C: + return false + } + } +} diff --git a/p2p/discover/v4_lookup_test.go b/p2p/discover/v4_lookup_test.go index a00de9ca1..7a04fa6ec 100644 --- a/p2p/discover/v4_lookup_test.go +++ b/p2p/discover/v4_lookup_test.go @@ -40,7 +40,7 @@ func TestUDPv4_Lookup(t *testing.T) { } // Seed table with initial node. - fillTable(test.table, []*node{wrapNode(lookupTestnet.node(256, 0))}) + fillTable(test.table, []*node{wrapNode(lookupTestnet.node(256, 0))}, true) // Start the lookup. resultC := make(chan []*enode.Node, 1) @@ -74,7 +74,7 @@ func TestUDPv4_LookupIterator(t *testing.T) { for i := range lookupTestnet.dists[256] { bootnodes[i] = wrapNode(lookupTestnet.node(256, i)) } - fillTable(test.table, bootnodes) + fillTable(test.table, bootnodes, true) go serveTestnet(test, lookupTestnet) // Create the iterator and collect the nodes it yields. @@ -109,7 +109,7 @@ func TestUDPv4_LookupIteratorClose(t *testing.T) { for i := range lookupTestnet.dists[256] { bootnodes[i] = wrapNode(lookupTestnet.node(256, i)) } - fillTable(test.table, bootnodes) + fillTable(test.table, bootnodes, true) go serveTestnet(test, lookupTestnet) it := test.udp.RandomNodes() diff --git a/p2p/discover/v4_udp.go b/p2p/discover/v4_udp.go index 3fb61fd83..1abe93999 100644 --- a/p2p/discover/v4_udp.go +++ b/p2p/discover/v4_udp.go @@ -142,7 +142,7 @@ func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { log: cfg.Log, } - tab, err := newMeteredTable(t, ln.Database(), cfg.Bootnodes, t.log, cfg.FilterFunction) + tab, err := newMeteredTable(t, ln.Database(), cfg) if err != nil { return nil, err } @@ -165,7 +165,7 @@ func (t *UDPv4) NodesInDHT() [][]enode.Node { for i, bucket := range t.tab.buckets { nodes[i] = make([]enode.Node, len(bucket.entries)) for j, entry := range bucket.entries { - nodes[i][j] = entry.Node + nodes[i][j] = *entry.Node } } return nodes @@ -685,10 +685,10 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.I n := wrapNode(enode.NewV4(h.senderKey, from.IP, int(req.From.TCP), from.Port)) if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration { t.sendPing(fromID, from, func() { - t.tab.addVerifiedNode(n) + t.tab.addInboundNode(n) }) } else { - t.tab.addVerifiedNode(n) + t.tab.addInboundNode(n) } // Update node database and endpoint predictor. diff --git a/p2p/discover/v4_udp_test.go b/p2p/discover/v4_udp_test.go index 6a51fc563..ad48e0187 100644 --- a/p2p/discover/v4_udp_test.go +++ b/p2p/discover/v4_udp_test.go @@ -270,7 +270,7 @@ func TestUDPv4_findnode(t *testing.T) { } nodes.push(n, numCandidates) } - fillTable(test.table, nodes.entries) + fillTable(test.table, nodes.entries, false) // ensure there's a bond with the test node, // findnode won't be accepted otherwise. diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go index 4d88fb614..71b44ce3b 100644 --- a/p2p/discover/v5_udp.go +++ b/p2p/discover/v5_udp.go @@ -164,7 +164,7 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { closeCtx: closeCtx, cancelCloseCtx: cancelCloseCtx, } - tab, err := newMeteredTable(t, t.db, cfg.Bootnodes, cfg.Log, cfg.FilterFunction) + tab, err := newMeteredTable(t, t.db, cfg) if err != nil { return nil, err } @@ -652,7 +652,7 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error { } if fromNode != nil { // Handshake succeeded, add to table. - t.tab.addSeenNode(wrapNode(fromNode)) + t.tab.addInboundNode(wrapNode(fromNode)) } if packet.Kind() != v5wire.WhoareyouPacket { // WHOAREYOU logged separately to report errors. diff --git a/p2p/discover/v5_udp_test.go b/p2p/discover/v5_udp_test.go index 0290ab4e2..141eb343d 100644 --- a/p2p/discover/v5_udp_test.go +++ b/p2p/discover/v5_udp_test.go @@ -145,7 +145,7 @@ func TestUDPv5_unknownPacket(t *testing.T) { // Make node known. n := test.getNode(test.remotekey, test.remoteaddr).Node() - test.table.addSeenNodeSync(wrapNode(n)) + test.table.addFoundNode(wrapNode(n)) test.packetIn(&v5wire.Unknown{Nonce: nonce}) test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) { @@ -163,9 +163,9 @@ func TestUDPv5_findnodeHandling(t *testing.T) { nodes253 := nodesAtDistance(test.table.self().ID(), 253, 10) nodes249 := nodesAtDistance(test.table.self().ID(), 249, 4) nodes248 := nodesAtDistance(test.table.self().ID(), 248, 10) - fillTable(test.table, wrapNodes(nodes253)) - fillTable(test.table, wrapNodes(nodes249)) - fillTable(test.table, wrapNodes(nodes248)) + fillTable(test.table, wrapNodes(nodes253), true) + fillTable(test.table, wrapNodes(nodes249), true) + fillTable(test.table, wrapNodes(nodes248), true) // Requesting with distance zero should return the node's own record. test.packetIn(&v5wire.Findnode{ReqID: []byte{0}, Distances: []uint{0}}) @@ -539,7 +539,7 @@ func TestUDPv5_lookup(t *testing.T) { // Seed table with initial node. initialNode := lookupTestnet.node(256, 0) - fillTable(test.table, []*node{wrapNode(initialNode)}) + fillTable(test.table, []*node{wrapNode(initialNode)}, true) // Start the lookup. resultC := make(chan []*enode.Node, 1)