Skip to content

Commit 255c95e

Browse files
committed
keep listener after erroring with invalid upstream
1 parent b718e7c commit 255c95e

File tree

2 files changed

+137
-51
lines changed

2 files changed

+137
-51
lines changed

protocol.go

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,26 @@ package proxyproto
22

33
import (
44
"bufio"
5+
"errors"
6+
"fmt"
57
"io"
68
"net"
79
"sync"
810
"sync/atomic"
911
"time"
1012
)
1113

12-
// DefaultReadHeaderTimeout is how long header processing waits for header to
13-
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
14-
// It's kept as a global variable so to make it easier to find and override,
15-
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
16-
var DefaultReadHeaderTimeout = 10 * time.Second
14+
var (
15+
// DefaultReadHeaderTimeout is how long header processing waits for header to
16+
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
17+
// It's kept as a global variable so to make it easier to find and override,
18+
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
19+
DefaultReadHeaderTimeout = 10 * time.Second
20+
21+
// ErrInvalidUpstream should be returned when an upstream connection address
22+
// is not trusted, and therefore is invalid.
23+
ErrInvalidUpstream = fmt.Errorf("proxyproto: upstream connection address not trusted for PROXY information")
24+
)
1725

1826
// Listener is used to wrap an underlying listener,
1927
// whose connections may be using the HAProxy Proxy Protocol.
@@ -63,53 +71,61 @@ func ValidateHeader(v Validator) func(*Conn) {
6371
}
6472
}
6573

66-
// Accept waits for and returns the next connection to the listener.
74+
// Accept waits for and returns the next valid connection to the listener.
6775
func (p *Listener) Accept() (net.Conn, error) {
68-
// Get the underlying connection
69-
conn, err := p.Listener.Accept()
70-
if err != nil {
71-
return nil, err
72-
}
73-
74-
proxyHeaderPolicy := USE
75-
if p.Policy != nil && p.ConnPolicy != nil {
76-
panic("only one of policy or connpolicy must be provided.")
77-
}
78-
if p.Policy != nil || p.ConnPolicy != nil {
79-
if p.Policy != nil {
80-
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
81-
} else {
82-
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
83-
Upstream: conn.RemoteAddr(),
84-
Downstream: conn.LocalAddr(),
85-
})
86-
}
76+
for {
77+
// Get the underlying connection
78+
conn, err := p.Listener.Accept()
8779
if err != nil {
88-
// can't decide the policy, we can't accept the connection
89-
conn.Close()
9080
return nil, err
9181
}
92-
// Handle a connection as a regular one
93-
if proxyHeaderPolicy == SKIP {
94-
return conn, nil
82+
83+
proxyHeaderPolicy := USE
84+
if p.Policy != nil && p.ConnPolicy != nil {
85+
panic("only one of policy or connpolicy must be provided.")
9586
}
96-
}
87+
if p.Policy != nil || p.ConnPolicy != nil {
88+
if p.Policy != nil {
89+
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
90+
} else {
91+
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
92+
Upstream: conn.RemoteAddr(),
93+
Downstream: conn.LocalAddr(),
94+
})
95+
}
96+
if err != nil {
97+
// can't decide the policy, we can't accept the connection
98+
conn.Close()
9799

98-
newConn := NewConn(
99-
conn,
100-
WithPolicy(proxyHeaderPolicy),
101-
ValidateHeader(p.ValidateHeader),
102-
)
100+
if errors.Is(err, ErrInvalidUpstream) {
101+
// keep listening for other connections
102+
continue
103+
}
103104

104-
// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
105-
if p.ReadHeaderTimeout == 0 {
106-
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
107-
}
105+
return nil, err
106+
}
107+
// Handle a connection as a regular one
108+
if proxyHeaderPolicy == SKIP {
109+
return conn, nil
110+
}
111+
}
108112

109-
// Set the readHeaderTimeout of the new conn to the value of the listener
110-
newConn.readHeaderTimeout = p.ReadHeaderTimeout
113+
newConn := NewConn(
114+
conn,
115+
WithPolicy(proxyHeaderPolicy),
116+
ValidateHeader(p.ValidateHeader),
117+
)
111118

112-
return newConn, nil
119+
// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
120+
if p.ReadHeaderTimeout == 0 {
121+
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
122+
}
123+
124+
// Set the readHeaderTimeout of the new conn to the value of the listener
125+
newConn.readHeaderTimeout = p.ReadHeaderTimeout
126+
127+
return newConn, nil
128+
}
113129
}
114130

115131
// Close closes the underlying listener.

protocol_test.go

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"io"
1414
"io/ioutil"
1515
"net"
16+
"net/http"
17+
"sync/atomic"
1618
"testing"
1719
"time"
1820
)
@@ -83,7 +85,6 @@ func TestRequiredWithReadHeaderTimeout(t *testing.T) {
8385
start := time.Now()
8486

8587
l, err := net.Listen("tcp", "127.0.0.1:0")
86-
8788
if err != nil {
8889
t.Fatalf("err: %v", err)
8990
}
@@ -138,7 +139,6 @@ func TestUseWithReadHeaderTimeout(t *testing.T) {
138139
start := time.Now()
139140

140141
l, err := net.Listen("tcp", "127.0.0.1:0")
141-
142142
if err != nil {
143143
t.Fatalf("err: %v", err)
144144
}
@@ -848,6 +848,7 @@ func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) {
848848
t.Fatalf("client error: %v", err)
849849
}
850850
}
851+
851852
func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) {
852853
l, err := net.Listen("tcp", "127.0.0.1:0")
853854
if err != nil {
@@ -1275,6 +1276,67 @@ func Test_ConnectionErrorsWhenHeaderValidationFails(t *testing.T) {
12751276
}
12761277
}
12771278

1279+
func Test_ConnectionHandlesInvalidUpstreamError(t *testing.T) {
1280+
l, err := net.Listen("tcp", "localhost:8080")
1281+
if err != nil {
1282+
t.Fatalf("error creating listener: %v", err)
1283+
}
1284+
1285+
var connectionCounter atomic.Int32
1286+
1287+
newLn := &Listener{
1288+
Listener: l,
1289+
ConnPolicy: func(_ ConnPolicyOptions) (Policy, error) {
1290+
// Return the invalid upstream error on the first call, the listener
1291+
// should remain open and accepting.
1292+
times := connectionCounter.Load()
1293+
if times == 0 {
1294+
connectionCounter.Store(times + 1)
1295+
return REJECT, ErrInvalidUpstream
1296+
}
1297+
1298+
return REJECT, ErrNoProxyProtocol
1299+
},
1300+
}
1301+
1302+
// Kick off the listener and return any error via the chanel.
1303+
errCh := make(chan error)
1304+
defer close(errCh)
1305+
go func(t *testing.T) {
1306+
_, err := newLn.Accept()
1307+
errCh <- err
1308+
}(t)
1309+
1310+
// Make two calls to trigger the listener's accept, the first should experience
1311+
// the ErrInvalidUpstream and keep the listener open, the second should experience
1312+
// a different error which will cause the listener to close.
1313+
_, _ = http.Get("http://localhost:8080")
1314+
// Wait a few seconds to ensure we didn't get anything back on our channel.
1315+
select {
1316+
case err := <-errCh:
1317+
if err != nil {
1318+
t.Fatalf("invalid upstream shouldn't return an error: %v", err)
1319+
}
1320+
case <-time.After(2 * time.Second):
1321+
// No error returned (as expected, we're still listening though)
1322+
}
1323+
1324+
_, _ = http.Get("http://localhost:8080")
1325+
// Wait a few seconds before we fail the test as we should have received an
1326+
// error that was not invalid upstream.
1327+
select {
1328+
case err := <-errCh:
1329+
if err == nil {
1330+
t.Fatalf("errors other than invalid upstream should error")
1331+
}
1332+
if !errors.Is(ErrNoProxyProtocol, err) {
1333+
t.Fatalf("unexpected error type: %v", err)
1334+
}
1335+
case <-time.After(2 * time.Second):
1336+
t.Fatalf("timed out waiting for listener")
1337+
}
1338+
}
1339+
12781340
type TestTLSServer struct {
12791341
Listener net.Listener
12801342

@@ -1483,9 +1545,11 @@ func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
14831545
b, err := ioutil.ReadAll(r)
14841546
return int64(len(b)), err
14851547
}
1548+
14861549
func (c *testConn) Write(p []byte) (int, error) {
14871550
return len(p), nil
14881551
}
1552+
14891553
func (c *testConn) Read(p []byte) (int, error) {
14901554
if c.reads == 0 {
14911555
return 0, io.EOF
@@ -1534,7 +1598,7 @@ func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
15341598
}
15351599

15361600
func benchmarkTCPProxy(size int, b *testing.B) {
1537-
//create and start the echo backend
1601+
// create and start the echo backend
15381602
backend, err := net.Listen("tcp", "127.0.0.1:0")
15391603
if err != nil {
15401604
b.Fatalf("err: %v", err)
@@ -1555,7 +1619,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
15551619
}
15561620
}()
15571621

1558-
//start the proxyprotocol enabled tcp proxy
1622+
// start the proxyprotocol enabled tcp proxy
15591623
l, err := net.Listen("tcp", "127.0.0.1:0")
15601624
if err != nil {
15611625
b.Fatalf("err: %v", err)
@@ -1604,7 +1668,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
16041668
},
16051669
}
16061670

1607-
//now for the actual benchmark
1671+
// now for the actual benchmark
16081672
b.ResetTimer()
16091673
for n := 0; n < b.N; n++ {
16101674
conn, err := net.Dial("tcp", pl.Addr().String())
@@ -1615,16 +1679,15 @@ func benchmarkTCPProxy(size int, b *testing.B) {
16151679
if _, err := header.WriteTo(conn); err != nil {
16161680
b.Fatalf("err: %v", err)
16171681
}
1618-
//send data
1682+
// send data
16191683
go func() {
16201684
_, err = conn.Write(data)
16211685
_ = conn.(*net.TCPConn).CloseWrite()
16221686
if err != nil {
16231687
panic(fmt.Sprintf("Failed to write data: %v", err))
16241688
}
1625-
16261689
}()
1627-
//receive data
1690+
// receive data
16281691
n, err := io.Copy(ioutil.Discard, conn)
16291692
if n != int64(len(data)) {
16301693
b.Fatalf("Expected to receive %d bytes, got %d", len(data), n)
@@ -1639,24 +1702,31 @@ func benchmarkTCPProxy(size int, b *testing.B) {
16391702
func BenchmarkTCPProxy16KB(b *testing.B) {
16401703
benchmarkTCPProxy(16*1024, b)
16411704
}
1705+
16421706
func BenchmarkTCPProxy32KB(b *testing.B) {
16431707
benchmarkTCPProxy(32*1024, b)
16441708
}
1709+
16451710
func BenchmarkTCPProxy64KB(b *testing.B) {
16461711
benchmarkTCPProxy(64*1024, b)
16471712
}
1713+
16481714
func BenchmarkTCPProxy128KB(b *testing.B) {
16491715
benchmarkTCPProxy(128*1024, b)
16501716
}
1717+
16511718
func BenchmarkTCPProxy256KB(b *testing.B) {
16521719
benchmarkTCPProxy(256*1024, b)
16531720
}
1721+
16541722
func BenchmarkTCPProxy512KB(b *testing.B) {
16551723
benchmarkTCPProxy(512*1024, b)
16561724
}
1725+
16571726
func BenchmarkTCPProxy1024KB(b *testing.B) {
16581727
benchmarkTCPProxy(1024*1024, b)
16591728
}
1729+
16601730
func BenchmarkTCPProxy2048KB(b *testing.B) {
16611731
benchmarkTCPProxy(2048*1024, b)
16621732
}

0 commit comments

Comments
 (0)