Skip to content

Commit 98ac070

Browse files
committed
policy: add REFUSE
in strict whitelist policies we want to refuse a connection from a not allowed upstream address whether the proxy header is set or not set. Before this change if the upstream address is not allowed: 1) if the policy returns REJECT, the connection is allowed if no proxy header is sent 2) if the policy returns REQUIRE, the connection is allowed if a proxy header is set, even if the upstream address is not allowed to set it. The new REFUSE policy can be returned for not allowed addresses so that the connection is always refused.
1 parent b718e7c commit 98ac070

File tree

4 files changed

+88
-77
lines changed

4 files changed

+88
-77
lines changed

policy.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ const (
5151
// Note: an example usage can be found in the SkipProxyHeaderForCIDR
5252
// function.
5353
SKIP
54+
// REFUSE is the same as REJECT if a proxy header is set and the same as
55+
// REQUIRE if a proxy header is not set.
56+
REFUSE
5457
)
5558

5659
// SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a
@@ -117,7 +120,7 @@ func StrictWhiteListPolicy(allowed []string) (PolicyFunc, error) {
117120
return nil, err
118121
}
119122

120-
return whitelistPolicy(allowFrom, REJECT), nil
123+
return whitelistPolicy(allowFrom, REFUSE), nil
121124
}
122125

123126
// MustStrictWhiteListPolicy returns a StrictWhiteListPolicy but will panic

policy_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ func TestStrictWhitelistPolicyReturnsRejectWhenUpstreamIpAddrNotInWhitelist(t *t
4242
t.Fatalf("err: %v", err)
4343
}
4444

45-
if policy != REJECT {
46-
t.Fatalf("Expected policy REJECT, got %v", policy)
45+
if policy != REFUSE {
46+
t.Fatalf("Expected policy REFUSE, got %v", policy)
4747
}
4848
}
4949

protocol.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ func (p *Conn) readHeader() error {
288288
// let's act as if there was no error when PROXY protocol is not present.
289289
if err == ErrNoProxyProtocol {
290290
// but not if it is required that the connection has one
291-
if p.ProxyHeaderPolicy == REQUIRE {
291+
if p.ProxyHeaderPolicy == REQUIRE || p.ProxyHeaderPolicy == REFUSE {
292292
return err
293293
}
294294

@@ -298,7 +298,7 @@ func (p *Conn) readHeader() error {
298298
// proxy protocol header was found
299299
if err == nil && header != nil {
300300
switch p.ProxyHeaderPolicy {
301-
case REJECT:
301+
case REJECT, REFUSE:
302302
// this connection is not allowed to send one
303303
return ErrSuperfluousProxyHeader
304304
case USE, REQUIRE:

protocol_test.go

Lines changed: 80 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -752,100 +752,108 @@ func TestAcceptReturnsErrorWhenConnPolicyFuncErrors(t *testing.T) {
752752
}
753753

754754
func TestReadingIsRefusedWhenProxyHeaderRequiredButMissing(t *testing.T) {
755-
l, err := net.Listen("tcp", "127.0.0.1:0")
756-
if err != nil {
757-
t.Fatalf("err: %v", err)
755+
policyFuncs := []PolicyFunc{
756+
func(upstream net.Addr) (Policy, error) { return REQUIRE, nil },
757+
func(upstream net.Addr) (Policy, error) { return REFUSE, nil },
758758
}
759+
for _, policyFunc := range policyFuncs {
760+
l, err := net.Listen("tcp", "127.0.0.1:0")
761+
if err != nil {
762+
t.Fatalf("err: %v", err)
763+
}
759764

760-
policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }
765+
pl := &Listener{Listener: l, Policy: policyFunc}
761766

762-
pl := &Listener{Listener: l, Policy: policyFunc}
767+
cliResult := make(chan error)
768+
go func() {
769+
conn, err := net.Dial("tcp", pl.Addr().String())
770+
if err != nil {
771+
cliResult <- err
772+
return
773+
}
774+
defer conn.Close()
763775

764-
cliResult := make(chan error)
765-
go func() {
766-
conn, err := net.Dial("tcp", pl.Addr().String())
776+
if _, err := conn.Write([]byte("ping")); err != nil {
777+
cliResult <- err
778+
return
779+
}
780+
781+
close(cliResult)
782+
}()
783+
784+
conn, err := pl.Accept()
767785
if err != nil {
768-
cliResult <- err
769-
return
786+
t.Fatalf("err: %v", err)
770787
}
771788
defer conn.Close()
772789

773-
if _, err := conn.Write([]byte("ping")); err != nil {
774-
cliResult <- err
775-
return
790+
recv := make([]byte, 4)
791+
if _, err = conn.Read(recv); err != ErrNoProxyProtocol {
792+
t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err)
793+
}
794+
err = <-cliResult
795+
if err != nil {
796+
t.Fatalf("client error: %v", err)
776797
}
777-
778-
close(cliResult)
779-
}()
780-
781-
conn, err := pl.Accept()
782-
if err != nil {
783-
t.Fatalf("err: %v", err)
784-
}
785-
defer conn.Close()
786-
787-
recv := make([]byte, 4)
788-
if _, err = conn.Read(recv); err != ErrNoProxyProtocol {
789-
t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err)
790-
}
791-
err = <-cliResult
792-
if err != nil {
793-
t.Fatalf("client error: %v", err)
794798
}
795799
}
796800

797801
func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) {
798-
l, err := net.Listen("tcp", "127.0.0.1:0")
799-
if err != nil {
800-
t.Fatalf("err: %v", err)
802+
policyFuncs := []PolicyFunc{
803+
func(upstream net.Addr) (Policy, error) { return REJECT, nil },
804+
func(upstream net.Addr) (Policy, error) { return REFUSE, nil },
801805
}
806+
for _, policyFunc := range policyFuncs {
807+
l, err := net.Listen("tcp", "127.0.0.1:0")
808+
if err != nil {
809+
t.Fatalf("err: %v", err)
810+
}
802811

803-
policyFunc := func(upstream net.Addr) (Policy, error) { return REJECT, nil }
812+
pl := &Listener{Listener: l, Policy: policyFunc}
804813

805-
pl := &Listener{Listener: l, Policy: policyFunc}
814+
cliResult := make(chan error)
815+
go func() {
816+
conn, err := net.Dial("tcp", pl.Addr().String())
817+
if err != nil {
818+
cliResult <- err
819+
return
820+
}
821+
defer conn.Close()
822+
header := &Header{
823+
Version: 2,
824+
Command: PROXY,
825+
TransportProtocol: TCPv4,
826+
SourceAddr: &net.TCPAddr{
827+
IP: net.ParseIP("10.1.1.1"),
828+
Port: 1000,
829+
},
830+
DestinationAddr: &net.TCPAddr{
831+
IP: net.ParseIP("20.2.2.2"),
832+
Port: 2000,
833+
},
834+
}
835+
if _, err := header.WriteTo(conn); err != nil {
836+
cliResult <- err
837+
return
838+
}
806839

807-
cliResult := make(chan error)
808-
go func() {
809-
conn, err := net.Dial("tcp", pl.Addr().String())
840+
close(cliResult)
841+
}()
842+
843+
conn, err := pl.Accept()
810844
if err != nil {
811-
cliResult <- err
812-
return
845+
t.Fatalf("err: %v", err)
813846
}
814847
defer conn.Close()
815-
header := &Header{
816-
Version: 2,
817-
Command: PROXY,
818-
TransportProtocol: TCPv4,
819-
SourceAddr: &net.TCPAddr{
820-
IP: net.ParseIP("10.1.1.1"),
821-
Port: 1000,
822-
},
823-
DestinationAddr: &net.TCPAddr{
824-
IP: net.ParseIP("20.2.2.2"),
825-
Port: 2000,
826-
},
848+
849+
recv := make([]byte, 4)
850+
if _, err = conn.Read(recv); err != ErrSuperfluousProxyHeader {
851+
t.Fatalf("Expected error %v, received %v", ErrSuperfluousProxyHeader, err)
827852
}
828-
if _, err := header.WriteTo(conn); err != nil {
829-
cliResult <- err
830-
return
853+
err = <-cliResult
854+
if err != nil {
855+
t.Fatalf("client error: %v", err)
831856
}
832-
833-
close(cliResult)
834-
}()
835-
836-
conn, err := pl.Accept()
837-
if err != nil {
838-
t.Fatalf("err: %v", err)
839-
}
840-
defer conn.Close()
841-
842-
recv := make([]byte, 4)
843-
if _, err = conn.Read(recv); err != ErrSuperfluousProxyHeader {
844-
t.Fatalf("Expected error %v, received %v", ErrSuperfluousProxyHeader, err)
845-
}
846-
err = <-cliResult
847-
if err != nil {
848-
t.Fatalf("client error: %v", err)
849857
}
850858
}
851859
func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) {

0 commit comments

Comments
 (0)