Skip to content

Commit e3c037c

Browse files
committed
wip: add back-off mechanism when validating Fleet Desktop token
1 parent b011418 commit e3c037c

File tree

3 files changed

+234
-38
lines changed

3 files changed

+234
-38
lines changed

orbit/cmd/desktop/desktop.go

+19-38
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,15 @@ func main() {
168168
return newToken
169169
})
170170

171-
refetchToken := func() {
171+
refetchTokenValueFromDisk := func() {
172172
if _, err := tokenReader.Read(); err != nil {
173173
log.Error().Err(err).Msg("refetch token")
174174
}
175175
log.Debug().Msg("successfully refetched the token from disk")
176176
}
177177

178+
tokenChecker := token.NewChecker(identifierPath, client)
179+
178180
disableTray := func() {
179181
log.Debug().Msg("disabling tray items")
180182
myDeviceItem.SetTitle("Connecting...")
@@ -184,46 +186,24 @@ func main() {
184186
migrateMDMItem.Hide()
185187
}
186188

187-
// checkToken performs API test calls to enable the "My device" item as
188-
// soon as the device auth token is registered by Fleet.
189-
checkToken := func() <-chan interface{} {
190-
done := make(chan interface{})
191-
192-
go func() {
193-
ticker := time.NewTicker(5 * time.Second)
194-
defer ticker.Stop()
195-
defer close(done)
196-
197-
for {
198-
refetchToken()
199-
_, err := client.DesktopSummary(tokenReader.GetCached())
200-
201-
if err == nil || errors.Is(err, service.ErrMissingLicense) {
202-
log.Debug().Msg("enabling tray items")
203-
myDeviceItem.SetTitle("My device")
204-
myDeviceItem.Enable()
205-
transparencyItem.Enable()
206-
return
207-
}
208-
209-
log.Error().Err(err).Msg("get device URL")
210-
211-
<-ticker.C
212-
}
213-
}()
214-
215-
return done
189+
// checkToken perform API calls until a token is valid. When we
190+
// find a valid token it refreshes the cached alue and enables
191+
// basic tray items.
192+
checkToken := func() {
193+
tokenChecker.AwaitValid()
194+
refetchTokenValueFromDisk()
195+
log.Debug().Msg("enabling tray items")
196+
myDeviceItem.SetTitle("My device")
197+
myDeviceItem.Enable()
198+
transparencyItem.Enable()
216199
}
217200

218-
// start a check as soon as the app starts
219-
deviceEnabledChan := checkToken()
220-
221201
// this loop checks the `mtime` value of the token file and:
222202
// 1. if the token file was modified, it disables the tray items until we
223203
// verify the token is valid
224204
// 2. calls (blocking) `checkToken` to verify the token is valid
225205
go func() {
226-
<-deviceEnabledChan
206+
tokenChecker.AwaitValid()
227207
tic := time.NewTicker(1 * time.Second)
228208
defer tic.Stop()
229209

@@ -234,9 +214,8 @@ func main() {
234214
case err != nil:
235215
log.Error().Err(err).Msg("check token file")
236216
case expired:
237-
log.Info().Msg("token file changed, rechecking")
238217
disableTray()
239-
<-checkToken()
218+
checkToken()
240219
}
241220
}
242221
}()
@@ -279,7 +258,6 @@ func main() {
279258
// poll the server to check the policy status of the host and update the
280259
// tray icon accordingly
281260
go func() {
282-
<-deviceEnabledChan
283261
tic := time.NewTicker(5 * time.Minute)
284262
defer tic.Stop()
285263

@@ -290,11 +268,13 @@ func main() {
290268
case err == nil:
291269
// OK
292270
case errors.Is(err, service.ErrMissingLicense):
271+
// This case is for devices using the free plan.
293272
myDeviceItem.SetTitle("My device")
273+
transparencyItem.Enable()
294274
continue
295275
case errors.Is(err, service.ErrUnauthenticated):
296276
disableTray()
297-
<-checkToken()
277+
checkToken()
298278
continue
299279
default:
300280
log.Error().Err(err).Msg("get failing policies")
@@ -326,6 +306,7 @@ func main() {
326306
}
327307
}
328308
myDeviceItem.Enable()
309+
transparencyItem.Enable()
329310

330311
shouldRunMigrator := sum.Notifications.NeedsMDMMigration || sum.Notifications.RenewEnrollmentProfile
331312

orbit/pkg/token/checker.go

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package token
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"time"
7+
8+
"github.com/cenkalti/backoff"
9+
"github.com/fleetdm/fleet/v4/server/service"
10+
)
11+
12+
type client interface {
13+
CheckToken(string) error
14+
}
15+
16+
type reader interface {
17+
Read() (string, error)
18+
GetCached() string
19+
}
20+
21+
// TODO: docs
22+
type RemoteChecker struct {
23+
reader reader
24+
client client
25+
}
26+
27+
// TODO: docs
28+
func NewChecker(path string, client client) *RemoteChecker {
29+
return &RemoteChecker{
30+
reader: &Reader{Path: path},
31+
client: client,
32+
}
33+
}
34+
35+
func (c *RemoteChecker) isValid(err error) bool {
36+
return err == nil || errors.Is(err, service.ErrMissingLicense)
37+
}
38+
39+
// TODO: docs
40+
func (c *RemoteChecker) AwaitValid() {
41+
// TODO: find appropriate values
42+
// randomized interval = RetryInterval * (random value in range [1 - RandomizationFactor, 1 + RandomizationFactor])
43+
retryStrategy := backoff.NewExponentialBackOff()
44+
retryStrategy.InitialInterval = 2 * time.Second
45+
retryStrategy.MaxElapsedTime = 1 * time.Minute
46+
47+
for {
48+
if err := backoff.Retry(
49+
func() error {
50+
if _, err := c.reader.Read(); err != nil {
51+
return err
52+
}
53+
err := c.client.CheckToken(c.reader.GetCached())
54+
if !c.isValid(err) {
55+
// TODO: better error
56+
return errors.New("invalid")
57+
}
58+
return nil
59+
},
60+
retryStrategy,
61+
); err != nil {
62+
// TODO: what do we do here?
63+
fmt.Println("backoff gave up")
64+
continue
65+
}
66+
67+
return
68+
}
69+
}

orbit/pkg/token/checker_test.go

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package token
2+
3+
import (
4+
"errors"
5+
"testing"
6+
"time"
7+
8+
"github.com/fleetdm/fleet/v4/server/service"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
type mockClient struct {
13+
TokenValidationFunc func(string) error
14+
}
15+
16+
func (s *mockClient) CheckToken(token string) error {
17+
return s.TokenValidationFunc(token)
18+
}
19+
20+
type mockReader struct {
21+
ReadFunc func() (string, error)
22+
CachedVal string
23+
}
24+
25+
func (s *mockReader) Read() (string, error) {
26+
return s.ReadFunc()
27+
}
28+
29+
func (s *mockReader) GetCached() string {
30+
return s.CachedVal
31+
}
32+
33+
func TestNewChecker(t *testing.T) {
34+
client := &mockClient{}
35+
checker := NewChecker("path/to/token", client)
36+
37+
require.NotNil(t, checker)
38+
}
39+
40+
func TestIsValid(t *testing.T) {
41+
client := &mockClient{}
42+
checker := NewChecker("path/to/token", client)
43+
44+
testCases := []struct {
45+
err error
46+
expected bool
47+
}{
48+
{nil, true},
49+
{errors.New("random error"), false},
50+
{service.ErrMissingLicense, true},
51+
}
52+
53+
for _, tc := range testCases {
54+
result := checker.isValid(tc.err)
55+
require.Equal(t, tc.expected, result)
56+
}
57+
}
58+
59+
func TestAwaitValid_Success(t *testing.T) {
60+
client := &mockClient{
61+
TokenValidationFunc: func(token string) error {
62+
return nil
63+
},
64+
}
65+
reader := &mockReader{
66+
ReadFunc: func() (string, error) {
67+
return "valid token", nil
68+
},
69+
}
70+
71+
checker := &RemoteChecker{
72+
reader: reader,
73+
client: client,
74+
}
75+
76+
done := make(chan bool)
77+
go func() {
78+
checker.AwaitValid()
79+
done <- true
80+
}()
81+
82+
select {
83+
case <-done:
84+
// test passed
85+
case <-time.After(2 * time.Second):
86+
t.Fatal("Test timed out - AwaitValid did not complete in expected time")
87+
}
88+
}
89+
90+
func TestAwaitValid_Failure(t *testing.T) {
91+
client := &mockClient{
92+
TokenValidationFunc: func(token string) error {
93+
return errors.New("invalid token")
94+
},
95+
}
96+
reader := &mockReader{
97+
ReadFunc: func() (string, error) {
98+
return "invalid token", nil
99+
},
100+
}
101+
102+
checker := &RemoteChecker{
103+
reader: reader,
104+
client: client,
105+
}
106+
107+
done := make(chan bool)
108+
go func() {
109+
checker.AwaitValid()
110+
done <- true
111+
}()
112+
113+
select {
114+
case <-done:
115+
t.Fatal("Test failed - AwaitValid should not have completed")
116+
case <-time.After(2 * time.Second):
117+
// test passed
118+
}
119+
}
120+
121+
func TestAwaitValid_ReaderError(t *testing.T) {
122+
client := &mockClient{}
123+
reader := &mockReader{
124+
ReadFunc: func() (string, error) {
125+
return "", errors.New("reader error")
126+
},
127+
}
128+
129+
checker := &RemoteChecker{
130+
reader: reader,
131+
client: client,
132+
}
133+
134+
done := make(chan bool)
135+
go func() {
136+
checker.AwaitValid()
137+
done <- true
138+
}()
139+
140+
select {
141+
case <-done:
142+
t.Fatal("Test failed - AwaitValid should not have completed")
143+
case <-time.After(2 * time.Second):
144+
// test passed
145+
}
146+
}

0 commit comments

Comments
 (0)