Skip to content

Commit f951fe0

Browse files
committed
martian: add 1m graceful close timeout for bicopy
This will prevent the proxy from leaking connections with defective servers. This cannot be tested due to #1013. Enable the test again once that is fixed. For now I tested it manually. Contributes to #800.
1 parent 77bc57e commit f951fe0

File tree

3 files changed

+110
-3
lines changed

3 files changed

+110
-3
lines changed

internal/martian/closewriter.go renamed to internal/martian/close.go

+8
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,11 @@ func asCloseWriter(w io.Writer) (closeWriter, bool) {
4343

4444
return reflectx.LookupImpl[closeWriter](reflect.ValueOf(w))
4545
}
46+
47+
func asCloser(w any) (io.Closer, bool) {
48+
if c, ok := w.(io.Closer); ok {
49+
return c, ok
50+
}
51+
52+
return reflectx.LookupImpl[io.Closer](reflect.ValueOf(w))
53+
}

internal/martian/copy.go

+41-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"context"
2222
"io"
2323
"sync"
24+
"time"
2425

2526
"github.com/saucelabs/forwarder/internal/martian/log"
2627
)
@@ -43,13 +44,35 @@ var copyBufPool = sync.Pool{
4344
},
4445
}
4546

47+
var bicopyGracefulTimeout = 1 * time.Minute
48+
4649
func bicopy(ctx context.Context, cc ...copier) {
50+
ctx, cancel := context.WithCancel(ctx)
51+
defer cancel()
52+
4753
donec := make(chan struct{}, len(cc))
4854
for i := range cc {
4955
go cc[i].copy(ctx, donec)
5056
}
51-
for range cc {
57+
58+
for i := range cc {
5259
<-donec
60+
if i == 0 {
61+
// Forcibly close all tunnels 1 minute after the first tunnel finished.
62+
go gracefulCloseAfter(ctx, bicopyGracefulTimeout, cc...)
63+
}
64+
}
65+
}
66+
67+
func gracefulCloseAfter(ctx context.Context, d time.Duration, cc ...copier) {
68+
select {
69+
case <-ctx.Done():
70+
return
71+
case <-time.After(d):
72+
log.Infof(ctx, "forcibly closing tunnel after %v of graceful period", d)
73+
}
74+
for i := range cc {
75+
cc[i].close(ctx)
5376
}
5477
}
5578

@@ -67,6 +90,13 @@ func (c copier) copy(ctx context.Context, donec chan<- struct{}) {
6790
if _, err := io.CopyBuffer(c.dst, c.src, buf); err != nil && !isClosedConnError(err) {
6891
log.Errorf(ctx, "failed to copy %s tunnel: %v", c.name, err)
6992
}
93+
c.closeWriter(ctx)
94+
95+
log.Debugf(ctx, "%s tunnel finished copying", c.name)
96+
donec <- struct{}{}
97+
}
98+
99+
func (c copier) closeWriter(ctx context.Context) {
70100
var closeErr error
71101
if cw, ok := asCloseWriter(c.dst); ok {
72102
closeErr = cw.CloseWrite()
@@ -78,7 +108,15 @@ func (c copier) copy(ctx context.Context, donec chan<- struct{}) {
78108
if closeErr != nil {
79109
log.Infof(ctx, "failed to close write side of %s tunnel: %v", c.name, closeErr)
80110
}
111+
}
81112

82-
log.Debugf(ctx, "%s tunnel finished copying", c.name)
83-
donec <- struct{}{}
113+
func (c copier) close(ctx context.Context) {
114+
cc, ok := asCloser(c.dst)
115+
if !ok {
116+
log.Errorf(ctx, "cannot close %s tunnel (%T)", c.name, c.dst)
117+
return
118+
}
119+
if err := cc.Close(); err != nil && !isClosedConnError(err) {
120+
log.Infof(ctx, "failed to close %s tunnel: %v", c.name, err)
121+
}
84122
}

internal/martian/proxy_test.go

+61
Original file line numberDiff line numberDiff line change
@@ -2045,3 +2045,64 @@ func TestReadHeaderConnectionReset(t *testing.T) {
20452045
t.Fatalf("conn.Read(): got %v, want io.EOF", err)
20462046
}
20472047
}
2048+
2049+
func TestTunnelGracefulClose(t *testing.T) {
2050+
t.Parallel()
2051+
2052+
t.Skip("panic: close of closed channel, See #1013")
2053+
2054+
l, err := net.Listen("tcp", "localhost:0")
2055+
if err != nil {
2056+
t.Fatalf("net.Listen(): got %v, want no error", err)
2057+
}
2058+
2059+
done := make(chan struct{})
2060+
2061+
// Malicious server that hangs indefinitely.
2062+
go func() {
2063+
t.Logf("Waiting for server side connection")
2064+
conn, err := l.Accept()
2065+
if err != nil {
2066+
t.Errorf("Got error while accepting connection on destination listener: %v", err)
2067+
return
2068+
}
2069+
defer conn.Close()
2070+
t.Logf("Accepted server side connection")
2071+
2072+
<-done
2073+
}()
2074+
2075+
bicopyGracefulTimeout = 100 * time.Millisecond
2076+
h := testHelper{
2077+
Proxy: func(p *Proxy) {
2078+
tr := new(ProxyTrace)
2079+
tr.WroteResponse = func(info WroteResponseInfo) {
2080+
close(done)
2081+
}
2082+
p.Trace = tr
2083+
},
2084+
}
2085+
conn, cancel := h.proxyConn(t)
2086+
defer cancel()
2087+
defer conn.Close()
2088+
2089+
req, err := http.NewRequest(http.MethodConnect, "//"+l.Addr().String(), http.NoBody)
2090+
if err != nil {
2091+
t.Fatalf("http.NewRequest(): got %v, want no error", err)
2092+
}
2093+
2094+
if err := req.Write(conn); err != nil {
2095+
t.Fatalf("req.Write(): got %v, want no error", err)
2096+
}
2097+
res, err := http.ReadResponse(bufio.NewReader(conn), req)
2098+
if err != nil {
2099+
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
2100+
}
2101+
res.Body.Close()
2102+
conn.Close()
2103+
select {
2104+
case <-done:
2105+
case <-time.After(5 * time.Second):
2106+
t.Fatalf("timed out waiting for tunnel to close all connections")
2107+
}
2108+
}

0 commit comments

Comments
 (0)