Skip to content

Commit 6bcabfb

Browse files
Add context and lock functionality to client interface (#108)
1 parent d2e851b commit 6bcabfb

File tree

10 files changed

+520
-38
lines changed

10 files changed

+520
-38
lines changed

client.go

+161-30
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,76 @@ import (
66

77
"github.com/osquery/osquery-go/gen/osquery"
88
"github.com/osquery/osquery-go/transport"
9-
"github.com/pkg/errors"
109

1110
"github.com/apache/thrift/lib/go/thrift"
11+
"github.com/pkg/errors"
12+
)
13+
14+
const (
15+
defaultWaitTime = 200 * time.Millisecond
16+
defaultMaxWaitTime = 1 * time.Minute
1217
)
1318

1419
// ExtensionManagerClient is a wrapper for the osquery Thrift extensions API.
1520
type ExtensionManagerClient struct {
16-
Client osquery.ExtensionManager
21+
client osquery.ExtensionManager
1722
transport thrift.TTransport
23+
24+
waitTime time.Duration
25+
maxWaitTime time.Duration
26+
lock *locker
27+
}
28+
29+
type ClientOption func(*ExtensionManagerClient)
30+
31+
// WaitTime sets the default amount of wait time for the osquery socket to free up. You can override this on a per
32+
// call basis by setting a context deadline
33+
func DefaultWaitTime(d time.Duration) ClientOption {
34+
return func(c *ExtensionManagerClient) {
35+
c.waitTime = d
36+
}
37+
}
38+
39+
// MaxWaitTime is the maximum amount of time something is allowed to wait for the osquery socket. This takes precedence
40+
// over the context deadline.
41+
func MaxWaitTime(d time.Duration) ClientOption {
42+
return func(c *ExtensionManagerClient) {
43+
c.maxWaitTime = d
44+
}
1845
}
1946

2047
// NewClient creates a new client communicating to osquery over the socket at
2148
// the provided path. If resolving the address or connecting to the socket
2249
// fails, this function will error.
23-
func NewClient(path string, timeout time.Duration) (*ExtensionManagerClient, error) {
24-
trans, err := transport.Open(path, timeout)
25-
if err != nil {
26-
return nil, err
50+
func NewClient(path string, socketOpenTimeout time.Duration, opts ...ClientOption) (*ExtensionManagerClient, error) {
51+
c := &ExtensionManagerClient{
52+
waitTime: defaultWaitTime,
53+
maxWaitTime: defaultMaxWaitTime,
2754
}
2855

29-
client := osquery.NewExtensionManagerClientFactory(
30-
trans,
31-
thrift.NewTBinaryProtocolFactoryDefault(),
32-
)
56+
for _, opt := range opts {
57+
opt(c)
58+
}
3359

34-
return &ExtensionManagerClient{client, trans}, nil
60+
if c.waitTime > c.maxWaitTime {
61+
return nil, errors.New("default wait time larger than max wait time")
62+
}
63+
64+
c.lock = NewLocker(c.waitTime, c.maxWaitTime)
65+
66+
if c.client == nil {
67+
trans, err := transport.Open(path, socketOpenTimeout)
68+
if err != nil {
69+
return nil, err
70+
}
71+
72+
c.client = osquery.NewExtensionManagerClientFactory(
73+
trans,
74+
thrift.NewTBinaryProtocolFactoryDefault(),
75+
)
76+
}
77+
78+
return c, nil
3579
}
3680

3781
// Close should be called to close the transport when use of the client is
@@ -42,48 +86,120 @@ func (c *ExtensionManagerClient) Close() {
4286
}
4387
}
4488

45-
// Ping requests metadata from the extension manager.
89+
// Ping requests metadata from the extension manager, using a new background context
4690
func (c *ExtensionManagerClient) Ping() (*osquery.ExtensionStatus, error) {
47-
return c.Client.Ping(context.Background())
91+
return c.PingContext(context.Background())
4892
}
4993

50-
// Call requests a call to an extension (or core) registry plugin.
94+
// PingContext requests metadata from the extension manager.
95+
func (c *ExtensionManagerClient) PingContext(ctx context.Context) (*osquery.ExtensionStatus, error) {
96+
if err := c.lock.Lock(ctx); err != nil {
97+
return nil, err
98+
}
99+
defer c.lock.Unlock()
100+
return c.client.Ping(ctx)
101+
}
102+
103+
// Call requests a call to an extension (or core) registry plugin, using a new background context
51104
func (c *ExtensionManagerClient) Call(registry, item string, request osquery.ExtensionPluginRequest) (*osquery.ExtensionResponse, error) {
52-
return c.Client.Call(context.Background(), registry, item, request)
105+
return c.CallContext(context.Background(), registry, item, request)
53106
}
54107

55-
// Extensions requests the list of active registered extensions.
108+
// CallContext requests a call to an extension (or core) registry plugin.
109+
func (c *ExtensionManagerClient) CallContext(ctx context.Context, registry, item string, request osquery.ExtensionPluginRequest) (*osquery.ExtensionResponse, error) {
110+
if err := c.lock.Lock(ctx); err != nil {
111+
return nil, err
112+
}
113+
defer c.lock.Unlock()
114+
return c.client.Call(ctx, registry, item, request)
115+
}
116+
117+
// Extensions requests the list of active registered extensions, using a new background context
56118
func (c *ExtensionManagerClient) Extensions() (osquery.InternalExtensionList, error) {
57-
return c.Client.Extensions(context.Background())
119+
return c.ExtensionsContext(context.Background())
120+
}
121+
122+
// ExtensionsContext requests the list of active registered extensions.
123+
func (c *ExtensionManagerClient) ExtensionsContext(ctx context.Context) (osquery.InternalExtensionList, error) {
124+
if err := c.lock.Lock(ctx); err != nil {
125+
return nil, err
126+
}
127+
defer c.lock.Unlock()
128+
return c.client.Extensions(ctx)
58129
}
59130

60-
// RegisterExtension registers the extension plugins with the osquery process.
131+
// RegisterExtension registers the extension plugins with the osquery process, using a new background context
61132
func (c *ExtensionManagerClient) RegisterExtension(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
62-
return c.Client.RegisterExtension(context.Background(), info, registry)
133+
return c.RegisterExtensionContext(context.Background(), info, registry)
134+
}
135+
136+
// RegisterExtensionContext registers the extension plugins with the osquery process.
137+
func (c *ExtensionManagerClient) RegisterExtensionContext(ctx context.Context, info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
138+
if err := c.lock.Lock(ctx); err != nil {
139+
return nil, err
140+
}
141+
defer c.lock.Unlock()
142+
return c.client.RegisterExtension(ctx, info, registry)
63143
}
64144

65-
// DeregisterExtension de-registers the extension plugins with the osquery process.
145+
// DeregisterExtension de-registers the extension plugins with the osquery process, using a new background context
66146
func (c *ExtensionManagerClient) DeregisterExtension(uuid osquery.ExtensionRouteUUID) (*osquery.ExtensionStatus, error) {
67-
return c.Client.DeregisterExtension(context.Background(), uuid)
147+
return c.DeregisterExtensionContext(context.Background(), uuid)
148+
}
149+
150+
// DeregisterExtensionContext de-registers the extension plugins with the osquery process.
151+
func (c *ExtensionManagerClient) DeregisterExtensionContext(ctx context.Context, uuid osquery.ExtensionRouteUUID) (*osquery.ExtensionStatus, error) {
152+
if err := c.lock.Lock(ctx); err != nil {
153+
return nil, err
154+
}
155+
defer c.lock.Unlock()
156+
return c.client.DeregisterExtension(ctx, uuid)
68157
}
69158

70-
// Options requests the list of bootstrap or configuration options.
159+
// Options requests the list of bootstrap or configuration options, using a new background context.
71160
func (c *ExtensionManagerClient) Options() (osquery.InternalOptionList, error) {
72-
return c.Client.Options(context.Background())
161+
return c.OptionsContext(context.Background())
162+
}
163+
164+
// OptionsContext requests the list of bootstrap or configuration options.
165+
func (c *ExtensionManagerClient) OptionsContext(ctx context.Context) (osquery.InternalOptionList, error) {
166+
if err := c.lock.Lock(ctx); err != nil {
167+
return nil, err
168+
}
169+
defer c.lock.Unlock()
170+
return c.client.Options(ctx)
73171
}
74172

75-
// Query requests a query to be run and returns the extension response.
173+
// Query requests a query to be run and returns the extension
174+
// response, using a new background context. Consider using the
175+
// QueryRow or QueryRows helpers for a more friendly interface.
176+
func (c *ExtensionManagerClient) Query(sql string) (*osquery.ExtensionResponse, error) {
177+
return c.QueryContext(context.Background(), sql)
178+
}
179+
180+
// QueryContext requests a query to be run and returns the extension response.
76181
// Consider using the QueryRow or QueryRows helpers for a more friendly
77182
// interface.
78-
func (c *ExtensionManagerClient) Query(sql string) (*osquery.ExtensionResponse, error) {
79-
return c.Client.Query(context.Background(), sql)
183+
func (c *ExtensionManagerClient) QueryContext(ctx context.Context, sql string) (*osquery.ExtensionResponse, error) {
184+
if err := c.lock.Lock(ctx); err != nil {
185+
return nil, err
186+
}
187+
defer c.lock.Unlock()
188+
return c.client.Query(ctx, sql)
80189
}
81190

82191
// QueryRows is a helper that executes the requested query and returns the
83192
// results. It handles checking both the transport level errors and the osquery
84193
// internal errors by returning a normal Go error type.
85194
func (c *ExtensionManagerClient) QueryRows(sql string) ([]map[string]string, error) {
86-
res, err := c.Query(sql)
195+
return c.QueryRowsContext(context.Background(), sql)
196+
}
197+
198+
// QueryRowsContext is a helper that executes the requested query and returns the
199+
// results. It handles checking both the transport level errors and the osquery
200+
// internal errors by returning a normal Go error type.
201+
func (c *ExtensionManagerClient) QueryRowsContext(ctx context.Context, sql string) ([]map[string]string, error) {
202+
res, err := c.QueryContext(ctx, sql)
87203
if err != nil {
88204
return nil, errors.Wrap(err, "transport error in query")
89205
}
@@ -100,7 +216,13 @@ func (c *ExtensionManagerClient) QueryRows(sql string) ([]map[string]string, err
100216
// QueryRow behaves similarly to QueryRows, but it returns an error if the
101217
// query does not return exactly one row.
102218
func (c *ExtensionManagerClient) QueryRow(sql string) (map[string]string, error) {
103-
res, err := c.QueryRows(sql)
219+
return c.QueryRowContext(context.Background(), sql)
220+
}
221+
222+
// QueryRowContext behaves similarly to QueryRows, but it returns an error if the
223+
// query does not return exactly one row.
224+
func (c *ExtensionManagerClient) QueryRowContext(ctx context.Context, sql string) (map[string]string, error) {
225+
res, err := c.QueryRowsContext(ctx, sql)
104226
if err != nil {
105227
return nil, err
106228
}
@@ -110,7 +232,16 @@ func (c *ExtensionManagerClient) QueryRow(sql string) (map[string]string, error)
110232
return res[0], nil
111233
}
112234

113-
// GetQueryColumns requests the columns returned by the parsed query.
235+
// GetQueryColumns requests the columns returned by the parsed query, using a new background context.
114236
func (c *ExtensionManagerClient) GetQueryColumns(sql string) (*osquery.ExtensionResponse, error) {
115-
return c.Client.GetQueryColumns(context.Background(), sql)
237+
return c.GetQueryColumnsContext(context.Background(), sql)
238+
}
239+
240+
// GetQueryColumnsContext requests the columns returned by the parsed query.
241+
func (c *ExtensionManagerClient) GetQueryColumnsContext(ctx context.Context, sql string) (*osquery.ExtensionResponse, error) {
242+
if err := c.lock.Lock(ctx); err != nil {
243+
return nil, err
244+
}
245+
defer c.lock.Unlock()
246+
return c.client.GetQueryColumns(ctx, sql)
116247
}

client_test.go

+94-1
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,23 @@ package osquery
33
import (
44
"context"
55
"errors"
6+
"fmt"
7+
"os"
8+
"sync"
69
"testing"
10+
"time"
711

812
"github.com/osquery/osquery-go/gen/osquery"
913
"github.com/osquery/osquery-go/mock"
1014
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
1116
)
1217

1318
func TestQueryRows(t *testing.T) {
19+
t.Parallel()
1420
mock := &mock.ExtensionManager{}
15-
client := &ExtensionManagerClient{Client: mock}
21+
client, err := NewClient("", 5*time.Second, WithOsqueryThriftClient(mock))
22+
require.NoError(t, err)
1623

1724
// Transport related error
1825
mock.QueryFunc = func(ctx context.Context, sql string) (*osquery.ExtensionResponse, error) {
@@ -77,3 +84,89 @@ func TestQueryRows(t *testing.T) {
7784
row, err = client.QueryRow("select 1 union select 2")
7885
assert.NotNil(t, err)
7986
}
87+
88+
// TestLocking tests the the client correctly locks access to the osquery socket. Thrift only supports a single
89+
// actor on the socket at a time, this means that in parallel go code, it's very easy to have messages get
90+
// crossed and generate errors. This tests to ensure the locking works
91+
func TestLocking(t *testing.T) {
92+
t.Parallel()
93+
94+
sock := os.Getenv("OSQ_SOCKET")
95+
if sock == "" {
96+
t.Skip("no osquery socket specified")
97+
}
98+
99+
osq, err := NewClient(sock, 5*time.Second)
100+
require.NoError(t, err)
101+
102+
// The issue we're testing is about multithreaded access. Let's hammer on it!
103+
wait := sync.WaitGroup{}
104+
for i := 0; i < 100; i++ {
105+
wait.Add(1)
106+
go func() {
107+
defer wait.Done()
108+
109+
status, err := osq.Ping()
110+
require.NoError(t, err, "call to Ping()")
111+
if err != nil {
112+
require.Equal(t, 0, status.Code, fmt.Errorf("ping returned %d: %s", status.Code, status.Message))
113+
}
114+
}()
115+
}
116+
117+
wait.Wait()
118+
}
119+
120+
func TestLockTimeouts(t *testing.T) {
121+
t.Parallel()
122+
mock := &mock.ExtensionManager{}
123+
client, err := NewClient("", 5*time.Second, WithOsqueryThriftClient(mock), DefaultWaitTime(100*time.Millisecond), DefaultWaitTime(5*time.Second))
124+
require.NoError(t, err)
125+
126+
wait := sync.WaitGroup{}
127+
128+
errChan := make(chan error, 10)
129+
for i := 0; i < 3; i++ {
130+
wait.Add(1)
131+
go func() {
132+
defer wait.Done()
133+
134+
ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
135+
defer cancel()
136+
137+
errChan <- client.SlowLocker(ctx, 75*time.Millisecond)
138+
}()
139+
}
140+
141+
wait.Wait()
142+
close(errChan)
143+
144+
var successCount, errCount int
145+
for err := range errChan {
146+
if err == nil {
147+
successCount += 1
148+
} else {
149+
errCount += 1
150+
}
151+
}
152+
153+
assert.Equal(t, 2, successCount, "expected success count")
154+
assert.Equal(t, 1, errCount, "expected error count")
155+
}
156+
157+
// WithOsqueryThriftClient sets the underlying thrift client. This can be used to set a mock
158+
func WithOsqueryThriftClient(client osquery.ExtensionManager) ClientOption {
159+
return func(c *ExtensionManagerClient) {
160+
c.client = client
161+
}
162+
}
163+
164+
// SlowLocker attempts to emulate a slow sql routine, so we can test how lock timeouts work.
165+
func (c *ExtensionManagerClient) SlowLocker(ctx context.Context, d time.Duration) error {
166+
if err := c.lock.Lock(ctx); err != nil {
167+
return err
168+
}
169+
defer c.lock.Unlock()
170+
time.Sleep(d)
171+
return nil
172+
}

go.mod

+7-3
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@ module github.com/osquery/osquery-go
33
require (
44
github.com/Microsoft/go-winio v0.4.9
55
github.com/apache/thrift v0.16.0
6-
github.com/davecgh/go-spew v1.1.1 // indirect
76
github.com/pkg/errors v0.8.0
7+
github.com/stretchr/testify v1.8.3
8+
)
9+
10+
require (
11+
github.com/davecgh/go-spew v1.1.1 // indirect
812
github.com/pmezard/go-difflib v1.0.0 // indirect
9-
github.com/stretchr/testify v1.2.2
1013
golang.org/x/sys v0.1.0 // indirect
14+
gopkg.in/yaml.v3 v3.0.1 // indirect
1115
)
1216

13-
go 1.16
17+
go 1.19

0 commit comments

Comments
 (0)