Skip to content

Commit b411f54

Browse files
authored
Fix TestShutdownDeadlock timeout (#119)
1 parent d6f325f commit b411f54

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

server.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ func (s *ExtensionManagerServer) Call(ctx context.Context, registry string, item
322322
func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) {
323323
s.mutex.Lock()
324324
defer s.mutex.Unlock()
325+
325326
stat, err := s.serverClient.DeregisterExtension(s.uuid)
326327
err = errors.Wrap(err, "deregistering extension")
327328
if err == nil && stat.Code != 0 {
@@ -333,7 +334,7 @@ func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) {
333334
s.server = nil
334335
// Stop the server asynchronously so that the current request
335336
// can complete. Otherwise, this is vulnerable to deadlock if a
336-
// shutdown request is being processed when shutdown is
337+
// shutdown request is being processed when Shutdown is
337338
// explicitly called.
338339
go func() {
339340
server.Stop()

server_test.go

+30-14
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io/ioutil"
88
"net"
99
"os"
10+
"runtime/pprof"
1011
"strings"
1112
"sync"
1213
"syscall"
@@ -98,18 +99,19 @@ const parallelTestShutdownDeadlock = 20
9899

99100
func TestShutdownDeadlock(t *testing.T) {
100101
for i := 0; i < parallelTestShutdownDeadlock; i++ {
102+
i := i
101103
t.Run("", func(t *testing.T) {
102104
t.Parallel()
103-
testShutdownDeadlock(t)
105+
testShutdownDeadlock(t, i)
104106
})
105107
}
106108
}
107-
func testShutdownDeadlock(t *testing.T) {
109+
func testShutdownDeadlock(t *testing.T, uuid int) {
108110
tempPath, err := ioutil.TempFile("", "")
109111
require.Nil(t, err)
110112
defer os.Remove(tempPath.Name())
111113

112-
retUUID := osquery.ExtensionRouteUUID(0)
114+
retUUID := osquery.ExtensionRouteUUID(uuid)
113115
mock := &MockExtensionManager{
114116
RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
115117
return &osquery.ExtensionStatus{Code: 0, UUID: retUUID}, nil
@@ -119,16 +121,22 @@ func testShutdownDeadlock(t *testing.T) {
119121
},
120122
CloseFunc: func() {},
121123
}
122-
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}
124+
server := ExtensionManagerServer{
125+
serverClient: mock,
126+
sockPath: tempPath.Name(),
127+
timeout: defaultTimeout,
128+
}
123129

124-
wait := sync.WaitGroup{}
130+
var wait sync.WaitGroup
125131

126-
wait.Add(1)
127132
go func() {
133+
// We do not wait for this routine to finish because thrift.TServer.Serve
134+
// seems to sometimes hang after shutdowns. (This test is just testing
135+
// the Shutdown doesn't hang.)
128136
err := server.Start()
129-
require.Nil(t, err)
130-
wait.Done()
137+
require.NoError(t, err)
131138
}()
139+
132140
// Wait for server to be set up
133141
server.waitStarted()
134142

@@ -138,10 +146,17 @@ func testShutdownDeadlock(t *testing.T) {
138146
addr, err := net.ResolveUnixAddr("unix", listenPath)
139147
require.Nil(t, err)
140148
timeout := 500 * time.Millisecond
141-
trans := thrift.NewTSocketFromAddrTimeout(addr, timeout, timeout)
142-
err = trans.Open()
143-
require.Nil(t, err)
144-
client := osquery.NewExtensionManagerClientFactory(trans,
149+
opened := false
150+
attempt := 0
151+
var transport *thrift.TSocket
152+
for !opened && attempt < 10 {
153+
transport = thrift.NewTSocketFromAddrTimeout(addr, timeout, timeout)
154+
err = transport.Open()
155+
opened = err == nil
156+
attempt++
157+
}
158+
require.NoError(t, err)
159+
client := osquery.NewExtensionManagerClientFactory(transport,
145160
thrift.NewTBinaryProtocolFactoryDefault())
146161

147162
// Simultaneously call shutdown through a request from the client and
@@ -156,7 +171,7 @@ func testShutdownDeadlock(t *testing.T) {
156171
go func() {
157172
defer wait.Done()
158173
err = server.Shutdown(context.Background())
159-
require.Nil(t, err)
174+
require.NoError(t, err)
160175
}()
161176

162177
// Track whether shutdown completed
@@ -171,7 +186,8 @@ func testShutdownDeadlock(t *testing.T) {
171186
select {
172187
case <-completed:
173188
// Success. Do nothing.
174-
case <-time.After(5 * time.Second):
189+
case <-time.After(10 * time.Second):
190+
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
175191
t.Fatal("hung on shutdown")
176192
}
177193
}

0 commit comments

Comments
 (0)