Skip to content

Commit 0f9133d

Browse files
authored
Add LateFunctionBinding declaration and fix constant folding (#1117)
Adds the declaration for LateFunctionBindings, which can be used to indicate to the Runtime and the Optimizers that the function will be bound at runtime through the Activation. This lets the constant folding optimizer know that the function potentially has side effects and cannot be folded. Without this the optimization will fail with an error for late bound functions where all arguments are constants. The implementation for late bound functions will be added in a subsequent commit.
1 parent 6b7ecea commit 0f9133d

File tree

5 files changed

+251
-3
lines changed

5 files changed

+251
-3
lines changed

cel/decls.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,12 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
350350
return decls.FunctionBinding(binding)
351351
}
352352

353+
// LateFunctionBinding indicates that the function has a binding which is not known at compile time.
354+
// This is useful for functions which have side-effects or are not deterministically computable.
355+
func LateFunctionBinding() OverloadOpt {
356+
return decls.LateFunctionBinding()
357+
}
358+
353359
// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
354360
//
355361
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.

cel/folding.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
6868
// Walk the list of foldable expression and continue to fold until there are no more folds left.
6969
// All of the fold candidates returned by the constantExprMatcher should succeed unless there's
7070
// a logic bug with the selection of expressions.
71-
foldableExprs := ast.MatchDescendants(root, constantExprMatcher)
71+
constantExprMatcherCapture := func(e ast.NavigableExpr) bool { return constantExprMatcher(ctx, a, e) }
72+
foldableExprs := ast.MatchDescendants(root, constantExprMatcherCapture)
7273
foldCount := 0
7374
for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations {
7475
for _, fold := range foldableExprs {
@@ -77,6 +78,10 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
7778
if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) {
7879
continue
7980
}
81+
// Late-bound function calls cannot be folded.
82+
if fold.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, fold) {
83+
continue
84+
}
8085
// Otherwise, assume all context is needed to evaluate the expression.
8186
err := tryFold(ctx, a, fold)
8287
if err != nil {
@@ -85,7 +90,7 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
8590
}
8691
}
8792
foldCount++
88-
foldableExprs = ast.MatchDescendants(root, constantExprMatcher)
93+
foldableExprs = ast.MatchDescendants(root, constantExprMatcherCapture)
8994
}
9095
// Once all of the constants have been folded, try to run through the remaining comprehensions
9196
// one last time. In this case, there's no guarantee they'll run, so we only update the
@@ -139,6 +144,15 @@ func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
139144
return nil
140145
}
141146

147+
func isLateBoundFunctionCall(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) bool {
148+
call := expr.AsCall()
149+
function := ctx.Functions()[call.FunctionName()]
150+
if function == nil {
151+
return false
152+
}
153+
return function.HasLateBinding()
154+
}
155+
142156
// maybePruneBranches inspects the non-strict call expression to determine whether
143157
// a branch can be removed. Evaluation will naturally prune logical and / or calls,
144158
// but conditional will not be pruned cleanly, so this is one small area where the
@@ -455,7 +469,7 @@ func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) {
455469
// Only comprehensions which are not nested are included as possible constant folds, and only
456470
// if all variables referenced in the comprehension stack exist are only iteration or
457471
// accumulation variables.
458-
func constantExprMatcher(e ast.NavigableExpr) bool {
472+
func constantExprMatcher(ctx *OptimizerContext, a *ast.AST, e ast.NavigableExpr) bool {
459473
switch e.Kind() {
460474
case ast.CallKind:
461475
return constantCallMatcher(e)
@@ -477,6 +491,10 @@ func constantExprMatcher(e ast.NavigableExpr) bool {
477491
if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] {
478492
constantExprs = false
479493
}
494+
// Late-bound function calls cannot be folded.
495+
if e.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, e) {
496+
constantExprs = false
497+
}
480498
})
481499
ast.PreOrderVisit(e, visitor)
482500
return constantExprs

cel/folding_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ package cel
1717
import (
1818
"reflect"
1919
"sort"
20+
"strings"
2021
"testing"
2122

2223
"google.golang.org/protobuf/encoding/prototext"
2324
"google.golang.org/protobuf/proto"
2425

2526
"github.com/google/cel-go/common/ast"
27+
"github.com/google/cel-go/common/types/ref"
2628

2729
proto3pb "github.com/google/cel-go/test/proto3pb"
2830
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@@ -313,6 +315,89 @@ func TestConstantFoldingOptimizer(t *testing.T) {
313315
}
314316
}
315317

318+
func TestConstantFoldingCallsWithSideEffects(t *testing.T) {
319+
tests := []struct {
320+
expr string
321+
folded string
322+
error string
323+
}{
324+
{
325+
expr: `noSideEffect(3)`,
326+
folded: `3`,
327+
},
328+
{
329+
expr: `withSideEffect(3)`,
330+
folded: `withSideEffect(3)`,
331+
},
332+
{
333+
expr: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && withSideEffect(i.b) == 1)`,
334+
folded: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && withSideEffect(i.b) == 1)`,
335+
},
336+
{
337+
expr: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && noSideEffect(i.b) == 2)`,
338+
folded: `true`,
339+
},
340+
{
341+
expr: `noImpl(3)`,
342+
error: `constant-folding evaluation failed: no such overload: noImpl`,
343+
},
344+
}
345+
e, err := NewEnv(
346+
OptionalTypes(),
347+
EnableMacroCallTracking(),
348+
Function("noSideEffect",
349+
Overload("noSideEffect_int_int",
350+
[]*Type{IntType},
351+
IntType, FunctionBinding(func(args ...ref.Val) ref.Val {
352+
return args[0]
353+
}))),
354+
Function("withSideEffect",
355+
Overload("withSideEffect_int_int",
356+
[]*Type{IntType},
357+
IntType, LateFunctionBinding())),
358+
Function("noImpl",
359+
Overload("noImpl_int_int",
360+
[]*Type{IntType},
361+
IntType)),
362+
)
363+
if err != nil {
364+
t.Fatalf("NewEnv() failed: %v", err)
365+
}
366+
for _, tst := range tests {
367+
tc := tst
368+
t.Run(tc.expr, func(t *testing.T) {
369+
checked, iss := e.Compile(tc.expr)
370+
if iss.Err() != nil {
371+
t.Fatalf("Compile() failed: %v", iss.Err())
372+
}
373+
folder, err := NewConstantFoldingOptimizer()
374+
if err != nil {
375+
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
376+
}
377+
opt := NewStaticOptimizer(folder)
378+
optimized, iss := opt.Optimize(e, checked)
379+
if tc.error != "" {
380+
if iss.Err() == nil {
381+
t.Errorf("got nil, wanted error containing %q", tc.error)
382+
} else if !strings.Contains(iss.Err().Error(), tc.error) {
383+
t.Errorf("got %q, wanted error containing %q", iss.Err().Error(), tc.error)
384+
}
385+
return
386+
}
387+
if iss.Err() != nil {
388+
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
389+
}
390+
folded, err := AstToString(optimized)
391+
if err != nil {
392+
t.Fatalf("AstToString() failed: %v", err)
393+
}
394+
if folded != tc.folded {
395+
t.Errorf("got %q, wanted %q", folded, tc.folded)
396+
}
397+
})
398+
}
399+
}
400+
316401
func TestConstantFoldingOptimizerMacroElimination(t *testing.T) {
317402
tests := []struct {
318403
expr string

common/decls/decls.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,9 @@ func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error {
281281
}
282282
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.Name(), oID)
283283
}
284+
if overload.HasLateBinding() != o.HasLateBinding() {
285+
return fmt.Errorf("overload with late binding cannot be added to function %s: cannot mix late and non-late bindings", f.Name())
286+
}
284287
}
285288
f.overloadOrdinals = append(f.overloadOrdinals, overload.ID())
286289
f.overloads[overload.ID()] = overload
@@ -300,6 +303,19 @@ func (f *FunctionDecl) OverloadDecls() []*OverloadDecl {
300303
return overloads
301304
}
302305

306+
// Returns true if the function has late bindings. A function cannot mix late bindings with other bindings.
307+
func (f *FunctionDecl) HasLateBinding() bool {
308+
if f == nil {
309+
return false
310+
}
311+
for _, oID := range f.overloadOrdinals {
312+
if f.overloads[oID].HasLateBinding() {
313+
return true
314+
}
315+
}
316+
return false
317+
}
318+
303319
// Bindings produces a set of function bindings, if any are defined.
304320
func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
305321
var emptySet []*functions.Overload
@@ -308,8 +324,10 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
308324
}
309325
overloads := []*functions.Overload{}
310326
nonStrict := false
327+
hasLateBinding := false
311328
for _, oID := range f.overloadOrdinals {
312329
o := f.overloads[oID]
330+
hasLateBinding = hasLateBinding || o.HasLateBinding()
313331
if o.hasBinding() {
314332
overload := &functions.Overload{
315333
Operator: o.ID(),
@@ -327,6 +345,9 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
327345
if len(overloads) != 0 {
328346
return nil, fmt.Errorf("singleton function incompatible with specialized overloads: %s", f.Name())
329347
}
348+
if hasLateBinding {
349+
return nil, fmt.Errorf("singleton function incompatible with late bindings: %s", f.Name())
350+
}
330351
overloads = []*functions.Overload{
331352
{
332353
Operator: f.Name(),
@@ -576,6 +597,9 @@ type OverloadDecl struct {
576597
argTypes []*types.Type
577598
resultType *types.Type
578599
isMemberFunction bool
600+
// hasLateBinding indicates that the function has a binding which is not known at compile time.
601+
// This is useful for functions which have side-effects or are not deterministically computable.
602+
hasLateBinding bool
579603
// nonStrict indicates that the function will accept error and unknown arguments as inputs.
580604
nonStrict bool
581605
// operandTrait indicates whether the member argument should have a specific type-trait.
@@ -640,6 +664,14 @@ func (o *OverloadDecl) IsNonStrict() bool {
640664
return o.nonStrict
641665
}
642666

667+
// HasLateBinding returns whether the overload has a binding which is not known at compile time.
668+
func (o *OverloadDecl) HasLateBinding() bool {
669+
if o == nil {
670+
return false
671+
}
672+
return o.hasLateBinding
673+
}
674+
643675
// OperandTrait returns the trait mask of the first operand to the overload call, e.g.
644676
// `traits.Indexer`
645677
func (o *OverloadDecl) OperandTrait() int {
@@ -816,6 +848,9 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
816848
if len(o.ArgTypes()) != 1 {
817849
return nil, fmt.Errorf("unary function bound to non-unary overload: %s", o.ID())
818850
}
851+
if o.hasLateBinding {
852+
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
853+
}
819854
o.unaryOp = binding
820855
return o, nil
821856
}
@@ -831,6 +866,9 @@ func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
831866
if len(o.ArgTypes()) != 2 {
832867
return nil, fmt.Errorf("binary function bound to non-binary overload: %s", o.ID())
833868
}
869+
if o.hasLateBinding {
870+
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
871+
}
834872
o.binaryOp = binding
835873
return o, nil
836874
}
@@ -843,11 +881,26 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
843881
if o.hasBinding() {
844882
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
845883
}
884+
if o.hasLateBinding {
885+
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
886+
}
846887
o.functionOp = binding
847888
return o, nil
848889
}
849890
}
850891

892+
// LateFunctionBinding indicates that the function has a binding which is not known at compile time.
893+
// This is useful for functions which have side-effects or are not deterministically computable.
894+
func LateFunctionBinding() OverloadOpt {
895+
return func(o *OverloadDecl) (*OverloadDecl, error) {
896+
if o.hasBinding() {
897+
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
898+
}
899+
o.hasLateBinding = true
900+
return o, nil
901+
}
902+
}
903+
851904
// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
852905
//
853906
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.

0 commit comments

Comments
 (0)