Skip to content

Commit 99b4efa

Browse files
authored
Add DeregisterExtension method (#92)
When the extension manager shuts down, osquery states it should deregisters. When the extension is started again, it can avoid errors when osquery starts the new process before the register watcher has had time to clean up the old one.
1 parent df1df42 commit 99b4efa

File tree

4 files changed

+51
-4
lines changed

4 files changed

+51
-4
lines changed

client.go

+6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ type ExtensionManager interface {
1717
Call(registry, item string, req osquery.ExtensionPluginRequest) (*osquery.ExtensionResponse, error)
1818
Extensions() (osquery.InternalExtensionList, error)
1919
RegisterExtension(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error)
20+
DeregisterExtension(uuid osquery.ExtensionRouteUUID) (*osquery.ExtensionStatus, error)
2021
Options() (osquery.InternalOptionList, error)
2122
Query(sql string) (*osquery.ExtensionResponse, error)
2223
GetQueryColumns(sql string) (*osquery.ExtensionResponse, error)
@@ -73,6 +74,11 @@ func (c *ExtensionManagerClient) RegisterExtension(info *osquery.InternalExtensi
7374
return c.Client.RegisterExtension(context.Background(), info, registry)
7475
}
7576

77+
// DeregisterExtension de-registers the extension plugins with the osquery process.
78+
func (c *ExtensionManagerClient) DeregisterExtension(uuid osquery.ExtensionRouteUUID) (*osquery.ExtensionStatus, error) {
79+
return c.Client.DeregisterExtension(context.Background(), uuid)
80+
}
81+
7682
// Options requests the list of bootstrap or configuration options.
7783
func (c *ExtensionManagerClient) Options() (osquery.InternalOptionList, error) {
7884
return c.Client.Options(context.Background())

mock_manager.go

+10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ type ExtensionsFunc func() (osquery.InternalExtensionList, error)
1616

1717
type RegisterExtensionFunc func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error)
1818

19+
type DeregisterExtensionFunc func(uuid osquery.ExtensionRouteUUID) (*osquery.ExtensionStatus, error)
20+
1921
type OptionsFunc func() (osquery.InternalOptionList, error)
2022

2123
type QueryFunc func(sql string) (*osquery.ExtensionResponse, error)
@@ -38,6 +40,9 @@ type MockExtensionManager struct {
3840
RegisterExtensionFunc RegisterExtensionFunc
3941
RegisterExtensionFuncInvoked bool
4042

43+
DeRegisterExtensionFunc DeregisterExtensionFunc
44+
DeRegisterExtensionFuncInvoked bool
45+
4146
OptionsFunc OptionsFunc
4247
OptionsFuncInvoked bool
4348

@@ -73,6 +78,11 @@ func (m *MockExtensionManager) RegisterExtension(info *osquery.InternalExtension
7378
return m.RegisterExtensionFunc(info, registry)
7479
}
7580

81+
func (m *MockExtensionManager) DeregisterExtension(uuid osquery.ExtensionRouteUUID) (*osquery.ExtensionStatus, error) {
82+
m.DeRegisterExtensionFuncInvoked = true
83+
return m.DeRegisterExtensionFunc(uuid)
84+
}
85+
7686
func (m *MockExtensionManager) Options() (osquery.InternalOptionList, error) {
7787
m.OptionsFuncInvoked = true
7888
return m.OptionsFunc()

server.go

+17-4
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ type ExtensionManagerServer struct {
4949
timeout time.Duration
5050
pingInterval time.Duration // How often to ping osquery server
5151
mutex sync.Mutex
52+
uuid osquery.ExtensionRouteUUID
5253
started bool // Used to ensure tests wait until the server is actually started
5354
}
5455

@@ -153,14 +154,20 @@ func (s *ExtensionManagerServer) Start() error {
153154
if stat.Code != 0 {
154155
return errors.Errorf("status %d registering extension: %s", stat.Code, stat.Message)
155156
}
157+
s.uuid = stat.UUID
156158

157159
listenPath := fmt.Sprintf("%s.%d", s.sockPath, stat.UUID)
158160

159161
processor := osquery.NewExtensionProcessor(s)
160162

161163
s.transport, err = transport.OpenServer(listenPath, s.timeout)
162164
if err != nil {
163-
return errors.Wrapf(err, "opening server socket (%s)", listenPath)
165+
openError := errors.Wrapf(err, "opening server socket (%s)", listenPath)
166+
_, err = s.serverClient.DeregisterExtension(stat.UUID)
167+
if err != nil {
168+
return errors.Wrapf(err, "deregistering extension - follows %s", openError.Error())
169+
}
170+
return openError
164171
}
165172

166173
s.server = thrift.NewTSimpleServer2(processor, s.transport)
@@ -242,10 +249,16 @@ func (s *ExtensionManagerServer) Call(ctx context.Context, registry string, item
242249
return &response, nil
243250
}
244251

245-
// Shutdown stops the server and closes the listening socket.
246-
func (s *ExtensionManagerServer) Shutdown(ctx context.Context) error {
252+
// Shutdown deregisters the extension, stops the server and closes all sockets.
253+
func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) {
247254
s.mutex.Lock()
248255
defer s.mutex.Unlock()
256+
stat, err := s.serverClient.DeregisterExtension(s.uuid)
257+
err = errors.Wrap(err, "deregistering extension")
258+
if err == nil && stat.Code != 0 {
259+
err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message)
260+
}
261+
s.serverClient.Close()
249262
if s.server != nil {
250263
server := s.server
251264
s.server = nil
@@ -258,7 +271,7 @@ func (s *ExtensionManagerServer) Shutdown(ctx context.Context) error {
258271
}()
259272
}
260273

261-
return nil
274+
return
262275
}
263276

264277
// Useful for testing

server_test.go

+18
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ func TestNoDeadlockOnError(t *testing.T) {
3636
PingFunc: func() (*osquery.ExtensionStatus, error) {
3737
return &osquery.ExtensionStatus{}, nil
3838
},
39+
DeRegisterExtensionFunc: func(uuid osquery.ExtensionRouteUUID) (*osquery.ExtensionStatus, error) {
40+
return &osquery.ExtensionStatus{}, nil
41+
},
42+
CloseFunc: func() {},
3943
}
4044
server := &ExtensionManagerServer{
4145
serverClient: mock,
@@ -70,6 +74,10 @@ func TestShutdownWhenPingFails(t *testing.T) {
7074
// As if the socket was closed
7175
return nil, syscall.EPIPE
7276
},
77+
DeRegisterExtensionFunc: func(uuid osquery.ExtensionRouteUUID) (*osquery.ExtensionStatus, error) {
78+
return &osquery.ExtensionStatus{}, nil
79+
},
80+
CloseFunc: func() {},
7381
}
7482
server := &ExtensionManagerServer{
7583
serverClient: mock,
@@ -79,6 +87,8 @@ func TestShutdownWhenPingFails(t *testing.T) {
7987
err := server.Run()
8088
assert.Error(t, err)
8189
assert.Contains(t, err.Error(), "broken pipe")
90+
assert.True(t, mock.DeRegisterExtensionFuncInvoked)
91+
assert.True(t, mock.CloseFuncInvoked)
8292
}
8393

8494
// How many parallel tests to run (because sync issues do not occur on every
@@ -104,6 +114,10 @@ func testShutdownDeadlock(t *testing.T) {
104114
RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
105115
return &osquery.ExtensionStatus{Code: 0, UUID: retUUID}, nil
106116
},
117+
DeRegisterExtensionFunc: func(uuid osquery.ExtensionRouteUUID) (*osquery.ExtensionStatus, error) {
118+
return &osquery.ExtensionStatus{}, nil
119+
},
120+
CloseFunc: func() {},
107121
}
108122
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}
109123

@@ -172,6 +186,10 @@ func TestShutdownBasic(t *testing.T) {
172186
RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
173187
return &osquery.ExtensionStatus{Code: 0, UUID: retUUID}, nil
174188
},
189+
DeRegisterExtensionFunc: func(uuid osquery.ExtensionRouteUUID) (*osquery.ExtensionStatus, error) {
190+
return &osquery.ExtensionStatus{}, nil
191+
},
192+
CloseFunc: func() {},
175193
}
176194
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}
177195

0 commit comments

Comments
 (0)