@@ -13,6 +13,8 @@ import (
13
13
"io"
14
14
"io/ioutil"
15
15
"net"
16
+ "net/http"
17
+ "sync/atomic"
16
18
"testing"
17
19
"time"
18
20
)
@@ -83,7 +85,6 @@ func TestRequiredWithReadHeaderTimeout(t *testing.T) {
83
85
start := time .Now ()
84
86
85
87
l , err := net .Listen ("tcp" , "127.0.0.1:0" )
86
-
87
88
if err != nil {
88
89
t .Fatalf ("err: %v" , err )
89
90
}
@@ -138,7 +139,6 @@ func TestUseWithReadHeaderTimeout(t *testing.T) {
138
139
start := time .Now ()
139
140
140
141
l , err := net .Listen ("tcp" , "127.0.0.1:0" )
141
-
142
142
if err != nil {
143
143
t .Fatalf ("err: %v" , err )
144
144
}
@@ -848,6 +848,7 @@ func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) {
848
848
t .Fatalf ("client error: %v" , err )
849
849
}
850
850
}
851
+
851
852
func TestIgnorePolicyIgnoresIpFromProxyHeader (t * testing.T ) {
852
853
l , err := net .Listen ("tcp" , "127.0.0.1:0" )
853
854
if err != nil {
@@ -1275,6 +1276,67 @@ func Test_ConnectionErrorsWhenHeaderValidationFails(t *testing.T) {
1275
1276
}
1276
1277
}
1277
1278
1279
+ func Test_ConnectionHandlesInvalidUpstreamError (t * testing.T ) {
1280
+ l , err := net .Listen ("tcp" , "localhost:8080" )
1281
+ if err != nil {
1282
+ t .Fatalf ("error creating listener: %v" , err )
1283
+ }
1284
+
1285
+ var connectionCounter atomic.Int32
1286
+
1287
+ newLn := & Listener {
1288
+ Listener : l ,
1289
+ ConnPolicy : func (_ ConnPolicyOptions ) (Policy , error ) {
1290
+ // Return the invalid upstream error on the first call, the listener
1291
+ // should remain open and accepting.
1292
+ times := connectionCounter .Load ()
1293
+ if times == 0 {
1294
+ connectionCounter .Store (times + 1 )
1295
+ return REJECT , ErrInvalidUpstream
1296
+ }
1297
+
1298
+ return REJECT , ErrNoProxyProtocol
1299
+ },
1300
+ }
1301
+
1302
+ // Kick off the listener and return any error via the chanel.
1303
+ errCh := make (chan error )
1304
+ defer close (errCh )
1305
+ go func (t * testing.T ) {
1306
+ _ , err := newLn .Accept ()
1307
+ errCh <- err
1308
+ }(t )
1309
+
1310
+ // Make two calls to trigger the listener's accept, the first should experience
1311
+ // the ErrInvalidUpstream and keep the listener open, the second should experience
1312
+ // a different error which will cause the listener to close.
1313
+ _ , _ = http .Get ("http://localhost:8080" )
1314
+ // Wait a few seconds to ensure we didn't get anything back on our channel.
1315
+ select {
1316
+ case err := <- errCh :
1317
+ if err != nil {
1318
+ t .Fatalf ("invalid upstream shouldn't return an error: %v" , err )
1319
+ }
1320
+ case <- time .After (2 * time .Second ):
1321
+ // No error returned (as expected, we're still listening though)
1322
+ }
1323
+
1324
+ _ , _ = http .Get ("http://localhost:8080" )
1325
+ // Wait a few seconds before we fail the test as we should have received an
1326
+ // error that was not invalid upstream.
1327
+ select {
1328
+ case err := <- errCh :
1329
+ if err == nil {
1330
+ t .Fatalf ("errors other than invalid upstream should error" )
1331
+ }
1332
+ if ! errors .Is (ErrNoProxyProtocol , err ) {
1333
+ t .Fatalf ("unexpected error type: %v" , err )
1334
+ }
1335
+ case <- time .After (2 * time .Second ):
1336
+ t .Fatalf ("timed out waiting for listener" )
1337
+ }
1338
+ }
1339
+
1278
1340
type TestTLSServer struct {
1279
1341
Listener net.Listener
1280
1342
@@ -1483,9 +1545,11 @@ func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
1483
1545
b , err := ioutil .ReadAll (r )
1484
1546
return int64 (len (b )), err
1485
1547
}
1548
+
1486
1549
func (c * testConn ) Write (p []byte ) (int , error ) {
1487
1550
return len (p ), nil
1488
1551
}
1552
+
1489
1553
func (c * testConn ) Read (p []byte ) (int , error ) {
1490
1554
if c .reads == 0 {
1491
1555
return 0 , io .EOF
@@ -1534,7 +1598,7 @@ func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
1534
1598
}
1535
1599
1536
1600
func benchmarkTCPProxy (size int , b * testing.B ) {
1537
- //create and start the echo backend
1601
+ // create and start the echo backend
1538
1602
backend , err := net .Listen ("tcp" , "127.0.0.1:0" )
1539
1603
if err != nil {
1540
1604
b .Fatalf ("err: %v" , err )
@@ -1555,7 +1619,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
1555
1619
}
1556
1620
}()
1557
1621
1558
- //start the proxyprotocol enabled tcp proxy
1622
+ // start the proxyprotocol enabled tcp proxy
1559
1623
l , err := net .Listen ("tcp" , "127.0.0.1:0" )
1560
1624
if err != nil {
1561
1625
b .Fatalf ("err: %v" , err )
@@ -1604,7 +1668,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
1604
1668
},
1605
1669
}
1606
1670
1607
- //now for the actual benchmark
1671
+ // now for the actual benchmark
1608
1672
b .ResetTimer ()
1609
1673
for n := 0 ; n < b .N ; n ++ {
1610
1674
conn , err := net .Dial ("tcp" , pl .Addr ().String ())
@@ -1615,16 +1679,15 @@ func benchmarkTCPProxy(size int, b *testing.B) {
1615
1679
if _ , err := header .WriteTo (conn ); err != nil {
1616
1680
b .Fatalf ("err: %v" , err )
1617
1681
}
1618
- //send data
1682
+ // send data
1619
1683
go func () {
1620
1684
_ , err = conn .Write (data )
1621
1685
_ = conn .(* net.TCPConn ).CloseWrite ()
1622
1686
if err != nil {
1623
1687
panic (fmt .Sprintf ("Failed to write data: %v" , err ))
1624
1688
}
1625
-
1626
1689
}()
1627
- //receive data
1690
+ // receive data
1628
1691
n , err := io .Copy (ioutil .Discard , conn )
1629
1692
if n != int64 (len (data )) {
1630
1693
b .Fatalf ("Expected to receive %d bytes, got %d" , len (data ), n )
@@ -1639,24 +1702,31 @@ func benchmarkTCPProxy(size int, b *testing.B) {
1639
1702
func BenchmarkTCPProxy16KB (b * testing.B ) {
1640
1703
benchmarkTCPProxy (16 * 1024 , b )
1641
1704
}
1705
+
1642
1706
func BenchmarkTCPProxy32KB (b * testing.B ) {
1643
1707
benchmarkTCPProxy (32 * 1024 , b )
1644
1708
}
1709
+
1645
1710
func BenchmarkTCPProxy64KB (b * testing.B ) {
1646
1711
benchmarkTCPProxy (64 * 1024 , b )
1647
1712
}
1713
+
1648
1714
func BenchmarkTCPProxy128KB (b * testing.B ) {
1649
1715
benchmarkTCPProxy (128 * 1024 , b )
1650
1716
}
1717
+
1651
1718
func BenchmarkTCPProxy256KB (b * testing.B ) {
1652
1719
benchmarkTCPProxy (256 * 1024 , b )
1653
1720
}
1721
+
1654
1722
func BenchmarkTCPProxy512KB (b * testing.B ) {
1655
1723
benchmarkTCPProxy (512 * 1024 , b )
1656
1724
}
1725
+
1657
1726
func BenchmarkTCPProxy1024KB (b * testing.B ) {
1658
1727
benchmarkTCPProxy (1024 * 1024 , b )
1659
1728
}
1729
+
1660
1730
func BenchmarkTCPProxy2048KB (b * testing.B ) {
1661
1731
benchmarkTCPProxy (2048 * 1024 , b )
1662
1732
}
0 commit comments