@@ -12,6 +12,8 @@ import (
12
12
"fmt"
13
13
"io"
14
14
"net"
15
+ "net/http"
16
+ "sync/atomic"
15
17
"testing"
16
18
"time"
17
19
)
@@ -82,7 +84,6 @@ func TestRequiredWithReadHeaderTimeout(t *testing.T) {
82
84
start := time .Now ()
83
85
84
86
l , err := net .Listen ("tcp" , "127.0.0.1:0" )
85
-
86
87
if err != nil {
87
88
t .Fatalf ("err: %v" , err )
88
89
}
@@ -137,7 +138,6 @@ func TestUseWithReadHeaderTimeout(t *testing.T) {
137
138
start := time .Now ()
138
139
139
140
l , err := net .Listen ("tcp" , "127.0.0.1:0" )
140
-
141
141
if err != nil {
142
142
t .Fatalf ("err: %v" , err )
143
143
}
@@ -847,6 +847,7 @@ func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) {
847
847
t .Fatalf ("client error: %v" , err )
848
848
}
849
849
}
850
+
850
851
func TestIgnorePolicyIgnoresIpFromProxyHeader (t * testing.T ) {
851
852
l , err := net .Listen ("tcp" , "127.0.0.1:0" )
852
853
if err != nil {
@@ -1274,6 +1275,67 @@ func Test_ConnectionErrorsWhenHeaderValidationFails(t *testing.T) {
1274
1275
}
1275
1276
}
1276
1277
1278
+ func Test_ConnectionHandlesInvalidUpstreamError (t * testing.T ) {
1279
+ l , err := net .Listen ("tcp" , "localhost:8080" )
1280
+ if err != nil {
1281
+ t .Fatalf ("error creating listener: %v" , err )
1282
+ }
1283
+
1284
+ var connectionCounter atomic.Int32
1285
+
1286
+ newLn := & Listener {
1287
+ Listener : l ,
1288
+ ConnPolicy : func (_ ConnPolicyOptions ) (Policy , error ) {
1289
+ // Return the invalid upstream error on the first call, the listener
1290
+ // should remain open and accepting.
1291
+ times := connectionCounter .Load ()
1292
+ if times == 0 {
1293
+ connectionCounter .Store (times + 1 )
1294
+ return REJECT , ErrInvalidUpstream
1295
+ }
1296
+
1297
+ return REJECT , ErrNoProxyProtocol
1298
+ },
1299
+ }
1300
+
1301
+ // Kick off the listener and return any error via the chanel.
1302
+ errCh := make (chan error )
1303
+ defer close (errCh )
1304
+ go func (t * testing.T ) {
1305
+ _ , err := newLn .Accept ()
1306
+ errCh <- err
1307
+ }(t )
1308
+
1309
+ // Make two calls to trigger the listener's accept, the first should experience
1310
+ // the ErrInvalidUpstream and keep the listener open, the second should experience
1311
+ // a different error which will cause the listener to close.
1312
+ _ , _ = http .Get ("http://localhost:8080" )
1313
+ // Wait a few seconds to ensure we didn't get anything back on our channel.
1314
+ select {
1315
+ case err := <- errCh :
1316
+ if err != nil {
1317
+ t .Fatalf ("invalid upstream shouldn't return an error: %v" , err )
1318
+ }
1319
+ case <- time .After (2 * time .Second ):
1320
+ // No error returned (as expected, we're still listening though)
1321
+ }
1322
+
1323
+ _ , _ = http .Get ("http://localhost:8080" )
1324
+ // Wait a few seconds before we fail the test as we should have received an
1325
+ // error that was not invalid upstream.
1326
+ select {
1327
+ case err := <- errCh :
1328
+ if err == nil {
1329
+ t .Fatalf ("errors other than invalid upstream should error" )
1330
+ }
1331
+ if ! errors .Is (ErrNoProxyProtocol , err ) {
1332
+ t .Fatalf ("unexpected error type: %v" , err )
1333
+ }
1334
+ case <- time .After (2 * time .Second ):
1335
+ t .Fatalf ("timed out waiting for listener" )
1336
+ }
1337
+ }
1338
+
1277
1339
type TestTLSServer struct {
1278
1340
Listener net.Listener
1279
1341
@@ -1482,9 +1544,11 @@ func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
1482
1544
b , err := io .ReadAll (r )
1483
1545
return int64 (len (b )), err
1484
1546
}
1547
+
1485
1548
func (c * testConn ) Write (p []byte ) (int , error ) {
1486
1549
return len (p ), nil
1487
1550
}
1551
+
1488
1552
func (c * testConn ) Read (p []byte ) (int , error ) {
1489
1553
if c .reads == 0 {
1490
1554
return 0 , io .EOF
@@ -1533,7 +1597,7 @@ func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
1533
1597
}
1534
1598
1535
1599
func benchmarkTCPProxy (size int , b * testing.B ) {
1536
- //create and start the echo backend
1600
+ // create and start the echo backend
1537
1601
backend , err := net .Listen ("tcp" , "127.0.0.1:0" )
1538
1602
if err != nil {
1539
1603
b .Fatalf ("err: %v" , err )
@@ -1554,7 +1618,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
1554
1618
}
1555
1619
}()
1556
1620
1557
- //start the proxyprotocol enabled tcp proxy
1621
+ // start the proxyprotocol enabled tcp proxy
1558
1622
l , err := net .Listen ("tcp" , "127.0.0.1:0" )
1559
1623
if err != nil {
1560
1624
b .Fatalf ("err: %v" , err )
@@ -1603,7 +1667,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
1603
1667
},
1604
1668
}
1605
1669
1606
- //now for the actual benchmark
1670
+ // now for the actual benchmark
1607
1671
b .ResetTimer ()
1608
1672
for n := 0 ; n < b .N ; n ++ {
1609
1673
conn , err := net .Dial ("tcp" , pl .Addr ().String ())
@@ -1614,16 +1678,15 @@ func benchmarkTCPProxy(size int, b *testing.B) {
1614
1678
if _ , err := header .WriteTo (conn ); err != nil {
1615
1679
b .Fatalf ("err: %v" , err )
1616
1680
}
1617
- //send data
1681
+ // send data
1618
1682
go func () {
1619
1683
_ , err = conn .Write (data )
1620
1684
_ = conn .(* net.TCPConn ).CloseWrite ()
1621
1685
if err != nil {
1622
1686
panic (fmt .Sprintf ("Failed to write data: %v" , err ))
1623
1687
}
1624
-
1625
1688
}()
1626
- //receive data
1689
+ // receive data
1627
1690
n , err := io .Copy (io .Discard , conn )
1628
1691
if n != int64 (len (data )) {
1629
1692
b .Fatalf ("Expected to receive %d bytes, got %d" , len (data ), n )
@@ -1638,24 +1701,31 @@ func benchmarkTCPProxy(size int, b *testing.B) {
1638
1701
func BenchmarkTCPProxy16KB (b * testing.B ) {
1639
1702
benchmarkTCPProxy (16 * 1024 , b )
1640
1703
}
1704
+
1641
1705
func BenchmarkTCPProxy32KB (b * testing.B ) {
1642
1706
benchmarkTCPProxy (32 * 1024 , b )
1643
1707
}
1708
+
1644
1709
func BenchmarkTCPProxy64KB (b * testing.B ) {
1645
1710
benchmarkTCPProxy (64 * 1024 , b )
1646
1711
}
1712
+
1647
1713
func BenchmarkTCPProxy128KB (b * testing.B ) {
1648
1714
benchmarkTCPProxy (128 * 1024 , b )
1649
1715
}
1716
+
1650
1717
func BenchmarkTCPProxy256KB (b * testing.B ) {
1651
1718
benchmarkTCPProxy (256 * 1024 , b )
1652
1719
}
1720
+
1653
1721
func BenchmarkTCPProxy512KB (b * testing.B ) {
1654
1722
benchmarkTCPProxy (512 * 1024 , b )
1655
1723
}
1724
+
1656
1725
func BenchmarkTCPProxy1024KB (b * testing.B ) {
1657
1726
benchmarkTCPProxy (1024 * 1024 , b )
1658
1727
}
1728
+
1659
1729
func BenchmarkTCPProxy2048KB (b * testing.B ) {
1660
1730
benchmarkTCPProxy (2048 * 1024 , b )
1661
1731
}
0 commit comments