Skip to content

Commit 589416c

Browse files
committed
support context propagation
- context.Context instance passed in ContextEval can be propagated to binding function to cancel the process.
1 parent 2337cc0 commit 589416c

20 files changed

+481
-334
lines changed

cel/cel_test.go

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,41 @@ func TestContextEval(t *testing.T) {
10141014
}
10151015
}
10161016

1017+
func TestContextEvalPropagation(t *testing.T) {
1018+
env, err := NewEnv(Function("test",
1019+
Overload("test_int", []*Type{}, IntType,
1020+
FunctionBindingContext(func(ctx context.Context, _ ...ref.Val) ref.Val {
1021+
md := ctx.Value("metadata")
1022+
if md == nil {
1023+
return types.NewErr("cannot find metadata value")
1024+
}
1025+
return types.Int(md.(int))
1026+
}),
1027+
),
1028+
))
1029+
if err != nil {
1030+
t.Fatalf("NewEnv() failed: %v", err)
1031+
}
1032+
ast, iss := env.Compile("test()")
1033+
if iss.Err() != nil {
1034+
t.Fatalf("env.Compile(expr) failed: %v", iss.Err())
1035+
}
1036+
prg, err := env.Program(ast)
1037+
if err != nil {
1038+
t.Fatalf("env.Program() failed: %v", err)
1039+
}
1040+
1041+
expected := 10
1042+
ctx := context.WithValue(context.Background(), "metadata", expected)
1043+
out, _, err := prg.ContextEval(ctx, map[string]interface{}{})
1044+
if err != nil {
1045+
t.Fatalf("prg.ContextEval() failed: %v", err)
1046+
}
1047+
if out != types.Int(expected) {
1048+
t.Errorf("prg.ContextEval() got %v, but wanted %d", out, expected)
1049+
}
1050+
}
1051+
10171052
func BenchmarkContextEval(b *testing.B) {
10181053
env := testEnv(b,
10191054
Variable("items", ListType(IntType)),
@@ -1428,7 +1463,7 @@ func TestCustomInterpreterDecorator(t *testing.T) {
14281463
if !lhsIsConst || !rhsIsConst {
14291464
return i, nil
14301465
}
1431-
val := call.Eval(interpreter.EmptyActivation())
1466+
val := call.Eval(context.Background(), interpreter.EmptyActivation())
14321467
if types.IsError(val) {
14331468
return nil, val.(*types.Err)
14341469
}

cel/decls.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,24 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
287287
return decls.FunctionBinding(binding)
288288
}
289289

290+
// UnaryBindingContext provides the implementation of a unary overload. The provided function is protected by a runtime
291+
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
292+
func UnaryBindingContext(binding functions.UnaryContextOp) OverloadOpt {
293+
return decls.UnaryBindingContext(binding)
294+
}
295+
296+
// BinaryBindingContext provides the implementation of a binary overload. The provided function is protected by a runtime
297+
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
298+
func BinaryBindingContext(binding functions.BinaryContextOp) OverloadOpt {
299+
return decls.BinaryBindingContext(binding)
300+
}
301+
302+
// FunctionBindingContext provides the implementation of a variadic overload. The provided function is protected by a runtime
303+
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
304+
func FunctionBindingContext(binding functions.FunctionContextOp) OverloadOpt {
305+
return decls.FunctionBindingContext(binding)
306+
}
307+
290308
// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
291309
//
292310
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.

cel/decls_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package cel
1616

1717
import (
18+
"context"
1819
"fmt"
1920
"math"
2021
"reflect"
@@ -673,7 +674,7 @@ func TestExprDeclToDeclaration(t *testing.T) {
673674
}
674675
prg, err := e.Program(ast, Functions(&functions.Overload{
675676
Operator: overloads.SizeString,
676-
Unary: func(arg ref.Val) ref.Val {
677+
Unary: func(ctx context.Context, arg ref.Val) ref.Val {
677678
str, ok := arg.(types.String)
678679
if !ok {
679680
return types.MaybeNoSuchOverloadErr(arg)
@@ -682,7 +683,7 @@ func TestExprDeclToDeclaration(t *testing.T) {
682683
},
683684
}, &functions.Overload{
684685
Operator: overloads.SizeStringInst,
685-
Unary: func(arg ref.Val) ref.Val {
686+
Unary: func(ctx context.Context, arg ref.Val) ref.Val {
686687
str, ok := arg.(types.String)
687688
if !ok {
688689
return types.MaybeNoSuchOverloadErr(arg)

cel/library.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package cel
1616

1717
import (
18+
"context"
1819
"math"
1920
"strconv"
2021
"strings"
@@ -494,17 +495,17 @@ func (opt *evalOptionalOr) ID() int64 {
494495

495496
// Eval evaluates the left-hand side optional to determine whether it contains a value, else
496497
// proceeds with the right-hand side evaluation.
497-
func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val {
498+
func (opt *evalOptionalOr) Eval(ctx context.Context, vars interpreter.Activation) ref.Val {
498499
// short-circuit lhs.
499-
optLHS := opt.lhs.Eval(ctx)
500+
optLHS := opt.lhs.Eval(ctx, vars)
500501
optVal, ok := optLHS.(*types.Optional)
501502
if !ok {
502503
return optLHS
503504
}
504505
if optVal.HasValue() {
505506
return optVal
506507
}
507-
return opt.rhs.Eval(ctx)
508+
return opt.rhs.Eval(ctx, vars)
508509
}
509510

510511
// evalOptionalOrValue selects between an optional or a concrete value. If the optional has a value,
@@ -522,17 +523,17 @@ func (opt *evalOptionalOrValue) ID() int64 {
522523

523524
// Eval evaluates the left-hand side optional to determine whether it contains a value, else
524525
// proceeds with the right-hand side evaluation.
525-
func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val {
526+
func (opt *evalOptionalOrValue) Eval(ctx context.Context, vars interpreter.Activation) ref.Val {
526527
// short-circuit lhs.
527-
optLHS := opt.lhs.Eval(ctx)
528+
optLHS := opt.lhs.Eval(ctx, vars)
528529
optVal, ok := optLHS.(*types.Optional)
529530
if !ok {
530531
return optLHS
531532
}
532533
if optVal.HasValue() {
533534
return optVal.GetValue()
534535
}
535-
return opt.rhs.Eval(ctx)
536+
return opt.rhs.Eval(ctx, vars)
536537
}
537538

538539
type timeUTCLibrary struct{}

cel/program.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,11 @@ func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorat
264264

265265
// Eval implements the Program interface method.
266266
func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
267+
return p.eval(context.Background(), input)
268+
}
269+
270+
// Eval implements the Program interface method.
271+
func (p *prog) eval(ctx context.Context, input any) (v ref.Val, det *EvalDetails, err error) {
267272
// Configure error recovery for unexpected panics during evaluation. Note, the use of named
268273
// return values makes it possible to modify the error response during the recovery
269274
// function.
@@ -291,7 +296,7 @@ func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
291296
if p.defaultVars != nil {
292297
vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars)
293298
}
294-
v = p.interpretable.Eval(vars)
299+
v = p.interpretable.Eval(ctx, vars)
295300
// The output of an internal Eval may have a value (`v`) that is a types.Err. This step
296301
// translates the CEL value to a Go error response. This interface does not quite match the
297302
// RPC signature which allows for multiple errors to be returned, but should be sufficient.
@@ -321,7 +326,7 @@ func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetail
321326
default:
322327
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input)
323328
}
324-
return p.Eval(vars)
329+
return p.eval(ctx, vars)
325330
}
326331

327332
// progFactory is a helper alias for marking a program creation factory function.
@@ -349,6 +354,11 @@ func newProgGen(factory progFactory) (Program, error) {
349354

350355
// Eval implements the Program interface method.
351356
func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
357+
return gen.eval(context.Background(), input)
358+
}
359+
360+
// Eval implements the Program interface method.
361+
func (gen *progGen) eval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) {
352362
// The factory based Eval() differs from the standard evaluation model in that it generates a
353363
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
354364
// results.
@@ -368,7 +378,7 @@ func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
368378
}
369379

370380
// Evaluate the input, returning the result and the 'state' within EvalDetails.
371-
v, _, err := p.Eval(input)
381+
v, _, err := p.ContextEval(ctx, input)
372382
if err != nil {
373383
return v, det, err
374384
}

common/decls/decls.go

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package decls
1717

1818
import (
19+
"context"
1920
"fmt"
2021
"strings"
2122

@@ -242,23 +243,23 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
242243
// All of the defined overloads are wrapped into a top-level function which
243244
// performs dynamic dispatch to the proper overload based on the argument types.
244245
bindings := append([]*functions.Overload{}, overloads...)
245-
funcDispatch := func(args ...ref.Val) ref.Val {
246+
funcDispatch := func(ctx context.Context, args ...ref.Val) ref.Val {
246247
for _, oID := range f.overloadOrdinals {
247248
o := f.overloads[oID]
248249
// During dynamic dispatch over multiple functions, signature agreement checks
249250
// are preserved in order to assist with the function resolution step.
250251
switch len(args) {
251252
case 1:
252253
if o.unaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
253-
return o.unaryOp(args[0])
254+
return o.unaryOp(ctx, args[0])
254255
}
255256
case 2:
256257
if o.binaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
257-
return o.binaryOp(args[0], args[1])
258+
return o.binaryOp(ctx, args[0], args[1])
258259
}
259260
}
260261
if o.functionOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
261-
return o.functionOp(args...)
262+
return o.functionOp(ctx, args...)
262263
}
263264
// eventually this will fall through to the noSuchOverload below.
264265
}
@@ -333,8 +334,10 @@ func SingletonUnaryBinding(fn functions.UnaryOp, traits ...int) FunctionOpt {
333334
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
334335
}
335336
f.singleton = &functions.Overload{
336-
Operator: f.Name(),
337-
Unary: fn,
337+
Operator: f.Name(),
338+
Unary: func(ctx context.Context, val ref.Val) ref.Val {
339+
return fn(val)
340+
},
338341
OperandTrait: trait,
339342
}
340343
return f, nil
@@ -355,8 +358,10 @@ func SingletonBinaryBinding(fn functions.BinaryOp, traits ...int) FunctionOpt {
355358
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
356359
}
357360
f.singleton = &functions.Overload{
358-
Operator: f.Name(),
359-
Binary: fn,
361+
Operator: f.Name(),
362+
Binary: func(ctx context.Context, lhs ref.Val, rhs ref.Val) ref.Val {
363+
return fn(lhs, rhs)
364+
},
360365
OperandTrait: trait,
361366
}
362367
return f, nil
@@ -377,8 +382,10 @@ func SingletonFunctionBinding(fn functions.FunctionOp, traits ...int) FunctionOp
377382
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
378383
}
379384
f.singleton = &functions.Overload{
380-
Operator: f.Name(),
381-
Function: fn,
385+
Operator: f.Name(),
386+
Function: func(ctx context.Context, values ...ref.Val) ref.Val {
387+
return fn(values...)
388+
},
382389
OperandTrait: trait,
383390
}
384391
return f, nil
@@ -460,11 +467,11 @@ type OverloadDecl struct {
460467

461468
// Function implementation options. Optional, but encouraged.
462469
// unaryOp is a function binding that takes a single argument.
463-
unaryOp functions.UnaryOp
470+
unaryOp functions.UnaryContextOp
464471
// binaryOp is a function binding that takes two arguments.
465-
binaryOp functions.BinaryOp
472+
binaryOp functions.BinaryContextOp
466473
// functionOp is a catch-all for zero-arity and three-plus arity functions.
467-
functionOp functions.FunctionOp
474+
functionOp functions.FunctionContextOp
468475
}
469476

470477
// ID mirrors the overload signature and provides a unique id which may be referenced within the type-checker
@@ -580,41 +587,41 @@ func (o *OverloadDecl) hasBinding() bool {
580587
}
581588

582589
// guardedUnaryOp creates an invocation guard around the provided unary operator, if one is defined.
583-
func (o *OverloadDecl) guardedUnaryOp(funcName string, disableTypeGuards bool) functions.UnaryOp {
590+
func (o *OverloadDecl) guardedUnaryOp(funcName string, disableTypeGuards bool) functions.UnaryContextOp {
584591
if o.unaryOp == nil {
585592
return nil
586593
}
587-
return func(arg ref.Val) ref.Val {
594+
return func(ctx context.Context, arg ref.Val) ref.Val {
588595
if !o.matchesRuntimeUnarySignature(disableTypeGuards, arg) {
589596
return MaybeNoSuchOverload(funcName, arg)
590597
}
591-
return o.unaryOp(arg)
598+
return o.unaryOp(ctx, arg)
592599
}
593600
}
594601

595602
// guardedBinaryOp creates an invocation guard around the provided binary operator, if one is defined.
596-
func (o *OverloadDecl) guardedBinaryOp(funcName string, disableTypeGuards bool) functions.BinaryOp {
603+
func (o *OverloadDecl) guardedBinaryOp(funcName string, disableTypeGuards bool) functions.BinaryContextOp {
597604
if o.binaryOp == nil {
598605
return nil
599606
}
600-
return func(arg1, arg2 ref.Val) ref.Val {
607+
return func(ctx context.Context, arg1, arg2 ref.Val) ref.Val {
601608
if !o.matchesRuntimeBinarySignature(disableTypeGuards, arg1, arg2) {
602609
return MaybeNoSuchOverload(funcName, arg1, arg2)
603610
}
604-
return o.binaryOp(arg1, arg2)
611+
return o.binaryOp(ctx, arg1, arg2)
605612
}
606613
}
607614

608615
// guardedFunctionOp creates an invocation guard around the provided variadic function binding, if one is provided.
609-
func (o *OverloadDecl) guardedFunctionOp(funcName string, disableTypeGuards bool) functions.FunctionOp {
616+
func (o *OverloadDecl) guardedFunctionOp(funcName string, disableTypeGuards bool) functions.FunctionContextOp {
610617
if o.functionOp == nil {
611618
return nil
612619
}
613-
return func(args ...ref.Val) ref.Val {
620+
return func(ctx context.Context, args ...ref.Val) ref.Val {
614621
if !o.matchesRuntimeSignature(disableTypeGuards, args...) {
615622
return MaybeNoSuchOverload(funcName, args...)
616623
}
617-
return o.functionOp(args...)
624+
return o.functionOp(ctx, args...)
618625
}
619626
}
620627

@@ -667,6 +674,30 @@ type OverloadOpt func(*OverloadDecl) (*OverloadDecl, error)
667674
// UnaryBinding provides the implementation of a unary overload. The provided function is protected by a runtime
668675
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
669676
func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
677+
return UnaryBindingContext(func(ctx context.Context, val ref.Val) ref.Val {
678+
return binding(val)
679+
})
680+
}
681+
682+
// BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime
683+
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
684+
func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
685+
return BinaryBindingContext(func(ctx context.Context, lhs ref.Val, rhs ref.Val) ref.Val {
686+
return binding(lhs, rhs)
687+
})
688+
}
689+
690+
// FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime
691+
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
692+
func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
693+
return FunctionBindingContext(func(ctx context.Context, values ...ref.Val) ref.Val {
694+
return binding(values...)
695+
})
696+
}
697+
698+
// UnaryBindingContext provides the implementation of a unary overload. The provided function is protected by a runtime
699+
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
700+
func UnaryBindingContext(binding functions.UnaryContextOp) OverloadOpt {
670701
return func(o *OverloadDecl) (*OverloadDecl, error) {
671702
if o.hasBinding() {
672703
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
@@ -679,9 +710,9 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
679710
}
680711
}
681712

682-
// BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime
713+
// BinaryBindingContext provides the implementation of a binary overload. The provided function is protected by a runtime
683714
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
684-
func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
715+
func BinaryBindingContext(binding functions.BinaryContextOp) OverloadOpt {
685716
return func(o *OverloadDecl) (*OverloadDecl, error) {
686717
if o.hasBinding() {
687718
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
@@ -694,9 +725,9 @@ func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
694725
}
695726
}
696727

697-
// FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime
728+
// FunctionBindingContext provides the implementation of a variadic overload. The provided function is protected by a runtime
698729
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
699-
func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
730+
func FunctionBindingContext(binding functions.FunctionContextOp) OverloadOpt {
700731
return func(o *OverloadDecl) (*OverloadDecl, error) {
701732
if o.hasBinding() {
702733
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())

0 commit comments

Comments
 (0)