@@ -24,7 +24,7 @@ import (
24
24
// Verify that an error in server.Start will return an error instead of deadlock.
25
25
func TestNoDeadlockOnError (t * testing.T ) {
26
26
registry := make (map [string ](map [string ]OsqueryPlugin ))
27
- for reg , _ := range validRegistryNames {
27
+ for reg := range validRegistryNames {
28
28
registry [reg ] = make (map [string ]OsqueryPlugin )
29
29
}
30
30
mut := sync.Mutex {}
@@ -43,8 +43,9 @@ func TestNoDeadlockOnError(t *testing.T) {
43
43
CloseFunc : func () {},
44
44
}
45
45
server := & ExtensionManagerServer {
46
- serverClient : mock ,
47
- registry : registry ,
46
+ serverClient : mock ,
47
+ registry : registry ,
48
+ serverClientShouldShutdown : true ,
48
49
}
49
50
50
51
log := func (ctx context.Context , typ logger.LogType , logText string ) error {
@@ -63,8 +64,12 @@ func TestNoDeadlockOnError(t *testing.T) {
63
64
// Ensure that the extension server will shutdown and return if the osquery
64
65
// instance it is talking to stops responding to pings.
65
66
func TestShutdownWhenPingFails (t * testing.T ) {
67
+ tempPath , err := ioutil .TempFile ("" , "" )
68
+ require .Nil (t , err )
69
+ defer os .Remove (tempPath .Name ())
70
+
66
71
registry := make (map [string ](map [string ]OsqueryPlugin ))
67
- for reg , _ := range validRegistryNames {
72
+ for reg := range validRegistryNames {
68
73
registry [reg ] = make (map [string ]OsqueryPlugin )
69
74
}
70
75
mock := & MockExtensionManager {
@@ -81,11 +86,14 @@ func TestShutdownWhenPingFails(t *testing.T) {
81
86
CloseFunc : func () {},
82
87
}
83
88
server := & ExtensionManagerServer {
84
- serverClient : mock ,
85
- registry : registry ,
89
+ serverClient : mock ,
90
+ registry : registry ,
91
+ serverClientShouldShutdown : true ,
92
+ pingInterval : 1 * time .Second ,
93
+ sockPath : tempPath .Name (),
86
94
}
87
95
88
- err : = server .Run ()
96
+ err = server .Run ()
89
97
assert .Error (t , err )
90
98
assert .Contains (t , err .Error (), "broken pipe" )
91
99
assert .True (t , mock .DeRegisterExtensionFuncInvoked )
@@ -106,6 +114,7 @@ func TestShutdownDeadlock(t *testing.T) {
106
114
})
107
115
}
108
116
}
117
+
109
118
func testShutdownDeadlock (t * testing.T , uuid int ) {
110
119
tempPath , err := ioutil .TempFile ("" , "" )
111
120
require .Nil (t , err )
@@ -122,9 +131,10 @@ func testShutdownDeadlock(t *testing.T, uuid int) {
122
131
CloseFunc : func () {},
123
132
}
124
133
server := ExtensionManagerServer {
125
- serverClient : mock ,
126
- sockPath : tempPath .Name (),
127
- timeout : defaultTimeout ,
134
+ serverClient : mock ,
135
+ sockPath : tempPath .Name (),
136
+ timeout : defaultTimeout ,
137
+ serverClientShouldShutdown : true ,
128
138
}
129
139
130
140
var wait sync.WaitGroup
@@ -152,8 +162,12 @@ func testShutdownDeadlock(t *testing.T, uuid int) {
152
162
for ! opened && attempt < 10 {
153
163
transport = thrift .NewTSocketFromAddrTimeout (addr , timeout , timeout )
154
164
err = transport .Open ()
155
- opened = err == nil
156
165
attempt ++
166
+ if err != nil {
167
+ time .Sleep (1 * time .Second )
168
+ } else {
169
+ opened = true
170
+ }
157
171
}
158
172
require .NoError (t , err )
159
173
client := osquery .NewExtensionManagerClientFactory (transport ,
@@ -193,9 +207,13 @@ func testShutdownDeadlock(t *testing.T, uuid int) {
193
207
}
194
208
195
209
func TestShutdownBasic (t * testing.T ) {
196
- tempPath , err := ioutil .TempFile ("" , "" )
197
- require .Nil (t , err )
198
- defer os .Remove (tempPath .Name ())
210
+ dir := t .TempDir ()
211
+
212
+ tempPath := func () string {
213
+ tmp , err := os .CreateTemp (dir , "" )
214
+ require .NoError (t , err )
215
+ return tmp .Name ()
216
+ }
199
217
200
218
retUUID := osquery .ExtensionRouteUUID (0 )
201
219
mock := & MockExtensionManager {
@@ -207,26 +225,38 @@ func TestShutdownBasic(t *testing.T) {
207
225
},
208
226
CloseFunc : func () {},
209
227
}
210
- server := ExtensionManagerServer {serverClient : mock , sockPath : tempPath .Name ()}
211
228
212
- completed := make (chan struct {})
213
- go func () {
214
- err := server .Start ()
229
+ for _ , server := range []* ExtensionManagerServer {
230
+ // Create the extension manager without using NewExtensionManagerServer.
231
+ {serverClient : mock , sockPath : tempPath ()},
232
+ // Create the extension manager using ExtensionManagerServer.
233
+ {serverClient : mock , sockPath : tempPath (), serverClientShouldShutdown : true },
234
+ } {
235
+ completed := make (chan struct {})
236
+ go func () {
237
+ err := server .Start ()
238
+ require .NoError (t , err )
239
+ close (completed )
240
+ }()
241
+
242
+ server .waitStarted ()
243
+
244
+ err := server .Shutdown (context .Background ())
215
245
require .NoError (t , err )
216
- close (completed )
217
- }()
218
246
219
- server .waitStarted ()
220
- err = server .Shutdown (context .Background ())
221
- require .NoError (t , err )
247
+ // Test that server.Shutdown is idempotent.
248
+ err = server .Shutdown (context .Background ())
249
+ require .NoError (t , err )
250
+
251
+ // Either indicate successful shutdown, or fatal the test because it
252
+ // hung
253
+ select {
254
+ case <- completed :
255
+ // Success. Do nothing.
256
+ case <- time .After (5 * time .Second ):
257
+ t .Fatal ("hung on shutdown" )
258
+ }
222
259
223
- // Either indicate successful shutdown, or fatal the test because it
224
- // hung
225
- select {
226
- case <- completed :
227
- // Success. Do nothing.
228
- case <- time .After (5 * time .Second ):
229
- t .Fatal ("hung on shutdown" )
230
260
}
231
261
}
232
262
0 commit comments