Skip to content

Commit b918f42

Browse files
committed
extract messaging components from IpfsDHT into its own struct. create a new struct that manages sending DHT messages that can be used independently from the DHT.
1 parent 98c5089 commit b918f42

File tree

9 files changed

+233
-168
lines changed

9 files changed

+233
-168
lines changed

dht.go

Lines changed: 3 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package dht
22

33
import (
4-
"bytes"
54
"context"
65
"errors"
76
"fmt"
@@ -32,7 +31,6 @@ import (
3231
goprocessctx "github.com/jbenet/goprocess/context"
3332
"github.com/multiformats/go-base32"
3433
ma "github.com/multiformats/go-multiaddr"
35-
"github.com/multiformats/go-multihash"
3634
"go.opencensus.io/tag"
3735
"go.uber.org/zap"
3836
)
@@ -101,8 +99,7 @@ type IpfsDHT struct {
10199
ctx context.Context
102100
proc goprocess.Process
103101

104-
strmap map[peer.ID]*messageSender
105-
smlk sync.Mutex
102+
protoMessenger *ProtocolMessenger
106103

107104
plk sync.Mutex
108105

@@ -183,6 +180,7 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error)
183180
dht.enableValues = cfg.enableValues
184181

185182
dht.Validator = cfg.validator
183+
dht.protoMessenger = NewProtocolMessenger(dht.host, dht.protocols, dht.Validator)
186184

187185
dht.auto = cfg.mode
188186
switch cfg.mode {
@@ -273,7 +271,6 @@ func makeDHT(ctx context.Context, h host.Host, cfg config) (*IpfsDHT, error) {
273271
selfKey: kb.ConvertPeerID(h.ID()),
274272
peerstore: h.Peerstore(),
275273
host: h,
276-
strmap: make(map[peer.ID]*messageSender),
277274
birth: time.Now(),
278275
protocols: protocols,
279276
protocolsStrs: protocol.ConvertToStrings(protocols),
@@ -477,67 +474,8 @@ func (dht *IpfsDHT) persistRTPeersInPeerStore() {
477474
}
478475
}
479476

480-
// putValueToPeer stores the given key/value pair at the peer 'p'
481-
func (dht *IpfsDHT) putValueToPeer(ctx context.Context, p peer.ID, rec *recpb.Record) error {
482-
pmes := pb.NewMessage(pb.Message_PUT_VALUE, rec.Key, 0)
483-
pmes.Record = rec
484-
rpmes, err := dht.sendRequest(ctx, p, pmes)
485-
if err != nil {
486-
logger.Debugw("failed to put value to peer", "to", p, "key", loggableKeyBytes(rec.Key), "error", err)
487-
return err
488-
}
489-
490-
if !bytes.Equal(rpmes.GetRecord().Value, pmes.GetRecord().Value) {
491-
logger.Infow("value not put correctly", "put-message", pmes, "get-message", rpmes)
492-
return errors.New("value not put correctly")
493-
}
494-
495-
return nil
496-
}
497-
498477
var errInvalidRecord = errors.New("received invalid record")
499478

500-
// getValueOrPeers queries a particular peer p for the value for
501-
// key. It returns either the value or a list of closer peers.
502-
// NOTE: It will update the dht's peerstore with any new addresses
503-
// it finds for the given peer.
504-
func (dht *IpfsDHT) getValueOrPeers(ctx context.Context, p peer.ID, key string) (*recpb.Record, []*peer.AddrInfo, error) {
505-
pmes, err := dht.getValueSingle(ctx, p, key)
506-
if err != nil {
507-
return nil, nil, err
508-
}
509-
510-
// Perhaps we were given closer peers
511-
peers := pb.PBPeersToPeerInfos(pmes.GetCloserPeers())
512-
513-
if record := pmes.GetRecord(); record != nil {
514-
// Success! We were given the value
515-
logger.Debug("got value")
516-
517-
// make sure record is valid.
518-
err = dht.Validator.Validate(string(record.GetKey()), record.GetValue())
519-
if err != nil {
520-
logger.Debug("received invalid record (discarded)")
521-
// return a sentinal to signify an invalid record was received
522-
err = errInvalidRecord
523-
record = new(recpb.Record)
524-
}
525-
return record, peers, err
526-
}
527-
528-
if len(peers) > 0 {
529-
return nil, peers, nil
530-
}
531-
532-
return nil, nil, routing.ErrNotFound
533-
}
534-
535-
// getValueSingle simply performs the get value RPC with the given parameters
536-
func (dht *IpfsDHT) getValueSingle(ctx context.Context, p peer.ID, key string) (*pb.Message, error) {
537-
pmes := pb.NewMessage(pb.Message_GET_VALUE, []byte(key), 0)
538-
return dht.sendRequest(ctx, p, pmes)
539-
}
540-
541479
// getLocal attempts to retrieve the value from the datastore
542480
func (dht *IpfsDHT) getLocal(key string) (*recpb.Record, error) {
543481
logger.Debugw("finding value in datastore", "key", loggableKeyString(key))
@@ -627,17 +565,6 @@ func (dht *IpfsDHT) FindLocal(id peer.ID) peer.AddrInfo {
627565
}
628566
}
629567

630-
// findPeerSingle asks peer 'p' if they know where the peer with id 'id' is
631-
func (dht *IpfsDHT) findPeerSingle(ctx context.Context, p peer.ID, id peer.ID) (*pb.Message, error) {
632-
pmes := pb.NewMessage(pb.Message_FIND_NODE, []byte(id), 0)
633-
return dht.sendRequest(ctx, p, pmes)
634-
}
635-
636-
func (dht *IpfsDHT) findProvidersSingle(ctx context.Context, p peer.ID, key multihash.Multihash) (*pb.Message, error) {
637-
pmes := pb.NewMessage(pb.Message_GET_PROVIDERS, key, 0)
638-
return dht.sendRequest(ctx, p, pmes)
639-
}
640-
641568
// nearestPeersToQuery returns the routing tables closest peers.
642569
func (dht *IpfsDHT) nearestPeersToQuery(pmes *pb.Message, count int) []peer.ID {
643570
closer := dht.routingTable.NearestPeers(kb.ConvertKey(string(pmes.GetKey())), count)
@@ -778,15 +705,7 @@ func (dht *IpfsDHT) Host() host.Host {
778705

779706
// Ping sends a ping message to the passed peer and waits for a response.
780707
func (dht *IpfsDHT) Ping(ctx context.Context, p peer.ID) error {
781-
req := pb.NewMessage(pb.Message_PING, nil, 0)
782-
resp, err := dht.sendRequest(ctx, p, req)
783-
if err != nil {
784-
return fmt.Errorf("sending request: %w", err)
785-
}
786-
if resp.Type != pb.Message_PING {
787-
return fmt.Errorf("got unexpected response type: %v", resp.Type)
788-
}
789-
return nil
708+
return dht.protoMessenger.Ping(ctx, p)
790709
}
791710

792711
// newContextWithLocalTags returns a new context.Context with the InstanceID and

dht_net.go

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ import (
99
"time"
1010

1111
"github.com/libp2p/go-libp2p-core/helpers"
12+
"github.com/libp2p/go-libp2p-core/host"
1213
"github.com/libp2p/go-libp2p-core/network"
1314
"github.com/libp2p/go-libp2p-core/peer"
15+
"github.com/libp2p/go-libp2p-core/protocol"
1416

1517
"github.com/libp2p/go-libp2p-kad-dht/metrics"
1618
pb "github.com/libp2p/go-libp2p-kad-dht/pb"
@@ -208,12 +210,38 @@ func (dht *IpfsDHT) handleNewMessage(s network.Stream) bool {
208210
}
209211
}
210212

213+
type messageManager struct {
214+
host host.Host // the network services we need
215+
strmap map[peer.ID]*messageSender
216+
smlk sync.Mutex
217+
protocols []protocol.ID
218+
}
219+
220+
func (m *messageManager) streamDisconnect(ctx context.Context, p peer.ID) {
221+
m.smlk.Lock()
222+
defer m.smlk.Unlock()
223+
ms, ok := m.strmap[p]
224+
if !ok {
225+
return
226+
}
227+
delete(m.strmap, p)
228+
229+
// Do this asynchronously as ms.lk can block for a while.
230+
go func() {
231+
if err := ms.lk.Lock(ctx); err != nil {
232+
return
233+
}
234+
defer ms.lk.Unlock()
235+
ms.invalidate()
236+
}()
237+
}
238+
211239
// sendRequest sends out a request, but also makes sure to
212240
// measure the RTT for latency measurements.
213-
func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) {
241+
func (m *messageManager) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) {
214242
ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes))
215243

216-
ms, err := dht.messageSenderForPeer(ctx, p)
244+
ms, err := m.messageSenderForPeer(ctx, p)
217245
if err != nil {
218246
stats.Record(ctx,
219247
metrics.SentRequests.M(1),
@@ -240,15 +268,15 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message
240268
metrics.SentBytes.M(int64(pmes.Size())),
241269
metrics.OutboundRequestLatency.M(float64(time.Since(start))/float64(time.Millisecond)),
242270
)
243-
dht.peerstore.RecordLatency(p, time.Since(start))
271+
m.host.Peerstore().RecordLatency(p, time.Since(start))
244272
return rpmes, nil
245273
}
246274

247275
// sendMessage sends out a message
248-
func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error {
276+
func (m *messageManager) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error {
249277
ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes))
250278

251-
ms, err := dht.messageSenderForPeer(ctx, p)
279+
ms, err := m.messageSenderForPeer(ctx, p)
252280
if err != nil {
253281
stats.Record(ctx,
254282
metrics.SentMessages.M(1),
@@ -274,30 +302,30 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message
274302
return nil
275303
}
276304

277-
func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) {
278-
dht.smlk.Lock()
279-
ms, ok := dht.strmap[p]
305+
func (m *messageManager) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) {
306+
m.smlk.Lock()
307+
ms, ok := m.strmap[p]
280308
if ok {
281-
dht.smlk.Unlock()
309+
m.smlk.Unlock()
282310
return ms, nil
283311
}
284-
ms = &messageSender{p: p, dht: dht, lk: newCtxMutex()}
285-
dht.strmap[p] = ms
286-
dht.smlk.Unlock()
312+
ms = &messageSender{p: p, m: m, lk: newCtxMutex()}
313+
m.strmap[p] = ms
314+
m.smlk.Unlock()
287315

288316
if err := ms.prepOrInvalidate(ctx); err != nil {
289-
dht.smlk.Lock()
290-
defer dht.smlk.Unlock()
317+
m.smlk.Lock()
318+
defer m.smlk.Unlock()
291319

292-
if msCur, ok := dht.strmap[p]; ok {
320+
if msCur, ok := m.strmap[p]; ok {
293321
// Changed. Use the new one, old one is invalid and
294322
// not in the map so we can just throw it away.
295323
if ms != msCur {
296324
return msCur, nil
297325
}
298326
// Not changed, remove the now invalid stream from the
299327
// map.
300-
delete(dht.strmap, p)
328+
delete(m.strmap, p)
301329
}
302330
// Invalid but not in map. Must have been removed by a disconnect.
303331
return nil, err
@@ -307,11 +335,11 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
307335
}
308336

309337
type messageSender struct {
310-
s network.Stream
311-
r msgio.ReadCloser
312-
lk ctxMutex
313-
p peer.ID
314-
dht *IpfsDHT
338+
s network.Stream
339+
r msgio.ReadCloser
340+
lk ctxMutex
341+
p peer.ID
342+
m *messageManager
315343

316344
invalid bool
317345
singleMes int
@@ -352,7 +380,7 @@ func (ms *messageSender) prep(ctx context.Context) error {
352380
// We only want to speak to peers using our primary protocols. We do not want to query any peer that only speaks
353381
// one of the secondary "server" protocols that we happen to support (e.g. older nodes that we can respond to for
354382
// backwards compatibility reasons).
355-
nstr, err := ms.dht.host.NewStream(ctx, ms.p, ms.dht.protocols...)
383+
nstr, err := ms.m.host.NewStream(ctx, ms.p, ms.m.protocols...)
356384
if err != nil {
357385
return err
358386
}

dht_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -570,14 +570,14 @@ func TestInvalidMessageSenderTracking(t *testing.T) {
570570
defer dht.Close()
571571

572572
foo := peer.ID("asdasd")
573-
_, err := dht.messageSenderForPeer(ctx, foo)
573+
_, err := dht.protoMessenger.m.messageSenderForPeer(ctx, foo)
574574
if err == nil {
575575
t.Fatal("that shouldnt have succeeded")
576576
}
577577

578-
dht.smlk.Lock()
579-
mscnt := len(dht.strmap)
580-
dht.smlk.Unlock()
578+
dht.protoMessenger.m.smlk.Lock()
579+
mscnt := len(dht.protoMessenger.m.strmap)
580+
dht.protoMessenger.m.smlk.Unlock()
581581

582582
if mscnt > 0 {
583583
t.Fatal("should have no message senders in map")

0 commit comments

Comments
 (0)