Skip to content

Commit 10741be

Browse files
committed
keep listener after erroring with invalid upstream
1 parent 9814f02 commit 10741be

File tree

2 files changed

+137
-51
lines changed

2 files changed

+137
-51
lines changed

protocol.go

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,26 @@ package proxyproto
22

33
import (
44
"bufio"
5+
"errors"
6+
"fmt"
57
"io"
68
"net"
79
"sync"
810
"sync/atomic"
911
"time"
1012
)
1113

12-
// DefaultReadHeaderTimeout is how long header processing waits for header to
13-
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
14-
// It's kept as a global variable so to make it easier to find and override,
15-
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
16-
var DefaultReadHeaderTimeout = 10 * time.Second
14+
var (
15+
// DefaultReadHeaderTimeout is how long header processing waits for header to
16+
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
17+
// It's kept as a global variable so to make it easier to find and override,
18+
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
19+
DefaultReadHeaderTimeout = 10 * time.Second
20+
21+
// ErrInvalidUpstream should be returned when an upstream connection address
22+
// is not trusted, and therefore is invalid.
23+
ErrInvalidUpstream = fmt.Errorf("proxyproto: upstream connection address not trusted for PROXY information")
24+
)
1725

1826
// Listener is used to wrap an underlying listener,
1927
// whose connections may be using the HAProxy Proxy Protocol.
@@ -72,53 +80,61 @@ func SetReadHeaderTimeout(t time.Duration) func(*Conn) {
7280
}
7381
}
7482

75-
// Accept waits for and returns the next connection to the listener.
83+
// Accept waits for and returns the next valid connection to the listener.
7684
func (p *Listener) Accept() (net.Conn, error) {
77-
// Get the underlying connection
78-
conn, err := p.Listener.Accept()
79-
if err != nil {
80-
return nil, err
81-
}
82-
83-
proxyHeaderPolicy := USE
84-
if p.Policy != nil && p.ConnPolicy != nil {
85-
panic("only one of policy or connpolicy must be provided.")
86-
}
87-
if p.Policy != nil || p.ConnPolicy != nil {
88-
if p.Policy != nil {
89-
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
90-
} else {
91-
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
92-
Upstream: conn.RemoteAddr(),
93-
Downstream: conn.LocalAddr(),
94-
})
95-
}
85+
for {
86+
// Get the underlying connection
87+
conn, err := p.Listener.Accept()
9688
if err != nil {
97-
// can't decide the policy, we can't accept the connection
98-
conn.Close()
9989
return nil, err
10090
}
101-
// Handle a connection as a regular one
102-
if proxyHeaderPolicy == SKIP {
103-
return conn, nil
91+
92+
proxyHeaderPolicy := USE
93+
if p.Policy != nil && p.ConnPolicy != nil {
94+
panic("only one of policy or connpolicy must be provided.")
10495
}
105-
}
96+
if p.Policy != nil || p.ConnPolicy != nil {
97+
if p.Policy != nil {
98+
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
99+
} else {
100+
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
101+
Upstream: conn.RemoteAddr(),
102+
Downstream: conn.LocalAddr(),
103+
})
104+
}
105+
if err != nil {
106+
// can't decide the policy, we can't accept the connection
107+
conn.Close()
106108

107-
newConn := NewConn(
108-
conn,
109-
WithPolicy(proxyHeaderPolicy),
110-
ValidateHeader(p.ValidateHeader),
111-
)
109+
if errors.Is(err, ErrInvalidUpstream) {
110+
// keep listening for other connections
111+
continue
112+
}
112113

113-
// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
114-
if p.ReadHeaderTimeout == 0 {
115-
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
116-
}
114+
return nil, err
115+
}
116+
// Handle a connection as a regular one
117+
if proxyHeaderPolicy == SKIP {
118+
return conn, nil
119+
}
120+
}
117121

118-
// Set the readHeaderTimeout of the new conn to the value of the listener
119-
newConn.readHeaderTimeout = p.ReadHeaderTimeout
122+
newConn := NewConn(
123+
conn,
124+
WithPolicy(proxyHeaderPolicy),
125+
ValidateHeader(p.ValidateHeader),
126+
)
120127

121-
return newConn, nil
128+
// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
129+
if p.ReadHeaderTimeout == 0 {
130+
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
131+
}
132+
133+
// Set the readHeaderTimeout of the new conn to the value of the listener
134+
newConn.readHeaderTimeout = p.ReadHeaderTimeout
135+
136+
return newConn, nil
137+
}
122138
}
123139

124140
// Close closes the underlying listener.

protocol_test.go

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"io"
1414
"io/ioutil"
1515
"net"
16+
"net/http"
17+
"sync/atomic"
1618
"testing"
1719
"time"
1820
)
@@ -83,7 +85,6 @@ func TestRequiredWithReadHeaderTimeout(t *testing.T) {
8385
start := time.Now()
8486

8587
l, err := net.Listen("tcp", "127.0.0.1:0")
86-
8788
if err != nil {
8889
t.Fatalf("err: %v", err)
8990
}
@@ -138,7 +139,6 @@ func TestUseWithReadHeaderTimeout(t *testing.T) {
138139
start := time.Now()
139140

140141
l, err := net.Listen("tcp", "127.0.0.1:0")
141-
142142
if err != nil {
143143
t.Fatalf("err: %v", err)
144144
}
@@ -848,6 +848,7 @@ func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) {
848848
t.Fatalf("client error: %v", err)
849849
}
850850
}
851+
851852
func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) {
852853
l, err := net.Listen("tcp", "127.0.0.1:0")
853854
if err != nil {
@@ -1275,6 +1276,67 @@ func Test_ConnectionErrorsWhenHeaderValidationFails(t *testing.T) {
12751276
}
12761277
}
12771278

1279+
func Test_ConnectionHandlesInvalidUpstreamError(t *testing.T) {
1280+
l, err := net.Listen("tcp", "localhost:8080")
1281+
if err != nil {
1282+
t.Fatalf("error creating listener: %v", err)
1283+
}
1284+
1285+
var connectionCounter atomic.Int32
1286+
1287+
newLn := &Listener{
1288+
Listener: l,
1289+
ConnPolicy: func(_ ConnPolicyOptions) (Policy, error) {
1290+
// Return the invalid upstream error on the first call, the listener
1291+
// should remain open and accepting.
1292+
times := connectionCounter.Load()
1293+
if times == 0 {
1294+
connectionCounter.Store(times + 1)
1295+
return REJECT, ErrInvalidUpstream
1296+
}
1297+
1298+
return REJECT, ErrNoProxyProtocol
1299+
},
1300+
}
1301+
1302+
// Kick off the listener and return any error via the chanel.
1303+
errCh := make(chan error)
1304+
defer close(errCh)
1305+
go func(t *testing.T) {
1306+
_, err := newLn.Accept()
1307+
errCh <- err
1308+
}(t)
1309+
1310+
// Make two calls to trigger the listener's accept, the first should experience
1311+
// the ErrInvalidUpstream and keep the listener open, the second should experience
1312+
// a different error which will cause the listener to close.
1313+
_, _ = http.Get("http://localhost:8080")
1314+
// Wait a few seconds to ensure we didn't get anything back on our channel.
1315+
select {
1316+
case err := <-errCh:
1317+
if err != nil {
1318+
t.Fatalf("invalid upstream shouldn't return an error: %v", err)
1319+
}
1320+
case <-time.After(2 * time.Second):
1321+
// No error returned (as expected, we're still listening though)
1322+
}
1323+
1324+
_, _ = http.Get("http://localhost:8080")
1325+
// Wait a few seconds before we fail the test as we should have received an
1326+
// error that was not invalid upstream.
1327+
select {
1328+
case err := <-errCh:
1329+
if err == nil {
1330+
t.Fatalf("errors other than invalid upstream should error")
1331+
}
1332+
if !errors.Is(ErrNoProxyProtocol, err) {
1333+
t.Fatalf("unexpected error type: %v", err)
1334+
}
1335+
case <-time.After(2 * time.Second):
1336+
t.Fatalf("timed out waiting for listener")
1337+
}
1338+
}
1339+
12781340
type TestTLSServer struct {
12791341
Listener net.Listener
12801342

@@ -1483,9 +1545,11 @@ func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
14831545
b, err := ioutil.ReadAll(r)
14841546
return int64(len(b)), err
14851547
}
1548+
14861549
func (c *testConn) Write(p []byte) (int, error) {
14871550
return len(p), nil
14881551
}
1552+
14891553
func (c *testConn) Read(p []byte) (int, error) {
14901554
if c.reads == 0 {
14911555
return 0, io.EOF
@@ -1534,7 +1598,7 @@ func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
15341598
}
15351599

15361600
func benchmarkTCPProxy(size int, b *testing.B) {
1537-
//create and start the echo backend
1601+
// create and start the echo backend
15381602
backend, err := net.Listen("tcp", "127.0.0.1:0")
15391603
if err != nil {
15401604
b.Fatalf("err: %v", err)
@@ -1555,7 +1619,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
15551619
}
15561620
}()
15571621

1558-
//start the proxyprotocol enabled tcp proxy
1622+
// start the proxyprotocol enabled tcp proxy
15591623
l, err := net.Listen("tcp", "127.0.0.1:0")
15601624
if err != nil {
15611625
b.Fatalf("err: %v", err)
@@ -1604,7 +1668,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
16041668
},
16051669
}
16061670

1607-
//now for the actual benchmark
1671+
// now for the actual benchmark
16081672
b.ResetTimer()
16091673
for n := 0; n < b.N; n++ {
16101674
conn, err := net.Dial("tcp", pl.Addr().String())
@@ -1615,16 +1679,15 @@ func benchmarkTCPProxy(size int, b *testing.B) {
16151679
if _, err := header.WriteTo(conn); err != nil {
16161680
b.Fatalf("err: %v", err)
16171681
}
1618-
//send data
1682+
// send data
16191683
go func() {
16201684
_, err = conn.Write(data)
16211685
_ = conn.(*net.TCPConn).CloseWrite()
16221686
if err != nil {
16231687
panic(fmt.Sprintf("Failed to write data: %v", err))
16241688
}
1625-
16261689
}()
1627-
//receive data
1690+
// receive data
16281691
n, err := io.Copy(ioutil.Discard, conn)
16291692
if n != int64(len(data)) {
16301693
b.Fatalf("Expected to receive %d bytes, got %d", len(data), n)
@@ -1639,24 +1702,31 @@ func benchmarkTCPProxy(size int, b *testing.B) {
16391702
func BenchmarkTCPProxy16KB(b *testing.B) {
16401703
benchmarkTCPProxy(16*1024, b)
16411704
}
1705+
16421706
func BenchmarkTCPProxy32KB(b *testing.B) {
16431707
benchmarkTCPProxy(32*1024, b)
16441708
}
1709+
16451710
func BenchmarkTCPProxy64KB(b *testing.B) {
16461711
benchmarkTCPProxy(64*1024, b)
16471712
}
1713+
16481714
func BenchmarkTCPProxy128KB(b *testing.B) {
16491715
benchmarkTCPProxy(128*1024, b)
16501716
}
1717+
16511718
func BenchmarkTCPProxy256KB(b *testing.B) {
16521719
benchmarkTCPProxy(256*1024, b)
16531720
}
1721+
16541722
func BenchmarkTCPProxy512KB(b *testing.B) {
16551723
benchmarkTCPProxy(512*1024, b)
16561724
}
1725+
16571726
func BenchmarkTCPProxy1024KB(b *testing.B) {
16581727
benchmarkTCPProxy(1024*1024, b)
16591728
}
1729+
16601730
func BenchmarkTCPProxy2048KB(b *testing.B) {
16611731
benchmarkTCPProxy(2048*1024, b)
16621732
}

0 commit comments

Comments
 (0)