Skip to content

Commit 42d6fa1

Browse files
committed
multi-service: preserve context values
1 parent 51616ed commit 42d6fa1

File tree

3 files changed

+99
-3
lines changed

3 files changed

+99
-3
lines changed

service/multi/context.go

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package multi
2+
3+
import (
4+
"context"
5+
"time"
6+
)
7+
8+
// ContextWithoutCancel returns a derived context that points to the parent context
9+
// and is not canceled when parent is canceled.
10+
// The returned context returns no Deadline or Err, and its Done channel is nil.
11+
// Calling Cause on the returned context returns nil.
12+
func ContextWithoutCancel(parent context.Context) context.Context {
13+
if parent == nil {
14+
panic("cannot create context from nil parent")
15+
}
16+
return withoutCancelCtx{parent}
17+
}
18+
19+
type withoutCancelCtx struct {
20+
c context.Context
21+
}
22+
23+
func (withoutCancelCtx) Deadline() (deadline time.Time, ok bool) {
24+
return
25+
}
26+
27+
func (withoutCancelCtx) Done() <-chan struct{} {
28+
return nil
29+
}
30+
31+
func (withoutCancelCtx) Err() error {
32+
return nil
33+
}
34+
35+
func (c withoutCancelCtx) Value(key interface{}) interface{} {
36+
return c.c.Value(key)
37+
}
38+
39+
// func (c withoutCancelCtx) String() string {
40+
// return c.c.String() + ".WithoutCancel"
41+
// }

service/multi/context_test.go

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package multi
2+
3+
import (
4+
"context"
5+
"sync"
6+
"testing"
7+
8+
"github.com/micromdm/nanolib/log"
9+
"github.com/micromdm/nanomdm/mdm"
10+
"github.com/micromdm/nanomdm/service"
11+
"github.com/micromdm/nanomdm/test"
12+
)
13+
14+
type ctxTest1 struct{}
15+
16+
type testSvc struct {
17+
wg *sync.WaitGroup
18+
capture string
19+
service.CheckinAndCommandService
20+
}
21+
22+
func (ts *testSvc) Authenticate(r *mdm.Request, _ *mdm.Authenticate) error {
23+
ts.capture, _ = r.Context.Value(&ctxTest1{}).(string)
24+
ts.wg.Done()
25+
return nil
26+
}
27+
28+
func TestContextPassthru(t *testing.T) {
29+
nopSvc1 := &test.NopService{}
30+
31+
var ctx context.Context = context.Background()
32+
33+
ctx = context.WithValue(ctx, &ctxTest1{}, "test-ctx-val")
34+
35+
r := &mdm.Request{Context: ctx}
36+
37+
var wg sync.WaitGroup
38+
39+
wg.Add(1)
40+
ts := &testSvc{
41+
wg: &wg,
42+
CheckinAndCommandService: &test.NopService{},
43+
}
44+
45+
multi := New(log.NopLogger, nopSvc1, ts)
46+
47+
err := multi.Authenticate(r, &mdm.Authenticate{})
48+
if err != nil {
49+
t.Fatal(err)
50+
}
51+
52+
wg.Wait()
53+
54+
if have, want := ts.capture, "test-ctx-val"; have != want {
55+
t.Errorf("have: %v, want: %v", have, want)
56+
}
57+
}

service/multi/multi.go

+1-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import (
1919
type MultiService struct {
2020
logger log.Logger
2121
svcs []service.CheckinAndCommandService
22-
ctx context.Context
2322
}
2423

2524
func New(logger log.Logger, svcs ...service.CheckinAndCommandService) *MultiService {
@@ -29,7 +28,6 @@ func New(logger log.Logger, svcs ...service.CheckinAndCommandService) *MultiServ
2928
return &MultiService{
3029
logger: logger,
3130
svcs: svcs,
32-
ctx: context.Background(),
3331
}
3432
}
3533

@@ -52,7 +50,7 @@ func (ms *MultiService) runOthers(ctx context.Context, r errorRunner) {
5250
// RequestWithContext returns a clone of r and sets its context to ctx.
5351
func (ms *MultiService) RequestWithContext(r *mdm.Request) *mdm.Request {
5452
r2 := r.Clone()
55-
r2.Context = ms.ctx
53+
r2.Context = ContextWithoutCancel(r.Context)
5654
return r2
5755
}
5856

0 commit comments

Comments
 (0)