Skip to content

Commit 994b69e

Browse files
committed
fix: basichost: Use NegotiationTimeout as fallback timeout for NewStream (#3020)
1 parent 83d458c commit 994b69e

File tree

2 files changed

+66
-3
lines changed

2 files changed

+66
-3
lines changed

p2p/host/basic/basic_host.go

+12-3
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,10 @@ type HostOpts struct {
122122
// MultistreamMuxer is essential for the *BasicHost and will use a sensible default value if omitted.
123123
MultistreamMuxer *msmux.MultistreamMuxer[protocol.ID]
124124

125-
// NegotiationTimeout determines the read and write timeouts on streams.
126-
// If 0 or omitted, it will use DefaultNegotiationTimeout.
127-
// If below 0, timeouts on streams will be deactivated.
125+
// NegotiationTimeout determines the read and write timeouts when negotiating
126+
// protocols for streams. If 0 or omitted, it will use
127+
// DefaultNegotiationTimeout. If below 0, timeouts on streams will be
128+
// deactivated.
128129
NegotiationTimeout time.Duration
129130

130131
// AddrsFactory holds a function which can be used to override or filter the result of Addrs.
@@ -689,6 +690,14 @@ func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) {
689690
// to create one. If ProtocolID is "", writes no header.
690691
// (Thread-safe)
691692
func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (str network.Stream, strErr error) {
693+
if _, ok := ctx.Deadline(); !ok {
694+
if h.negtimeout > 0 {
695+
var cancel context.CancelFunc
696+
ctx, cancel = context.WithTimeout(ctx, h.negtimeout)
697+
defer cancel()
698+
}
699+
}
700+
692701
// If the caller wants to prevent the host from dialing, it should use the NoDial option.
693702
if nodial, _ := network.GetNoDial(ctx); !nodial {
694703
err := h.Connect(ctx, peer.AddrInfo{ID: p})

p2p/host/basic/basic_host_test.go

+54
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package basichost
22

33
import (
44
"context"
5+
"encoding/binary"
56
"fmt"
67
"io"
78
"reflect"
@@ -941,3 +942,56 @@ func TestTrimHostAddrList(t *testing.T) {
941942
})
942943
}
943944
}
945+
946+
func TestHostTimeoutNewStream(t *testing.T) {
947+
h1, err := NewHost(swarmt.GenSwarm(t), nil)
948+
require.NoError(t, err)
949+
h1.Start()
950+
defer h1.Close()
951+
952+
const proto = "/testing"
953+
h2 := swarmt.GenSwarm(t)
954+
955+
h2.SetStreamHandler(func(s network.Stream) {
956+
// First message is multistream header. Just echo it
957+
msHeader := []byte("\x19/multistream/1.0.0\n")
958+
_, err := s.Read(msHeader)
959+
assert.NoError(t, err)
960+
_, err = s.Write(msHeader)
961+
assert.NoError(t, err)
962+
963+
buf := make([]byte, 1024)
964+
n, err := s.Read(buf)
965+
assert.NoError(t, err)
966+
967+
msgLen, varintN := binary.Uvarint(buf[:n])
968+
buf = buf[varintN:]
969+
proto := buf[:int(msgLen)]
970+
if string(proto) == "/ipfs/id/1.0.0\n" {
971+
// Signal we don't support identify
972+
na := []byte("na\n")
973+
n := binary.PutUvarint(buf, uint64(len(na)))
974+
copy(buf[n:], na)
975+
976+
_, err = s.Write(buf[:int(n)+len(na)])
977+
assert.NoError(t, err)
978+
} else {
979+
// Stall
980+
time.Sleep(5 * time.Second)
981+
}
982+
t.Log("Resetting")
983+
s.Reset()
984+
})
985+
986+
err = h1.Connect(context.Background(), peer.AddrInfo{
987+
ID: h2.LocalPeer(),
988+
Addrs: h2.ListenAddresses(),
989+
})
990+
require.NoError(t, err)
991+
992+
// No context passed in, fallback to negtimeout
993+
h1.negtimeout = time.Second
994+
_, err = h1.NewStream(context.Background(), h2.LocalPeer(), proto)
995+
require.Error(t, err)
996+
require.ErrorContains(t, err, "context deadline exceeded")
997+
}

0 commit comments

Comments
 (0)