@@ -2,18 +2,26 @@ package proxyproto
2
2
3
3
import (
4
4
"bufio"
5
+ "errors"
6
+ "fmt"
5
7
"io"
6
8
"net"
7
9
"sync"
8
10
"sync/atomic"
9
11
"time"
10
12
)
11
13
12
- // DefaultReadHeaderTimeout is how long header processing waits for header to
13
- // be read from the wire, if Listener.ReaderHeaderTimeout is not set.
14
- // It's kept as a global variable so to make it easier to find and override,
15
- // e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
16
- var DefaultReadHeaderTimeout = 10 * time .Second
14
+ var (
15
+ // DefaultReadHeaderTimeout is how long header processing waits for header to
16
+ // be read from the wire, if Listener.ReaderHeaderTimeout is not set.
17
+ // It's kept as a global variable so to make it easier to find and override,
18
+ // e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
19
+ DefaultReadHeaderTimeout = 10 * time .Second
20
+
21
+ // ErrInvalidUpstream should be returned when an upstream connection address
22
+ // is not trusted, and therefore is invalid.
23
+ ErrInvalidUpstream = fmt .Errorf ("proxyproto: upstream connection address not trusted for PROXY information" )
24
+ )
17
25
18
26
// Listener is used to wrap an underlying listener,
19
27
// whose connections may be using the HAProxy Proxy Protocol.
@@ -63,53 +71,61 @@ func ValidateHeader(v Validator) func(*Conn) {
63
71
}
64
72
}
65
73
66
- // Accept waits for and returns the next connection to the listener.
74
+ // Accept waits for and returns the next valid connection to the listener.
67
75
func (p * Listener ) Accept () (net.Conn , error ) {
68
- // Get the underlying connection
69
- conn , err := p .Listener .Accept ()
70
- if err != nil {
71
- return nil , err
72
- }
73
-
74
- proxyHeaderPolicy := USE
75
- if p .Policy != nil && p .ConnPolicy != nil {
76
- panic ("only one of policy or connpolicy must be provided." )
77
- }
78
- if p .Policy != nil || p .ConnPolicy != nil {
79
- if p .Policy != nil {
80
- proxyHeaderPolicy , err = p .Policy (conn .RemoteAddr ())
81
- } else {
82
- proxyHeaderPolicy , err = p .ConnPolicy (ConnPolicyOptions {
83
- Upstream : conn .RemoteAddr (),
84
- Downstream : conn .LocalAddr (),
85
- })
86
- }
76
+ for {
77
+ // Get the underlying connection
78
+ conn , err := p .Listener .Accept ()
87
79
if err != nil {
88
- // can't decide the policy, we can't accept the connection
89
- conn .Close ()
90
80
return nil , err
91
81
}
92
- // Handle a connection as a regular one
93
- if proxyHeaderPolicy == SKIP {
94
- return conn , nil
82
+
83
+ proxyHeaderPolicy := USE
84
+ if p .Policy != nil && p .ConnPolicy != nil {
85
+ panic ("only one of policy or connpolicy must be provided." )
95
86
}
96
- }
87
+ if p .Policy != nil || p .ConnPolicy != nil {
88
+ if p .Policy != nil {
89
+ proxyHeaderPolicy , err = p .Policy (conn .RemoteAddr ())
90
+ } else {
91
+ proxyHeaderPolicy , err = p .ConnPolicy (ConnPolicyOptions {
92
+ Upstream : conn .RemoteAddr (),
93
+ Downstream : conn .LocalAddr (),
94
+ })
95
+ }
96
+ if err != nil {
97
+ // can't decide the policy, we can't accept the connection
98
+ conn .Close ()
97
99
98
- newConn := NewConn (
99
- conn ,
100
- WithPolicy (proxyHeaderPolicy ),
101
- ValidateHeader (p .ValidateHeader ),
102
- )
100
+ if errors .Is (err , ErrInvalidUpstream ) {
101
+ // keep listening for other connections
102
+ continue
103
+ }
103
104
104
- // If the ReadHeaderTimeout for the listener is unset, use the default timeout.
105
- if p .ReadHeaderTimeout == 0 {
106
- p .ReadHeaderTimeout = DefaultReadHeaderTimeout
107
- }
105
+ return nil , err
106
+ }
107
+ // Handle a connection as a regular one
108
+ if proxyHeaderPolicy == SKIP {
109
+ return conn , nil
110
+ }
111
+ }
108
112
109
- // Set the readHeaderTimeout of the new conn to the value of the listener
110
- newConn .readHeaderTimeout = p .ReadHeaderTimeout
113
+ newConn := NewConn (
114
+ conn ,
115
+ WithPolicy (proxyHeaderPolicy ),
116
+ ValidateHeader (p .ValidateHeader ),
117
+ )
111
118
112
- return newConn , nil
119
+ // If the ReadHeaderTimeout for the listener is unset, use the default timeout.
120
+ if p .ReadHeaderTimeout == 0 {
121
+ p .ReadHeaderTimeout = DefaultReadHeaderTimeout
122
+ }
123
+
124
+ // Set the readHeaderTimeout of the new conn to the value of the listener
125
+ newConn .readHeaderTimeout = p .ReadHeaderTimeout
126
+
127
+ return newConn , nil
128
+ }
113
129
}
114
130
115
131
// Close closes the underlying listener.
0 commit comments