Skip to content

Commit 2bcde89

Browse files
authored
Introduce Go context-aware Wait functions for blocking operation (cli#39)
1 parent 6f7124e commit 2bcde89

10 files changed

+133
-71
lines changed

device/device_flow.go

+19-23
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
package device
1414

1515
import (
16+
"context"
1617
"errors"
1718
"fmt"
1819
"net/http"
@@ -103,16 +104,16 @@ const defaultGrantType = "urn:ietf:params:oauth:grant-type:device_code"
103104

104105
// PollToken polls the server at pollURL until an access token is granted or denied.
105106
//
106-
// Deprecated: use PollTokenWithOptions.
107+
// Deprecated: use Wait.
107108
func PollToken(c httpClient, pollURL string, clientID string, code *CodeResponse) (*api.AccessToken, error) {
108-
return PollTokenWithOptions(c, pollURL, PollOptions{
109+
return Wait(context.Background(), c, pollURL, WaitOptions{
109110
ClientID: clientID,
110111
DeviceCode: code,
111112
})
112113
}
113114

114-
// PollOptions specifies parameters to poll the server with until authentication completes.
115-
type PollOptions struct {
115+
// WaitOptions specifies parameters to poll the server with until authentication completes.
116+
type WaitOptions struct {
116117
// ClientID is the app client ID value.
117118
ClientID string
118119
// ClientSecret is the app client secret value. Optional: only pass if the server requires it.
@@ -122,30 +123,28 @@ type PollOptions struct {
122123
// GrantType overrides the default value specified by OAuth 2.0 Device Code. Optional.
123124
GrantType string
124125

125-
timeNow func() time.Time
126-
timeSleep func(time.Duration)
126+
newPoller pollerFactory
127127
}
128128

129-
// PollTokenWithOptions polls the server at uri until authorization completes.
130-
func PollTokenWithOptions(c httpClient, uri string, opts PollOptions) (*api.AccessToken, error) {
131-
timeNow := opts.timeNow
132-
if timeNow == nil {
133-
timeNow = time.Now
134-
}
135-
timeSleep := opts.timeSleep
136-
if timeSleep == nil {
137-
timeSleep = time.Sleep
138-
}
139-
129+
// Wait polls the server at uri until authorization completes.
130+
func Wait(ctx context.Context, c httpClient, uri string, opts WaitOptions) (*api.AccessToken, error) {
140131
checkInterval := time.Duration(opts.DeviceCode.Interval) * time.Second
141-
expiresAt := timeNow().Add(time.Duration(opts.DeviceCode.ExpiresIn) * time.Second)
132+
expiresIn := time.Duration(opts.DeviceCode.ExpiresIn) * time.Second
142133
grantType := opts.GrantType
143134
if opts.GrantType == "" {
144135
grantType = defaultGrantType
145136
}
146137

138+
makePoller := opts.newPoller
139+
if makePoller == nil {
140+
makePoller = newPoller
141+
}
142+
_, poll := makePoller(ctx, checkInterval, expiresIn)
143+
147144
for {
148-
timeSleep(checkInterval)
145+
if err := poll.Wait(); err != nil {
146+
return nil, err
147+
}
149148

150149
values := url.Values{
151150
"client_id": {opts.ClientID},
@@ -158,6 +157,7 @@ func PollTokenWithOptions(c httpClient, uri string, opts PollOptions) (*api.Acce
158157
values.Add("client_secret", opts.ClientSecret)
159158
}
160159

160+
// TODO: pass tctx down to the HTTP layer
161161
resp, err := api.PostForm(c, uri, values)
162162
if err != nil {
163163
return nil, err
@@ -170,9 +170,5 @@ func PollTokenWithOptions(c httpClient, uri string, opts PollOptions) (*api.Acce
170170
} else if !(errors.As(err, &apiError) && apiError.Code == "authorization_pending") {
171171
return nil, err
172172
}
173-
174-
if timeNow().After(expiresAt) {
175-
return nil, ErrTimeout
176-
}
177173
}
178174
}

device/device_flow_test.go

+33-39
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package device
22

33
import (
44
"bytes"
5+
"context"
6+
"errors"
57
"io/ioutil"
68
"net/http"
79
"net/url"
@@ -230,28 +232,16 @@ func TestRequestCode(t *testing.T) {
230232
}
231233

232234
func TestPollToken(t *testing.T) {
233-
var totalSlept time.Duration
234-
mockSleep := func(d time.Duration) {
235-
totalSlept += d
236-
}
237-
duration := func(d string) time.Duration {
238-
res, _ := time.ParseDuration(d)
239-
return res
240-
}
241-
clock := func(durations ...string) func() time.Time {
242-
count := 0
243-
now := time.Now()
244-
return func() time.Time {
245-
t := now.Add(duration(durations[count]))
246-
count++
247-
return t
235+
makeFakePoller := func(maxWaits int) pollerFactory {
236+
return func(ctx context.Context, interval, expiresIn time.Duration) (context.Context, poller) {
237+
return ctx, &fakePoller{maxWaits: maxWaits}
248238
}
249239
}
250240

251241
type args struct {
252242
http apiClient
253243
url string
254-
opts PollOptions
244+
opts WaitOptions
255245
}
256246
tests := []struct {
257247
name string
@@ -279,7 +269,7 @@ func TestPollToken(t *testing.T) {
279269
},
280270
},
281271
url: "https://github.com/oauth",
282-
opts: PollOptions{
272+
opts: WaitOptions{
283273
ClientID: "CLIENT-ID",
284274
DeviceCode: &CodeResponse{
285275
DeviceCode: "DEVIC",
@@ -288,14 +278,12 @@ func TestPollToken(t *testing.T) {
288278
ExpiresIn: 99,
289279
Interval: 5,
290280
},
291-
timeSleep: mockSleep,
292-
timeNow: clock("0", "5s", "10s"),
281+
newPoller: makeFakePoller(2),
293282
},
294283
},
295284
want: &api.AccessToken{
296285
Token: "123abc",
297286
},
298-
slept: duration("10s"),
299287
posts: []postArgs{
300288
{
301289
url: "https://github.com/oauth",
@@ -328,7 +316,7 @@ func TestPollToken(t *testing.T) {
328316
},
329317
},
330318
url: "https://github.com/oauth",
331-
opts: PollOptions{
319+
opts: WaitOptions{
332320
ClientID: "CLIENT-ID",
333321
ClientSecret: "SEKRIT",
334322
GrantType: "device_code",
@@ -339,14 +327,12 @@ func TestPollToken(t *testing.T) {
339327
ExpiresIn: 99,
340328
Interval: 5,
341329
},
342-
timeSleep: mockSleep,
343-
timeNow: clock("0", "5s", "10s"),
330+
newPoller: makeFakePoller(1),
344331
},
345332
},
346333
want: &api.AccessToken{
347334
Token: "123abc",
348335
},
349-
slept: duration("5s"),
350336
posts: []postArgs{
351337
{
352338
url: "https://github.com/oauth",
@@ -377,21 +363,19 @@ func TestPollToken(t *testing.T) {
377363
},
378364
},
379365
url: "https://github.com/oauth",
380-
opts: PollOptions{
366+
opts: WaitOptions{
381367
ClientID: "CLIENT-ID",
382368
DeviceCode: &CodeResponse{
383369
DeviceCode: "DEVIC",
384370
UserCode: "123-abc",
385371
VerificationURI: "http://verify.me",
386-
ExpiresIn: 99,
372+
ExpiresIn: 14,
387373
Interval: 5,
388374
},
389-
timeSleep: mockSleep,
390-
timeNow: clock("0", "5s", "15m"),
375+
newPoller: makeFakePoller(2),
391376
},
392377
},
393-
wantErr: "authentication timed out",
394-
slept: duration("10s"),
378+
wantErr: "context deadline exceeded",
395379
posts: []postArgs{
396380
{
397381
url: "https://github.com/oauth",
@@ -424,7 +408,7 @@ func TestPollToken(t *testing.T) {
424408
},
425409
},
426410
url: "https://github.com/oauth",
427-
opts: PollOptions{
411+
opts: WaitOptions{
428412
ClientID: "CLIENT-ID",
429413
DeviceCode: &CodeResponse{
430414
DeviceCode: "DEVIC",
@@ -433,12 +417,10 @@ func TestPollToken(t *testing.T) {
433417
ExpiresIn: 99,
434418
Interval: 5,
435419
},
436-
timeSleep: mockSleep,
437-
timeNow: clock("0", "5s"),
420+
newPoller: makeFakePoller(1),
438421
},
439422
},
440423
wantErr: "access_denied",
441-
slept: duration("5s"),
442424
posts: []postArgs{
443425
{
444426
url: "https://github.com/oauth",
@@ -453,8 +435,7 @@ func TestPollToken(t *testing.T) {
453435
}
454436
for _, tt := range tests {
455437
t.Run(tt.name, func(t *testing.T) {
456-
totalSlept = 0
457-
got, err := PollTokenWithOptions(&tt.args.http, tt.args.url, tt.args.opts)
438+
got, err := Wait(context.Background(), &tt.args.http, tt.args.url, tt.args.opts)
458439
if (err != nil) != (tt.wantErr != "") {
459440
t.Errorf("PollToken() error = %v, wantErr %v", err, tt.wantErr)
460441
return
@@ -468,9 +449,22 @@ func TestPollToken(t *testing.T) {
468449
if !reflect.DeepEqual(tt.args.http.calls, tt.posts) {
469450
t.Errorf("PostForm() = %v, want %v", tt.args.http.calls, tt.posts)
470451
}
471-
if totalSlept != tt.slept {
472-
t.Errorf("slept %v, wanted %v", totalSlept, tt.slept)
473-
}
474452
})
475453
}
476454
}
455+
456+
type fakePoller struct {
457+
maxWaits int
458+
count int
459+
}
460+
461+
func (p *fakePoller) Wait() error {
462+
if p.count == p.maxWaits {
463+
return errors.New("context deadline exceeded")
464+
}
465+
p.count++
466+
return nil
467+
}
468+
469+
func (p *fakePoller) Cancel() {
470+
}

device/examples_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package device
22

33
import (
4+
"context"
45
"fmt"
56
"net/http"
67
"os"
@@ -22,7 +23,7 @@ func Example() {
2223
fmt.Printf("Copy code: %s\n", code.UserCode)
2324
fmt.Printf("then open: %s\n", code.VerificationURI)
2425

25-
accessToken, err := PollTokenWithOptions(httpClient, "https://github.com/login/oauth/access_token", PollOptions{
26+
accessToken, err := Wait(context.TODO(), httpClient, "https://github.com/login/oauth/access_token", WaitOptions{
2627
ClientID: clientID,
2728
DeviceCode: code,
2829
})

device/poller.go

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package device
2+
3+
import (
4+
"context"
5+
"time"
6+
)
7+
8+
type poller interface {
9+
Wait() error
10+
Cancel()
11+
}
12+
13+
type pollerFactory func(context.Context, time.Duration, time.Duration) (context.Context, poller)
14+
15+
func newPoller(ctx context.Context, checkInteval, expiresIn time.Duration) (context.Context, poller) {
16+
c, cancel := context.WithTimeout(ctx, expiresIn)
17+
return c, &intervalPoller{
18+
ctx: c,
19+
interval: checkInteval,
20+
cancelFunc: cancel,
21+
}
22+
}
23+
24+
type intervalPoller struct {
25+
ctx context.Context
26+
interval time.Duration
27+
cancelFunc func()
28+
}
29+
30+
func (p intervalPoller) Wait() error {
31+
t := time.NewTimer(p.interval)
32+
select {
33+
case <-p.ctx.Done():
34+
t.Stop()
35+
return p.ctx.Err()
36+
case <-t.C:
37+
return nil
38+
}
39+
}
40+
41+
func (p intervalPoller) Cancel() {
42+
p.cancelFunc()
43+
}

oauth_device.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package oauth
22

33
import (
44
"bufio"
5+
"context"
56
"fmt"
67
"io"
78
"net/http"
@@ -58,7 +59,7 @@ func (oa *Flow) DeviceFlow() (*api.AccessToken, error) {
5859
return nil, fmt.Errorf("error opening the web browser: %w", err)
5960
}
6061

61-
return device.PollTokenWithOptions(httpClient, host.TokenURL, device.PollOptions{
62+
return device.Wait(context.TODO(), httpClient, host.TokenURL, device.WaitOptions{
6263
ClientID: oa.ClientID,
6364
DeviceCode: code,
6465
})

oauth_webapp.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package oauth
22

33
import (
4+
"context"
45
"fmt"
56
"net/http"
67

@@ -52,5 +53,7 @@ func (oa *Flow) WebAppFlow() (*api.AccessToken, error) {
5253
httpClient = http.DefaultClient
5354
}
5455

55-
return flow.AccessToken(httpClient, host.TokenURL, oa.ClientSecret)
56+
return flow.Wait(context.TODO(), httpClient, host.TokenURL, webapp.WaitOptions{
57+
ClientSecret: oa.ClientSecret,
58+
})
5659
}

webapp/examples_test.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package webapp
22

33
import (
4+
"context"
45
"fmt"
56
"net/http"
67
"os"
@@ -42,7 +43,9 @@ func Example() {
4243
}
4344

4445
httpClient := http.DefaultClient
45-
accessToken, err := flow.AccessToken(httpClient, "https://github.com/login/oauth/access_token", clientSecret)
46+
accessToken, err := flow.Wait(context.TODO(), httpClient, "https://github.com/login/oauth/access_token", WaitOptions{
47+
ClientSecret: clientSecret,
48+
})
4649
if err != nil {
4750
panic(err)
4851
}

0 commit comments

Comments
 (0)