Skip to content

Commit ca3d29f

Browse files
committed
fix(sampledconn): Correctly handle slow bytes and closed conns
1 parent 9024f8e commit ca3d29f

File tree

2 files changed

+109
-15
lines changed

2 files changed

+109
-15
lines changed

p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
manet "github.com/multiformats/go-multiaddr/net"
1111

1212
"github.com/stretchr/testify/assert"
13+
"github.com/stretchr/testify/require"
1314
)
1415

1516
func TestSampledConn(t *testing.T) {
@@ -76,3 +77,102 @@ func TestSampledConn(t *testing.T) {
7677
})
7778
}
7879
}
80+
81+
func spawnServerAndClientConn(t *testing.T) (serverConn manet.Conn, clientConn manet.Conn) {
82+
serverConnChan := make(chan manet.Conn, 1)
83+
84+
listener, err := manet.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
85+
assert.NoError(t, err)
86+
defer listener.Close()
87+
88+
serverAddr := listener.Multiaddr()
89+
90+
// Server goroutine
91+
go func() {
92+
conn, err := listener.Accept()
93+
assert.NoError(t, err)
94+
serverConnChan <- conn
95+
}()
96+
97+
// Give the server a moment to start
98+
time.Sleep(100 * time.Millisecond)
99+
100+
// Create a TCP client
101+
clientConn, err = manet.Dial(serverAddr)
102+
assert.NoError(t, err)
103+
104+
return <-serverConnChan, clientConn
105+
}
106+
107+
func TestHandleNoBytes(t *testing.T) {
108+
serverConn, clientConn := spawnServerAndClientConn(t)
109+
defer clientConn.Close()
110+
111+
// Server goroutine
112+
go func() {
113+
serverConn.Close()
114+
}()
115+
_, _, err := PeekBytes(clientConn.(interface {
116+
manet.Conn
117+
syscall.Conn
118+
}))
119+
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
120+
}
121+
122+
func TestHandle1ByteAndClose(t *testing.T) {
123+
serverConn, clientConn := spawnServerAndClientConn(t)
124+
defer clientConn.Close()
125+
126+
// Server goroutine
127+
go func() {
128+
defer serverConn.Close()
129+
_, err := serverConn.Write([]byte("h"))
130+
assert.NoError(t, err)
131+
}()
132+
_, _, err := PeekBytes(clientConn.(interface {
133+
manet.Conn
134+
syscall.Conn
135+
}))
136+
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
137+
}
138+
139+
func TestSlowBytes(t *testing.T) {
140+
serverConn, clientConn := spawnServerAndClientConn(t)
141+
142+
interval := 100 * time.Millisecond
143+
144+
// Server goroutine
145+
go func() {
146+
defer serverConn.Close()
147+
148+
time.Sleep(interval)
149+
_, err := serverConn.Write([]byte("h"))
150+
assert.NoError(t, err)
151+
time.Sleep(interval)
152+
_, err = serverConn.Write([]byte("e"))
153+
assert.NoError(t, err)
154+
time.Sleep(interval)
155+
_, err = serverConn.Write([]byte("l"))
156+
assert.NoError(t, err)
157+
time.Sleep(interval)
158+
_, err = serverConn.Write([]byte("lo"))
159+
assert.NoError(t, err)
160+
}()
161+
162+
defer clientConn.Close()
163+
164+
err := clientConn.SetReadDeadline(time.Now().Add(interval * 10))
165+
require.NoError(t, err)
166+
167+
peeked, clientConn, err := PeekBytes(clientConn.(interface {
168+
manet.Conn
169+
syscall.Conn
170+
}))
171+
assert.NoError(t, err)
172+
assert.Equal(t, "hel", string(peeked[:]))
173+
174+
buf := make([]byte, 5)
175+
_, err = io.ReadFull(clientConn, buf)
176+
assert.NoError(t, err)
177+
assert.Equal(t, "hello", string(buf))
178+
}

p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package sampledconn
44

55
import (
66
"errors"
7+
"io"
78
"syscall"
89
)
910

@@ -15,27 +16,20 @@ func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) {
1516
return s, err
1617
}
1718

18-
readBytes := 0
1919
var readErr error
20+
var n int
2021
err = rawConn.Read(func(fd uintptr) bool {
21-
for readBytes < peekSize {
22-
var n int
23-
n, _, readErr = syscall.Recvfrom(int(fd), s[readBytes:], syscall.MSG_PEEK)
24-
if errors.Is(readErr, syscall.EAGAIN) {
25-
return false
26-
}
27-
if readErr != nil {
28-
return true
29-
}
30-
readBytes += n
31-
}
32-
return true
22+
n, _, readErr = syscall.Recvfrom(int(fd), s[:], syscall.MSG_PEEK|syscall.MSG_WAITALL)
23+
return !errors.Is(readErr, syscall.EAGAIN)
3324
})
25+
if err != nil {
26+
return s, err
27+
}
3428
if readErr != nil {
3529
return s, readErr
3630
}
37-
if err != nil {
38-
return s, err
31+
if n < peekSize {
32+
return s, io.ErrUnexpectedEOF
3933
}
4034

4135
return s, nil

0 commit comments

Comments
 (0)