Skip to content

Commit b323cec

Browse files
committed
keep listener after erroring with invalid upstream
1 parent 2df67b4 commit b323cec

File tree

2 files changed

+137
-51
lines changed

2 files changed

+137
-51
lines changed

protocol.go

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,26 @@ package proxyproto
22

33
import (
44
"bufio"
5+
"errors"
6+
"fmt"
57
"io"
68
"net"
79
"sync"
810
"sync/atomic"
911
"time"
1012
)
1113

12-
// DefaultReadHeaderTimeout is how long header processing waits for header to
13-
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
14-
// It's kept as a global variable so to make it easier to find and override,
15-
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
16-
var DefaultReadHeaderTimeout = 10 * time.Second
14+
var (
15+
// DefaultReadHeaderTimeout is how long header processing waits for header to
16+
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
17+
// It's kept as a global variable so to make it easier to find and override,
18+
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
19+
DefaultReadHeaderTimeout = 10 * time.Second
20+
21+
// ErrInvalidUpstream should be returned when an upstream connection address
22+
// is not trusted, and therefore is invalid.
23+
ErrInvalidUpstream = fmt.Errorf("proxyproto: upstream connection address not trusted for PROXY information")
24+
)
1725

1826
// Listener is used to wrap an underlying listener,
1927
// whose connections may be using the HAProxy Proxy Protocol.
@@ -73,53 +81,61 @@ func SetReadHeaderTimeout(t time.Duration) func(*Conn) {
7381
}
7482
}
7583

76-
// Accept waits for and returns the next connection to the listener.
84+
// Accept waits for and returns the next valid connection to the listener.
7785
func (p *Listener) Accept() (net.Conn, error) {
78-
// Get the underlying connection
79-
conn, err := p.Listener.Accept()
80-
if err != nil {
81-
return nil, err
82-
}
83-
84-
proxyHeaderPolicy := USE
85-
if p.Policy != nil && p.ConnPolicy != nil {
86-
panic("only one of policy or connpolicy must be provided.")
87-
}
88-
if p.Policy != nil || p.ConnPolicy != nil {
89-
if p.Policy != nil {
90-
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
91-
} else {
92-
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
93-
Upstream: conn.RemoteAddr(),
94-
Downstream: conn.LocalAddr(),
95-
})
96-
}
86+
for {
87+
// Get the underlying connection
88+
conn, err := p.Listener.Accept()
9789
if err != nil {
98-
// can't decide the policy, we can't accept the connection
99-
conn.Close()
10090
return nil, err
10191
}
102-
// Handle a connection as a regular one
103-
if proxyHeaderPolicy == SKIP {
104-
return conn, nil
92+
93+
proxyHeaderPolicy := USE
94+
if p.Policy != nil && p.ConnPolicy != nil {
95+
panic("only one of policy or connpolicy must be provided.")
10596
}
106-
}
97+
if p.Policy != nil || p.ConnPolicy != nil {
98+
if p.Policy != nil {
99+
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
100+
} else {
101+
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
102+
Upstream: conn.RemoteAddr(),
103+
Downstream: conn.LocalAddr(),
104+
})
105+
}
106+
if err != nil {
107+
// can't decide the policy, we can't accept the connection
108+
conn.Close()
107109

108-
newConn := NewConn(
109-
conn,
110-
WithPolicy(proxyHeaderPolicy),
111-
ValidateHeader(p.ValidateHeader),
112-
)
110+
if errors.Is(err, ErrInvalidUpstream) {
111+
// keep listening for other connections
112+
continue
113+
}
113114

114-
// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
115-
if p.ReadHeaderTimeout == 0 {
116-
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
117-
}
115+
return nil, err
116+
}
117+
// Handle a connection as a regular one
118+
if proxyHeaderPolicy == SKIP {
119+
return conn, nil
120+
}
121+
}
118122

119-
// Set the readHeaderTimeout of the new conn to the value of the listener
120-
newConn.readHeaderTimeout = p.ReadHeaderTimeout
123+
newConn := NewConn(
124+
conn,
125+
WithPolicy(proxyHeaderPolicy),
126+
ValidateHeader(p.ValidateHeader),
127+
)
121128

122-
return newConn, nil
129+
// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
130+
if p.ReadHeaderTimeout == 0 {
131+
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
132+
}
133+
134+
// Set the readHeaderTimeout of the new conn to the value of the listener
135+
newConn.readHeaderTimeout = p.ReadHeaderTimeout
136+
137+
return newConn, nil
138+
}
123139
}
124140

125141
// Close closes the underlying listener.

protocol_test.go

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import (
1212
"fmt"
1313
"io"
1414
"net"
15+
"net/http"
16+
"sync/atomic"
1517
"testing"
1618
"time"
1719
)
@@ -82,7 +84,6 @@ func TestRequiredWithReadHeaderTimeout(t *testing.T) {
8284
start := time.Now()
8385

8486
l, err := net.Listen("tcp", "127.0.0.1:0")
85-
8687
if err != nil {
8788
t.Fatalf("err: %v", err)
8889
}
@@ -137,7 +138,6 @@ func TestUseWithReadHeaderTimeout(t *testing.T) {
137138
start := time.Now()
138139

139140
l, err := net.Listen("tcp", "127.0.0.1:0")
140-
141141
if err != nil {
142142
t.Fatalf("err: %v", err)
143143
}
@@ -847,6 +847,7 @@ func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) {
847847
t.Fatalf("client error: %v", err)
848848
}
849849
}
850+
850851
func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) {
851852
l, err := net.Listen("tcp", "127.0.0.1:0")
852853
if err != nil {
@@ -1274,6 +1275,67 @@ func Test_ConnectionErrorsWhenHeaderValidationFails(t *testing.T) {
12741275
}
12751276
}
12761277

1278+
func Test_ConnectionHandlesInvalidUpstreamError(t *testing.T) {
1279+
l, err := net.Listen("tcp", "localhost:8080")
1280+
if err != nil {
1281+
t.Fatalf("error creating listener: %v", err)
1282+
}
1283+
1284+
var connectionCounter atomic.Int32
1285+
1286+
newLn := &Listener{
1287+
Listener: l,
1288+
ConnPolicy: func(_ ConnPolicyOptions) (Policy, error) {
1289+
// Return the invalid upstream error on the first call, the listener
1290+
// should remain open and accepting.
1291+
times := connectionCounter.Load()
1292+
if times == 0 {
1293+
connectionCounter.Store(times + 1)
1294+
return REJECT, ErrInvalidUpstream
1295+
}
1296+
1297+
return REJECT, ErrNoProxyProtocol
1298+
},
1299+
}
1300+
1301+
// Kick off the listener and return any error via the chanel.
1302+
errCh := make(chan error)
1303+
defer close(errCh)
1304+
go func(t *testing.T) {
1305+
_, err := newLn.Accept()
1306+
errCh <- err
1307+
}(t)
1308+
1309+
// Make two calls to trigger the listener's accept, the first should experience
1310+
// the ErrInvalidUpstream and keep the listener open, the second should experience
1311+
// a different error which will cause the listener to close.
1312+
_, _ = http.Get("http://localhost:8080")
1313+
// Wait a few seconds to ensure we didn't get anything back on our channel.
1314+
select {
1315+
case err := <-errCh:
1316+
if err != nil {
1317+
t.Fatalf("invalid upstream shouldn't return an error: %v", err)
1318+
}
1319+
case <-time.After(2 * time.Second):
1320+
// No error returned (as expected, we're still listening though)
1321+
}
1322+
1323+
_, _ = http.Get("http://localhost:8080")
1324+
// Wait a few seconds before we fail the test as we should have received an
1325+
// error that was not invalid upstream.
1326+
select {
1327+
case err := <-errCh:
1328+
if err == nil {
1329+
t.Fatalf("errors other than invalid upstream should error")
1330+
}
1331+
if !errors.Is(ErrNoProxyProtocol, err) {
1332+
t.Fatalf("unexpected error type: %v", err)
1333+
}
1334+
case <-time.After(2 * time.Second):
1335+
t.Fatalf("timed out waiting for listener")
1336+
}
1337+
}
1338+
12771339
type TestTLSServer struct {
12781340
Listener net.Listener
12791341

@@ -1482,9 +1544,11 @@ func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
14821544
b, err := io.ReadAll(r)
14831545
return int64(len(b)), err
14841546
}
1547+
14851548
func (c *testConn) Write(p []byte) (int, error) {
14861549
return len(p), nil
14871550
}
1551+
14881552
func (c *testConn) Read(p []byte) (int, error) {
14891553
if c.reads == 0 {
14901554
return 0, io.EOF
@@ -1533,7 +1597,7 @@ func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
15331597
}
15341598

15351599
func benchmarkTCPProxy(size int, b *testing.B) {
1536-
//create and start the echo backend
1600+
// create and start the echo backend
15371601
backend, err := net.Listen("tcp", "127.0.0.1:0")
15381602
if err != nil {
15391603
b.Fatalf("err: %v", err)
@@ -1554,7 +1618,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
15541618
}
15551619
}()
15561620

1557-
//start the proxyprotocol enabled tcp proxy
1621+
// start the proxyprotocol enabled tcp proxy
15581622
l, err := net.Listen("tcp", "127.0.0.1:0")
15591623
if err != nil {
15601624
b.Fatalf("err: %v", err)
@@ -1603,7 +1667,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
16031667
},
16041668
}
16051669

1606-
//now for the actual benchmark
1670+
// now for the actual benchmark
16071671
b.ResetTimer()
16081672
for n := 0; n < b.N; n++ {
16091673
conn, err := net.Dial("tcp", pl.Addr().String())
@@ -1614,16 +1678,15 @@ func benchmarkTCPProxy(size int, b *testing.B) {
16141678
if _, err := header.WriteTo(conn); err != nil {
16151679
b.Fatalf("err: %v", err)
16161680
}
1617-
//send data
1681+
// send data
16181682
go func() {
16191683
_, err = conn.Write(data)
16201684
_ = conn.(*net.TCPConn).CloseWrite()
16211685
if err != nil {
16221686
panic(fmt.Sprintf("Failed to write data: %v", err))
16231687
}
1624-
16251688
}()
1626-
//receive data
1689+
// receive data
16271690
n, err := io.Copy(io.Discard, conn)
16281691
if n != int64(len(data)) {
16291692
b.Fatalf("Expected to receive %d bytes, got %d", len(data), n)
@@ -1638,24 +1701,31 @@ func benchmarkTCPProxy(size int, b *testing.B) {
16381701
func BenchmarkTCPProxy16KB(b *testing.B) {
16391702
benchmarkTCPProxy(16*1024, b)
16401703
}
1704+
16411705
func BenchmarkTCPProxy32KB(b *testing.B) {
16421706
benchmarkTCPProxy(32*1024, b)
16431707
}
1708+
16441709
func BenchmarkTCPProxy64KB(b *testing.B) {
16451710
benchmarkTCPProxy(64*1024, b)
16461711
}
1712+
16471713
func BenchmarkTCPProxy128KB(b *testing.B) {
16481714
benchmarkTCPProxy(128*1024, b)
16491715
}
1716+
16501717
func BenchmarkTCPProxy256KB(b *testing.B) {
16511718
benchmarkTCPProxy(256*1024, b)
16521719
}
1720+
16531721
func BenchmarkTCPProxy512KB(b *testing.B) {
16541722
benchmarkTCPProxy(512*1024, b)
16551723
}
1724+
16561725
func BenchmarkTCPProxy1024KB(b *testing.B) {
16571726
benchmarkTCPProxy(1024*1024, b)
16581727
}
1728+
16591729
func BenchmarkTCPProxy2048KB(b *testing.B) {
16601730
benchmarkTCPProxy(2048*1024, b)
16611731
}

0 commit comments

Comments
 (0)