Skip to content

Commit 956f8fe

Browse files
committed
keep listener trying to accept connections and don't error on invalid upstream
1 parent b718e7c commit 956f8fe

File tree

1 file changed

+59
-43
lines changed

1 file changed

+59
-43
lines changed

protocol.go

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,26 @@ package proxyproto
22

33
import (
44
"bufio"
5+
"errors"
6+
"fmt"
57
"io"
68
"net"
79
"sync"
810
"sync/atomic"
911
"time"
1012
)
1113

12-
// DefaultReadHeaderTimeout is how long header processing waits for header to
13-
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
14-
// It's kept as a global variable so to make it easier to find and override,
15-
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
16-
var DefaultReadHeaderTimeout = 10 * time.Second
14+
var (
15+
// DefaultReadHeaderTimeout is how long header processing waits for header to
16+
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
17+
// It's kept as a global variable so to make it easier to find and override,
18+
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
19+
DefaultReadHeaderTimeout = 10 * time.Second
20+
21+
// ErrInvalidUpstream should be returned when an upstream connection address
22+
// is not trusted, and therefore is invalid.
23+
ErrInvalidUpstream = fmt.Errorf("proxyproto: upstream connection address not trusted for PROXY information")
24+
)
1725

1826
// Listener is used to wrap an underlying listener,
1927
// whose connections may be using the HAProxy Proxy Protocol.
@@ -63,53 +71,61 @@ func ValidateHeader(v Validator) func(*Conn) {
6371
}
6472
}
6573

66-
// Accept waits for and returns the next connection to the listener.
74+
// Accept waits for and returns the next valid connection to the listener.
6775
func (p *Listener) Accept() (net.Conn, error) {
68-
// Get the underlying connection
69-
conn, err := p.Listener.Accept()
70-
if err != nil {
71-
return nil, err
72-
}
73-
74-
proxyHeaderPolicy := USE
75-
if p.Policy != nil && p.ConnPolicy != nil {
76-
panic("only one of policy or connpolicy must be provided.")
77-
}
78-
if p.Policy != nil || p.ConnPolicy != nil {
79-
if p.Policy != nil {
80-
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
81-
} else {
82-
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
83-
Upstream: conn.RemoteAddr(),
84-
Downstream: conn.LocalAddr(),
85-
})
86-
}
76+
for {
77+
// Get the underlying connection
78+
conn, err := p.Listener.Accept()
8779
if err != nil {
88-
// can't decide the policy, we can't accept the connection
89-
conn.Close()
9080
return nil, err
9181
}
92-
// Handle a connection as a regular one
93-
if proxyHeaderPolicy == SKIP {
94-
return conn, nil
82+
83+
proxyHeaderPolicy := USE
84+
if p.Policy != nil && p.ConnPolicy != nil {
85+
panic("only one of policy or connpolicy must be provided.")
9586
}
96-
}
87+
if p.Policy != nil || p.ConnPolicy != nil {
88+
if p.Policy != nil {
89+
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
90+
} else {
91+
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
92+
Upstream: conn.RemoteAddr(),
93+
Downstream: conn.LocalAddr(),
94+
})
95+
}
96+
if err != nil {
97+
// can't decide the policy, we can't accept the connection
98+
conn.Close()
9799

98-
newConn := NewConn(
99-
conn,
100-
WithPolicy(proxyHeaderPolicy),
101-
ValidateHeader(p.ValidateHeader),
102-
)
100+
if errors.Is(err, ErrInvalidUpstream) {
101+
// keep listening for other connections
102+
continue
103+
}
103104

104-
// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
105-
if p.ReadHeaderTimeout == 0 {
106-
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
107-
}
105+
return nil, err
106+
}
107+
// Handle a connection as a regular one
108+
if proxyHeaderPolicy == SKIP {
109+
return conn, nil
110+
}
111+
}
108112

109-
// Set the readHeaderTimeout of the new conn to the value of the listener
110-
newConn.readHeaderTimeout = p.ReadHeaderTimeout
113+
newConn := NewConn(
114+
conn,
115+
WithPolicy(proxyHeaderPolicy),
116+
ValidateHeader(p.ValidateHeader),
117+
)
111118

112-
return newConn, nil
119+
// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
120+
if p.ReadHeaderTimeout == 0 {
121+
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
122+
}
123+
124+
// Set the readHeaderTimeout of the new conn to the value of the listener
125+
newConn.readHeaderTimeout = p.ReadHeaderTimeout
126+
127+
return newConn, nil
128+
}
113129
}
114130

115131
// Close closes the underlying listener.

0 commit comments

Comments
 (0)