Skip to content

Commit 1b0c625

Browse files
committed
add new connpolicy which can validate upstream and downstream addresses
1 parent 6ac4f3c commit 1b0c625

File tree

4 files changed

+204
-48
lines changed

4 files changed

+204
-48
lines changed

policy.go

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,28 @@ import (
66
"strings"
77
)
88

9-
// PolicyFunc can be used to decide whether to trust the PROXY info based on
10-
// upstream/downstream IP. If set, the connecting addresses(remote and local)
11-
// are passed in as arguments.
9+
// PolicyFunc can be used to decide whether to trust the PROXY info from
10+
// upstream. If set, the connecting address is passed in as an argument.
1211
//
1312
// See below for the different policies.
1413
//
1514
// In case an error is returned the connection is denied.
16-
type PolicyFunc func(upstream net.Addr, downstream net.Addr) (Policy, error)
15+
type PolicyFunc func(upstream net.Addr) (Policy, error)
16+
17+
// ConnPolicyFunc can be used to decide whether to trust the PROXY info
18+
// based on connection policy options. If set, the connecting addresses
19+
// (remote and local) are passed in as argument.
20+
//
21+
// See below for the different policies.
22+
//
23+
// In case an error is returned the connection is denied.
24+
type ConnPolicyFunc func(connPolicyOptions ConnPolicyOptions) (Policy, error)
25+
26+
// ConnPolicyOptions contains the remote and local addresses of a connection.
27+
type ConnPolicyOptions struct {
28+
Upstream net.Addr
29+
Downstream net.Addr
30+
}
1731

1832
// Policy defines how a connection with a PROXY header address is treated.
1933
type Policy int
@@ -44,7 +58,7 @@ const (
4458
// Kubernetes pods local traffic. The def is a policy to use when an upstream
4559
// address doesn't match the skipHeaderCIDR.
4660
func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc {
47-
return func(upstream net.Addr, downstream net.Addr) (Policy, error) {
61+
return func(upstream net.Addr) (Policy, error) {
4862
ip, err := ipFromAddr(upstream)
4963
if err != nil {
5064
return def, err
@@ -58,25 +72,6 @@ func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc {
5872
}
5973
}
6074

61-
// IgnoreProxyHeaderNotOnInterface retuns a PolicyFunc which can be used to
62-
// decide whether to use or ignore PROXY headers depending on the connection
63-
// being made on a specific interface. This policy can be used when the server
64-
// is bound to multiple interfaces but wants to allow on only one interface.
65-
func IgnoreProxyHeaderNotOnInterface(allowedIP net.IP) PolicyFunc {
66-
return func(upstream net.Addr, downstream net.Addr) (Policy, error) {
67-
ip, err := ipFromAddr(downstream)
68-
if err != nil {
69-
return REJECT, err
70-
}
71-
72-
if allowedIP.Equal(ip) {
73-
return USE, nil
74-
}
75-
76-
return IGNORE, nil
77-
}
78-
}
79-
8075
// WithPolicy adds given policy to a connection when passed as option to NewConn()
8176
func WithPolicy(p Policy) func(*Conn) {
8277
return func(c *Conn) {
@@ -137,7 +132,7 @@ func MustStrictWhiteListPolicy(allowed []string) PolicyFunc {
137132
}
138133

139134
func whitelistPolicy(allowed []func(net.IP) bool, def Policy) PolicyFunc {
140-
return func(upstream net.Addr, downstream net.Addr) (Policy, error) {
135+
return func(upstream net.Addr) (Policy, error) {
141136
upstreamIP, err := ipFromAddr(upstream)
142137
if err != nil {
143138
// something is wrong with the source IP, better reject the connection
@@ -190,3 +185,22 @@ func ipFromAddr(upstream net.Addr) (net.IP, error) {
190185

191186
return upstreamIP, nil
192187
}
188+
189+
// IgnoreProxyHeaderNotOnInterface retuns a ConnPolicyFunc which can be used to
190+
// decide whether to use or ignore PROXY headers depending on the connection
191+
// being made on a specific interface. This policy can be used when the server
192+
// is bound to multiple interfaces but wants to allow on only one interface.
193+
func IgnoreProxyHeaderNotOnInterface(allowedIP net.IP) ConnPolicyFunc {
194+
return func(connOpts ConnPolicyOptions) (Policy, error) {
195+
ip, err := ipFromAddr(connOpts.Downstream)
196+
if err != nil {
197+
return REJECT, err
198+
}
199+
200+
if allowedIP.Equal(ip) {
201+
return USE, nil
202+
}
203+
204+
return IGNORE, nil
205+
}
206+
}

policy_test.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func TestWhitelistPolicyReturnsErrorOnInvalidAddress(t *testing.T) {
2121

2222
for _, tc := range cases {
2323
t.Run(tc.name, func(t *testing.T) {
24-
_, err := tc.policy(failingAddr{}, nil)
24+
_, err := tc.policy(failingAddr{})
2525
if err == nil {
2626
t.Fatal("Expected error, got none")
2727
}
@@ -37,7 +37,7 @@ func TestStrictWhitelistPolicyReturnsRejectWhenUpstreamIpAddrNotInWhitelist(t *t
3737
t.Fatalf("err: %v", err)
3838
}
3939

40-
policy, err := p(upstream, nil)
40+
policy, err := p(upstream)
4141
if err != nil {
4242
t.Fatalf("err: %v", err)
4343
}
@@ -55,7 +55,7 @@ func TestLaxWhitelistPolicyReturnsIgnoreWhenUpstreamIpAddrNotInWhitelist(t *test
5555
t.Fatalf("err: %v", err)
5656
}
5757

58-
policy, err := p(upstream, nil)
58+
policy, err := p(upstream)
5959
if err != nil {
6060
t.Fatalf("err: %v", err)
6161
}
@@ -81,7 +81,7 @@ func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelist(t *testing.T) {
8181

8282
for _, tc := range cases {
8383
t.Run(tc.name, func(t *testing.T) {
84-
policy, err := tc.policy(upstream, nil)
84+
policy, err := tc.policy(upstream)
8585
if err != nil {
8686
t.Fatalf("err: %v", err)
8787
}
@@ -109,7 +109,7 @@ func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelistRange(t *testing.
109109

110110
for _, tc := range cases {
111111
t.Run(tc.name, func(t *testing.T) {
112-
policy, err := tc.policy(upstream, nil)
112+
policy, err := tc.policy(upstream)
113113
if err != nil {
114114
t.Fatalf("err: %v", err)
115115
}
@@ -194,7 +194,7 @@ func TestSkipProxyHeaderForCIDR(t *testing.T) {
194194
f := SkipProxyHeaderForCIDR(cidr, REJECT)
195195

196196
upstream, _ := net.ResolveTCPAddr("tcp", "192.0.2.255:12345")
197-
policy, err := f(upstream, nil)
197+
policy, err := f(upstream)
198198
if err != nil {
199199
t.Fatalf("err: %v", err)
200200
}
@@ -203,7 +203,7 @@ func TestSkipProxyHeaderForCIDR(t *testing.T) {
203203
}
204204

205205
upstream, _ = net.ResolveTCPAddr("tcp", "8.8.8.8:12345")
206-
policy, err = f(upstream, nil)
206+
policy, err = f(upstream)
207207
if err != nil {
208208
t.Fatalf("err: %v", err)
209209
}
@@ -220,7 +220,7 @@ func TestIgnoreProxyHeaderNotOnInterface(t *testing.T) {
220220

221221
var cases = []struct {
222222
name string
223-
policy PolicyFunc
223+
policy ConnPolicyFunc
224224
downstreamAddress net.Addr
225225
expectedPolicy Policy
226226
expectError bool
@@ -232,7 +232,9 @@ func TestIgnoreProxyHeaderNotOnInterface(t *testing.T) {
232232

233233
for _, tc := range cases {
234234
t.Run(tc.name, func(t *testing.T) {
235-
policy, err := tc.policy(nil, tc.downstreamAddress)
235+
policy, err := tc.policy(ConnPolicyOptions{
236+
Downstream: tc.downstreamAddress,
237+
})
236238
if !tc.expectError && err != nil {
237239
t.Fatalf("err: %v", err)
238240
}

protocol.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@ var DefaultReadHeaderTimeout = 10 * time.Second
2222
// connections in order to prevent blocking operations. If no ReadHeaderTimeout
2323
// is set, a default of 200ms will be used. This can be disabled by setting the
2424
// timeout to < 0.
25+
//
26+
// Only one of Policy or ConnPolicy should be provided. If both are provided then
27+
// a panic would occur during accept.
2528
type Listener struct {
2629
Listener net.Listener
2730
Policy PolicyFunc
31+
ConnPolicy ConnPolicyFunc
2832
ValidateHeader Validator
2933
ReadHeaderTimeout time.Duration
3034
}
@@ -67,8 +71,18 @@ func (p *Listener) Accept() (net.Conn, error) {
6771
}
6872

6973
proxyHeaderPolicy := USE
70-
if p.Policy != nil {
71-
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr(), conn.LocalAddr())
74+
if p.Policy != nil && p.ConnPolicy != nil {
75+
panic("only one of policy or connpolicy must be provided.")
76+
}
77+
if p.Policy != nil || p.ConnPolicy != nil {
78+
if p.Policy != nil {
79+
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
80+
} else {
81+
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
82+
Upstream: conn.RemoteAddr(),
83+
Downstream: conn.LocalAddr(),
84+
})
85+
}
7286
if err != nil {
7387
// can't decide the policy, we can't accept the connection
7488
conn.Close()

0 commit comments

Comments
 (0)