Skip to content

Commit e3cde12

Browse files
authored
Make ExtensionManagerServer.Shutdown idempotent (#117)
* Make Shutdown idempotent * Protect access to s.serverClient * Add sleep to make retry effective
1 parent b411f54 commit e3cde12

File tree

2 files changed

+74
-37
lines changed

2 files changed

+74
-37
lines changed

server.go

+14-7
Original file line numberDiff line numberDiff line change
@@ -256,12 +256,16 @@ func (s *ExtensionManagerServer) Run() error {
256256
for {
257257
time.Sleep(s.pingInterval)
258258

259+
s.mutex.Lock()
260+
serverClient := s.serverClient
261+
s.mutex.Unlock()
262+
259263
// can't ping if s.Shutdown has already happened
260-
if s.serverClient == nil {
264+
if serverClient == nil {
261265
break
262266
}
263267

264-
status, err := s.serverClient.Ping()
268+
status, err := serverClient.Ping()
265269
if err != nil {
266270
errc <- errors.Wrap(err, "extension ping failed")
267271
break
@@ -323,12 +327,15 @@ func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) {
323327
s.mutex.Lock()
324328
defer s.mutex.Unlock()
325329

326-
stat, err := s.serverClient.DeregisterExtension(s.uuid)
327-
err = errors.Wrap(err, "deregistering extension")
328-
if err == nil && stat.Code != 0 {
329-
err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message)
330+
if s.serverClient != nil {
331+
var stat *osquery.ExtensionStatus
332+
stat, err = s.serverClient.DeregisterExtension(s.uuid)
333+
err = errors.Wrap(err, "deregistering extension")
334+
if err == nil && stat.Code != 0 {
335+
err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message)
336+
}
330337
}
331-
s.serverClient.Close()
338+
332339
if s.server != nil {
333340
server := s.server
334341
s.server = nil

server_test.go

+60-30
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import (
2424
// Verify that an error in server.Start will return an error instead of deadlock.
2525
func TestNoDeadlockOnError(t *testing.T) {
2626
registry := make(map[string](map[string]OsqueryPlugin))
27-
for reg, _ := range validRegistryNames {
27+
for reg := range validRegistryNames {
2828
registry[reg] = make(map[string]OsqueryPlugin)
2929
}
3030
mut := sync.Mutex{}
@@ -43,8 +43,9 @@ func TestNoDeadlockOnError(t *testing.T) {
4343
CloseFunc: func() {},
4444
}
4545
server := &ExtensionManagerServer{
46-
serverClient: mock,
47-
registry: registry,
46+
serverClient: mock,
47+
registry: registry,
48+
serverClientShouldShutdown: true,
4849
}
4950

5051
log := func(ctx context.Context, typ logger.LogType, logText string) error {
@@ -63,8 +64,12 @@ func TestNoDeadlockOnError(t *testing.T) {
6364
// Ensure that the extension server will shutdown and return if the osquery
6465
// instance it is talking to stops responding to pings.
6566
func TestShutdownWhenPingFails(t *testing.T) {
67+
tempPath, err := ioutil.TempFile("", "")
68+
require.Nil(t, err)
69+
defer os.Remove(tempPath.Name())
70+
6671
registry := make(map[string](map[string]OsqueryPlugin))
67-
for reg, _ := range validRegistryNames {
72+
for reg := range validRegistryNames {
6873
registry[reg] = make(map[string]OsqueryPlugin)
6974
}
7075
mock := &MockExtensionManager{
@@ -81,11 +86,14 @@ func TestShutdownWhenPingFails(t *testing.T) {
8186
CloseFunc: func() {},
8287
}
8388
server := &ExtensionManagerServer{
84-
serverClient: mock,
85-
registry: registry,
89+
serverClient: mock,
90+
registry: registry,
91+
serverClientShouldShutdown: true,
92+
pingInterval: 1 * time.Second,
93+
sockPath: tempPath.Name(),
8694
}
8795

88-
err := server.Run()
96+
err = server.Run()
8997
assert.Error(t, err)
9098
assert.Contains(t, err.Error(), "broken pipe")
9199
assert.True(t, mock.DeRegisterExtensionFuncInvoked)
@@ -106,6 +114,7 @@ func TestShutdownDeadlock(t *testing.T) {
106114
})
107115
}
108116
}
117+
109118
func testShutdownDeadlock(t *testing.T, uuid int) {
110119
tempPath, err := ioutil.TempFile("", "")
111120
require.Nil(t, err)
@@ -122,9 +131,10 @@ func testShutdownDeadlock(t *testing.T, uuid int) {
122131
CloseFunc: func() {},
123132
}
124133
server := ExtensionManagerServer{
125-
serverClient: mock,
126-
sockPath: tempPath.Name(),
127-
timeout: defaultTimeout,
134+
serverClient: mock,
135+
sockPath: tempPath.Name(),
136+
timeout: defaultTimeout,
137+
serverClientShouldShutdown: true,
128138
}
129139

130140
var wait sync.WaitGroup
@@ -152,8 +162,12 @@ func testShutdownDeadlock(t *testing.T, uuid int) {
152162
for !opened && attempt < 10 {
153163
transport = thrift.NewTSocketFromAddrTimeout(addr, timeout, timeout)
154164
err = transport.Open()
155-
opened = err == nil
156165
attempt++
166+
if err != nil {
167+
time.Sleep(1 * time.Second)
168+
} else {
169+
opened = true
170+
}
157171
}
158172
require.NoError(t, err)
159173
client := osquery.NewExtensionManagerClientFactory(transport,
@@ -193,9 +207,13 @@ func testShutdownDeadlock(t *testing.T, uuid int) {
193207
}
194208

195209
func TestShutdownBasic(t *testing.T) {
196-
tempPath, err := ioutil.TempFile("", "")
197-
require.Nil(t, err)
198-
defer os.Remove(tempPath.Name())
210+
dir := t.TempDir()
211+
212+
tempPath := func() string {
213+
tmp, err := os.CreateTemp(dir, "")
214+
require.NoError(t, err)
215+
return tmp.Name()
216+
}
199217

200218
retUUID := osquery.ExtensionRouteUUID(0)
201219
mock := &MockExtensionManager{
@@ -207,26 +225,38 @@ func TestShutdownBasic(t *testing.T) {
207225
},
208226
CloseFunc: func() {},
209227
}
210-
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}
211228

212-
completed := make(chan struct{})
213-
go func() {
214-
err := server.Start()
229+
for _, server := range []*ExtensionManagerServer{
230+
// Create the extension manager without using NewExtensionManagerServer.
231+
{serverClient: mock, sockPath: tempPath()},
232+
// Create the extension manager using ExtensionManagerServer.
233+
{serverClient: mock, sockPath: tempPath(), serverClientShouldShutdown: true},
234+
} {
235+
completed := make(chan struct{})
236+
go func() {
237+
err := server.Start()
238+
require.NoError(t, err)
239+
close(completed)
240+
}()
241+
242+
server.waitStarted()
243+
244+
err := server.Shutdown(context.Background())
215245
require.NoError(t, err)
216-
close(completed)
217-
}()
218246

219-
server.waitStarted()
220-
err = server.Shutdown(context.Background())
221-
require.NoError(t, err)
247+
// Test that server.Shutdown is idempotent.
248+
err = server.Shutdown(context.Background())
249+
require.NoError(t, err)
250+
251+
// Either indicate successful shutdown, or fatal the test because it
252+
// hung
253+
select {
254+
case <-completed:
255+
// Success. Do nothing.
256+
case <-time.After(5 * time.Second):
257+
t.Fatal("hung on shutdown")
258+
}
222259

223-
// Either indicate successful shutdown, or fatal the test because it
224-
// hung
225-
select {
226-
case <-completed:
227-
// Success. Do nothing.
228-
case <-time.After(5 * time.Second):
229-
t.Fatal("hung on shutdown")
230260
}
231261
}
232262

0 commit comments

Comments
 (0)