Skip to content

Commit 6ac4f3c

Browse files
committed
Add support for validating the downstream ip of the connection
1 parent 8a2480a commit 6ac4f3c

File tree

4 files changed

+81
-25
lines changed

4 files changed

+81
-25
lines changed

policy.go

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ import (
66
"strings"
77
)
88

9-
// PolicyFunc can be used to decide whether to trust the PROXY info from
10-
// upstream. If set, the connecting address is passed in as an argument.
9+
// PolicyFunc can be used to decide whether to trust the PROXY info based on
10+
// upstream/downstream IP. If set, the connecting addresses(remote and local)
11+
// are passed in as arguments.
1112
//
1213
// See below for the different policies.
1314
//
1415
// In case an error is returned the connection is denied.
15-
type PolicyFunc func(upstream net.Addr) (Policy, error)
16+
type PolicyFunc func(upstream net.Addr, downstream net.Addr) (Policy, error)
1617

1718
// Policy defines how a connection with a PROXY header address is treated.
1819
type Policy int
@@ -43,7 +44,7 @@ const (
4344
// Kubernetes pods local traffic. The def is a policy to use when an upstream
4445
// address doesn't match the skipHeaderCIDR.
4546
func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc {
46-
return func(upstream net.Addr) (Policy, error) {
47+
return func(upstream net.Addr, downstream net.Addr) (Policy, error) {
4748
ip, err := ipFromAddr(upstream)
4849
if err != nil {
4950
return def, err
@@ -57,6 +58,25 @@ func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc {
5758
}
5859
}
5960

61+
// IgnoreProxyHeaderNotOnInterface retuns a PolicyFunc which can be used to
62+
// decide whether to use or ignore PROXY headers depending on the connection
63+
// being made on a specific interface. This policy can be used when the server
64+
// is bound to multiple interfaces but wants to allow on only one interface.
65+
func IgnoreProxyHeaderNotOnInterface(allowedIP net.IP) PolicyFunc {
66+
return func(upstream net.Addr, downstream net.Addr) (Policy, error) {
67+
ip, err := ipFromAddr(downstream)
68+
if err != nil {
69+
return REJECT, err
70+
}
71+
72+
if allowedIP.Equal(ip) {
73+
return USE, nil
74+
}
75+
76+
return IGNORE, nil
77+
}
78+
}
79+
6080
// WithPolicy adds given policy to a connection when passed as option to NewConn()
6181
func WithPolicy(p Policy) func(*Conn) {
6282
return func(c *Conn) {
@@ -117,7 +137,7 @@ func MustStrictWhiteListPolicy(allowed []string) PolicyFunc {
117137
}
118138

119139
func whitelistPolicy(allowed []func(net.IP) bool, def Policy) PolicyFunc {
120-
return func(upstream net.Addr) (Policy, error) {
140+
return func(upstream net.Addr, downstream net.Addr) (Policy, error) {
121141
upstreamIP, err := ipFromAddr(upstream)
122142
if err != nil {
123143
// something is wrong with the source IP, better reject the connection

policy_test.go

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func TestWhitelistPolicyReturnsErrorOnInvalidAddress(t *testing.T) {
2121

2222
for _, tc := range cases {
2323
t.Run(tc.name, func(t *testing.T) {
24-
_, err := tc.policy(failingAddr{})
24+
_, err := tc.policy(failingAddr{}, nil)
2525
if err == nil {
2626
t.Fatal("Expected error, got none")
2727
}
@@ -37,7 +37,7 @@ func TestStrictWhitelistPolicyReturnsRejectWhenUpstreamIpAddrNotInWhitelist(t *t
3737
t.Fatalf("err: %v", err)
3838
}
3939

40-
policy, err := p(upstream)
40+
policy, err := p(upstream, nil)
4141
if err != nil {
4242
t.Fatalf("err: %v", err)
4343
}
@@ -55,7 +55,7 @@ func TestLaxWhitelistPolicyReturnsIgnoreWhenUpstreamIpAddrNotInWhitelist(t *test
5555
t.Fatalf("err: %v", err)
5656
}
5757

58-
policy, err := p(upstream)
58+
policy, err := p(upstream, nil)
5959
if err != nil {
6060
t.Fatalf("err: %v", err)
6161
}
@@ -81,7 +81,7 @@ func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelist(t *testing.T) {
8181

8282
for _, tc := range cases {
8383
t.Run(tc.name, func(t *testing.T) {
84-
policy, err := tc.policy(upstream)
84+
policy, err := tc.policy(upstream, nil)
8585
if err != nil {
8686
t.Fatalf("err: %v", err)
8787
}
@@ -109,7 +109,7 @@ func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelistRange(t *testing.
109109

110110
for _, tc := range cases {
111111
t.Run(tc.name, func(t *testing.T) {
112-
policy, err := tc.policy(upstream)
112+
policy, err := tc.policy(upstream, nil)
113113
if err != nil {
114114
t.Fatalf("err: %v", err)
115115
}
@@ -194,7 +194,7 @@ func TestSkipProxyHeaderForCIDR(t *testing.T) {
194194
f := SkipProxyHeaderForCIDR(cidr, REJECT)
195195

196196
upstream, _ := net.ResolveTCPAddr("tcp", "192.0.2.255:12345")
197-
policy, err := f(upstream)
197+
policy, err := f(upstream, nil)
198198
if err != nil {
199199
t.Fatalf("err: %v", err)
200200
}
@@ -203,11 +203,47 @@ func TestSkipProxyHeaderForCIDR(t *testing.T) {
203203
}
204204

205205
upstream, _ = net.ResolveTCPAddr("tcp", "8.8.8.8:12345")
206-
policy, err = f(upstream)
206+
policy, err = f(upstream, nil)
207207
if err != nil {
208208
t.Fatalf("err: %v", err)
209209
}
210210
if policy != REJECT {
211211
t.Errorf("Expected a REJECT policy for the %s address", upstream)
212212
}
213213
}
214+
215+
func TestIgnoreProxyHeaderNotOnInterface(t *testing.T) {
216+
downstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738")
217+
if err != nil {
218+
t.Fatalf("err: %v", err)
219+
}
220+
221+
var cases = []struct {
222+
name string
223+
policy PolicyFunc
224+
downstreamAddress net.Addr
225+
expectedPolicy Policy
226+
expectError bool
227+
}{
228+
{"ignore header for requests non on interface", IgnoreProxyHeaderNotOnInterface(net.ParseIP("192.0.2.1")), downstream, IGNORE, false},
229+
{"use headers for requests on interface", IgnoreProxyHeaderNotOnInterface(net.ParseIP("10.0.0.3")), downstream, USE, false},
230+
{"invalid address should return error", IgnoreProxyHeaderNotOnInterface(net.ParseIP("10.0.0.3")), failingAddr{}, REJECT, true},
231+
}
232+
233+
for _, tc := range cases {
234+
t.Run(tc.name, func(t *testing.T) {
235+
policy, err := tc.policy(nil, tc.downstreamAddress)
236+
if !tc.expectError && err != nil {
237+
t.Fatalf("err: %v", err)
238+
}
239+
if tc.expectError && err == nil {
240+
t.Fatal("Expected error, got none")
241+
}
242+
243+
if policy != tc.expectedPolicy {
244+
t.Fatalf("Expected policy %v, got %v", tc.expectedPolicy, policy)
245+
}
246+
})
247+
}
248+
249+
}

protocol.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func (p *Listener) Accept() (net.Conn, error) {
6868

6969
proxyHeaderPolicy := USE
7070
if p.Policy != nil {
71-
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
71+
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr(), conn.LocalAddr())
7272
if err != nil {
7373
// can't decide the policy, we can't accept the connection
7474
conn.Close()

protocol_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func TestRequiredWithReadHeaderTimeout(t *testing.T) {
9191
pl := &Listener{
9292
Listener: l,
9393
ReadHeaderTimeout: time.Millisecond * time.Duration(duration),
94-
Policy: func(upstream net.Addr) (Policy, error) {
94+
Policy: func(upstream net.Addr, downstream net.Addr) (Policy, error) {
9595
return REQUIRE, nil
9696
},
9797
}
@@ -146,7 +146,7 @@ func TestUseWithReadHeaderTimeout(t *testing.T) {
146146
pl := &Listener{
147147
Listener: l,
148148
ReadHeaderTimeout: time.Millisecond * time.Duration(duration),
149-
Policy: func(upstream net.Addr) (Policy, error) {
149+
Policy: func(upstream net.Addr, downstream net.Addr) (Policy, error) {
150150
return USE, nil
151151
},
152152
}
@@ -645,7 +645,7 @@ func TestAcceptReturnsErrorWhenPolicyFuncErrors(t *testing.T) {
645645
}
646646

647647
expectedErr := fmt.Errorf("failure")
648-
policyFunc := func(upstream net.Addr) (Policy, error) { return USE, expectedErr }
648+
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return USE, expectedErr }
649649

650650
pl := &Listener{Listener: l, Policy: policyFunc}
651651

@@ -681,7 +681,7 @@ func TestReadingIsRefusedWhenProxyHeaderRequiredButMissing(t *testing.T) {
681681
t.Fatalf("err: %v", err)
682682
}
683683

684-
policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }
684+
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return REQUIRE, nil }
685685

686686
pl := &Listener{Listener: l, Policy: policyFunc}
687687

@@ -724,7 +724,7 @@ func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) {
724724
t.Fatalf("err: %v", err)
725725
}
726726

727-
policyFunc := func(upstream net.Addr) (Policy, error) { return REJECT, nil }
727+
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return REJECT, nil }
728728

729729
pl := &Listener{Listener: l, Policy: policyFunc}
730730

@@ -778,7 +778,7 @@ func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) {
778778
t.Fatalf("err: %v", err)
779779
}
780780

781-
policyFunc := func(upstream net.Addr) (Policy, error) { return IGNORE, nil }
781+
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return IGNORE, nil }
782782

783783
pl := &Listener{Listener: l, Policy: policyFunc}
784784

@@ -891,7 +891,7 @@ func TestReadingIsRefusedOnErrorWhenRemoteAddrRequestedFirst(t *testing.T) {
891891
t.Fatalf("err: %v", err)
892892
}
893893

894-
policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }
894+
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return REQUIRE, nil }
895895

896896
pl := &Listener{Listener: l, Policy: policyFunc}
897897

@@ -935,7 +935,7 @@ func TestReadingIsRefusedOnErrorWhenLocalAddrRequestedFirst(t *testing.T) {
935935
t.Fatalf("err: %v", err)
936936
}
937937

938-
policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }
938+
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return REQUIRE, nil }
939939

940940
pl := &Listener{Listener: l, Policy: policyFunc}
941941

@@ -979,7 +979,7 @@ func TestSkipProxyProtocolPolicy(t *testing.T) {
979979
t.Fatalf("err: %v", err)
980980
}
981981

982-
policyFunc := func(upstream net.Addr) (Policy, error) { return SKIP, nil }
982+
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return SKIP, nil }
983983

984984
pl := &Listener{
985985
Listener: l,
@@ -1036,7 +1036,7 @@ func Test_ConnectionCasts(t *testing.T) {
10361036
t.Fatalf("err: %v", err)
10371037
}
10381038

1039-
policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }
1039+
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return REQUIRE, nil }
10401040

10411041
pl := &Listener{Listener: l, Policy: policyFunc}
10421042

@@ -1198,7 +1198,7 @@ func Test_TLSServer(t *testing.T) {
11981198
s := NewTestTLSServer(l)
11991199
s.Listener = &Listener{
12001200
Listener: s.Listener,
1201-
Policy: func(upstream net.Addr) (Policy, error) {
1201+
Policy: func(upstream net.Addr, downstream net.Addr) (Policy, error) {
12021202
return REQUIRE, nil
12031203
},
12041204
}
@@ -1269,7 +1269,7 @@ func Test_MisconfiguredTLSServerRespondsWithUnderlyingError(t *testing.T) {
12691269
s := NewTestTLSServer(l)
12701270
s.Listener = &Listener{
12711271
Listener: s.Listener,
1272-
Policy: func(upstream net.Addr) (Policy, error) {
1272+
Policy: func(upstream net.Addr, downstream net.Addr) (Policy, error) {
12731273
return REQUIRE, nil
12741274
},
12751275
}

0 commit comments

Comments
 (0)