Skip to content

Update the Go AST representation to handle a second iteration variable #1031

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions common/ast/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ go_library(
"navigable.go",
],
importpath = "github.com/google/cel-go/common/ast",
deps = [
deps = [
"//common:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
Expand All @@ -35,12 +35,13 @@ go_test(
embed = [
":go_default_library",
],
deps = [
deps = [
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/containers:go_default_library",
"//common/decls:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/stdlib:go_default_library",
"//common/types:go_default_library",
Expand Down
4 changes: 3 additions & 1 deletion common/ast/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,10 @@ func exprComprehension(factory ExprFactory, id int64, comp *exprpb.Expr_Comprehe
if err != nil {
return nil, err
}
return factory.NewComprehension(id,
return factory.NewComprehensionTwoVar(id,
iterRange,
comp.GetIterVar(),
comp.GetIterVar2(),
comp.GetAccuVar(),
accuInit,
loopCond,
Expand Down Expand Up @@ -363,6 +364,7 @@ func protoComprehension(id int64, comp ComprehensionExpr) (*exprpb.Expr, error)
ExprKind: &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterVar: comp.IterVar(),
IterVar2: comp.IterVar2(),
IterRange: iterRange,
AccuVar: comp.AccuVar(),
AccuInit: accuInit,
Expand Down
205 changes: 201 additions & 4 deletions common/ast/conversion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
Expand All @@ -35,6 +36,7 @@ import (
)

func TestConvertAST(t *testing.T) {
fac := ast.NewExprFactory()
tests := []struct {
goAST *ast.AST
pbAST *exprpb.CheckedExpr
Expand Down Expand Up @@ -68,6 +70,115 @@ func TestConvertAST(t *testing.T) {
},
},
},
{
goAST: ast.NewAST(
fac.NewComprehensionTwoVar(1,
fac.NewIdent(2, "data"),
"i",
"v",
"__result__",
fac.NewList(3, []ast.Expr{}, []int32{}),
fac.NewLiteral(4, types.True),
fac.NewCall(8, operators.Add,
fac.NewAccuIdent(9),
fac.NewCall(5, operators.Add,
fac.NewIdent(6, "i"),
fac.NewIdent(7, "v"),
)),
fac.NewAccuIdent(10),
), nil),
pbAST: &exprpb.CheckedExpr{
Expr: &exprpb.Expr{
Id: 1,
ExprKind: &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterRange: &exprpb.Expr{
Id: 2,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: "data",
},
},
},
IterVar: "i",
IterVar2: "v",
AccuVar: "__result__",
AccuInit: &exprpb.Expr{
Id: 3,
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{},
},
},
LoopCondition: &exprpb.Expr{
Id: 4,
ExprKind: &exprpb.Expr_ConstExpr{
ConstExpr: &exprpb.Constant{
ConstantKind: &exprpb.Constant_BoolValue{
BoolValue: true,
},
},
},
},
LoopStep: &exprpb.Expr{
Id: 8,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: operators.Add,
Args: []*exprpb.Expr{
{
Id: 9,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: "__result__",
},
},
},
{
Id: 5,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: operators.Add,
Args: []*exprpb.Expr{
{
Id: 6,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: "i",
},
},
},
{
Id: 7,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: "v",
},
},
},
},
},
},
},
},
},
},
},
Result: &exprpb.Expr{
Id: 10,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: "__result__",
},
},
},
},
},
},
SourceInfo: &exprpb.SourceInfo{},
TypeMap: map[int64]*exprpb.Type{},
ReferenceMap: map[int64]*exprpb.Reference{},
},
},
}

for i, tst := range tests {
Expand All @@ -83,11 +194,13 @@ func TestConvertAST(t *testing.T) {
!reflect.DeepEqual(checkedAST.TypeMap(), goAST.TypeMap()) {
t.Errorf("conversion to AST did not produce identical results: got %v, wanted %v", checkedAST, goAST)
}
if !checkedAST.ReferenceMap()[1].Equals(goAST.ReferenceMap()[1]) ||
!checkedAST.ReferenceMap()[2].Equals(goAST.ReferenceMap()[2]) {
t.Error("converted reference info values not equal")
if len(checkedAST.ReferenceMap()) > 2 {
if !checkedAST.ReferenceMap()[1].Equals(goAST.ReferenceMap()[1]) ||
!checkedAST.ReferenceMap()[2].Equals(goAST.ReferenceMap()[2]) {
t.Error("converted reference info values not equal")
}
}
checkedExpr, err := ast.ToProto(goAST)
checkedExpr, err := ast.ToProto(checkedAST)
if err != nil {
t.Fatalf("ASTToProto() failed: %v", err)
}
Expand All @@ -98,6 +211,90 @@ func TestConvertAST(t *testing.T) {
}
}

func TestConvertProtoToEntryExpr(t *testing.T) {
fac := ast.NewExprFactory()
tests := []struct {
goAST ast.EntryExpr
pbAST *exprpb.Expr_CreateStruct_Entry
}{
{
goAST: fac.NewMapEntry(1,
fac.NewIdent(2, "var_key"),
fac.NewLiteral(3, types.String("hello")),
true),
pbAST: &exprpb.Expr_CreateStruct_Entry{
Id: 1,
KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{
MapKey: &exprpb.Expr{
Id: 2,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: "var_key",
},
},
},
},
Value: &exprpb.Expr{
Id: 3,
ExprKind: &exprpb.Expr_ConstExpr{
ConstExpr: &exprpb.Constant{
ConstantKind: &exprpb.Constant_StringValue{
StringValue: "hello",
},
},
},
},
OptionalEntry: true,
},
},
{
goAST: fac.NewStructField(1,
"field_name",
fac.NewLiteral(2, types.String("hello")),
false),
pbAST: &exprpb.Expr_CreateStruct_Entry{
Id: 1,
KeyKind: &exprpb.Expr_CreateStruct_Entry_FieldKey{
FieldKey: "field_name",
},
Value: &exprpb.Expr{
Id: 2,
ExprKind: &exprpb.Expr_ConstExpr{
ConstExpr: &exprpb.Constant{
ConstantKind: &exprpb.Constant_StringValue{
StringValue: "hello",
},
},
},
},
OptionalEntry: false,
},
},
}

for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
goAST := tc.goAST
pbAST := tc.pbAST
gotGoAST, err := ast.ProtoToEntryExpr(pbAST)
if err != nil {
t.Fatalf("ProtoToEntryExpr() failed: %v", err)
}
if !reflect.DeepEqual(goAST, gotGoAST) {
t.Errorf("conversion to go AST did not produce identical results: got %v, wanted %v", gotGoAST, goAST)
}
gotProtoAST, err := ast.EntryExprToProto(gotGoAST)
if err != nil {
t.Fatalf("EntryExprToProto() failed: %v", err)
}
if !proto.Equal(gotProtoAST, pbAST) {
t.Errorf("conversion to protobuf did not produce identical results: got %v, wanted %v", gotProtoAST, pbAST)
}
})
}
}

func TestConvertExpr(t *testing.T) {
fac := ast.NewExprFactory()
tests := []struct {
Expand Down
24 changes: 24 additions & 0 deletions common/ast/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,22 @@ type ComprehensionExpr interface {
IterRange() Expr

// IterVar returns the iteration variable name.
//
// For one-variable comprehensions, the iter var refers to the element value
// when iterating over a list, or the map key when iterating over a map.
//
// For two-variable comprehneions, the iter var refers to the list index or the
// map key.
IterVar() string

// IterVar2 returns the second iteration variable name.
//
// When the value is non-empty, the comprehension is a two-variable comprehension.
IterVar2() string

// HasIterVar2 returns true if the second iteration variable is non-empty.
HasIterVar2() bool

// AccuVar returns the accumulation variable name.
AccuVar() string

Expand Down Expand Up @@ -397,6 +411,7 @@ func (e *expr) SetKindCase(other Expr) {
e.exprKindCase = &baseComprehensionExpr{
iterRange: c.IterRange(),
iterVar: c.IterVar(),
iterVar2: c.IterVar2(),
accuVar: c.AccuVar(),
accuInit: c.AccuInit(),
loopCond: c.LoopCondition(),
Expand Down Expand Up @@ -505,6 +520,7 @@ var _ ComprehensionExpr = &baseComprehensionExpr{}
type baseComprehensionExpr struct {
iterRange Expr
iterVar string
iterVar2 string
accuVar string
accuInit Expr
loopCond Expr
Expand All @@ -527,6 +543,14 @@ func (e *baseComprehensionExpr) IterVar() string {
return e.iterVar
}

func (e *baseComprehensionExpr) IterVar2() string {
return e.iterVar2
}

func (e *baseComprehensionExpr) HasIterVar2() bool {
return e.iterVar2 != ""
}

func (e *baseComprehensionExpr) AccuVar() string {
return e.accuVar
}
Expand Down
14 changes: 12 additions & 2 deletions common/ast/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ type ExprFactory interface {
// NewCall creates an Expr value representing a global function call.
NewCall(id int64, function string, args ...Expr) Expr

// NewComprehension creates an Expr value representing a comprehension over a value range.
// NewComprehension creates an Expr value representing a one-variable comprehension over a value range.
NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCondition, loopStep, result Expr) Expr

// NewComprehensionTwoVar creates an Expr value representing a two-variable comprehension over a value range.
NewComprehensionTwoVar(id int64, iterRange Expr, iterVar, iterVar2, accuVar string, accuInit, loopCondition, loopStep, result Expr) Expr

// NewMemberCall creates an Expr value representing a member function call.
NewMemberCall(id int64, function string, receiver Expr, args ...Expr) Expr

Expand Down Expand Up @@ -111,11 +114,17 @@ func (fac *baseExprFactory) NewMemberCall(id int64, function string, target Expr
}

func (fac *baseExprFactory) NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCond, loopStep, result Expr) Expr {
// Set the iter_var2 to empty string to indicate the second variable is omitted
return fac.NewComprehensionTwoVar(id, iterRange, iterVar, "", accuVar, accuInit, loopCond, loopStep, result)
}

func (fac *baseExprFactory) NewComprehensionTwoVar(id int64, iterRange Expr, iterVar, iterVar2, accuVar string, accuInit, loopCond, loopStep, result Expr) Expr {
return fac.newExpr(
id,
&baseComprehensionExpr{
iterRange: iterRange,
iterVar: iterVar,
iterVar2: iterVar2,
accuVar: accuVar,
accuInit: accuInit,
loopCond: loopCond,
Expand Down Expand Up @@ -223,9 +232,10 @@ func (fac *baseExprFactory) CopyExpr(e Expr) Expr {
return fac.NewMemberCall(e.ID(), c.FunctionName(), fac.CopyExpr(c.Target()), argsCopy...)
case ComprehensionKind:
compre := e.AsComprehension()
return fac.NewComprehension(e.ID(),
return fac.NewComprehensionTwoVar(e.ID(),
fac.CopyExpr(compre.IterRange()),
compre.IterVar(),
compre.IterVar2(),
compre.AccuVar(),
fac.CopyExpr(compre.AccuInit()),
fac.CopyExpr(compre.LoopCondition()),
Expand Down
8 changes: 8 additions & 0 deletions common/ast/navigable.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,14 @@ func (comp navigableComprehensionImpl) IterVar() string {
return comp.Expr.AsComprehension().IterVar()
}

func (comp navigableComprehensionImpl) IterVar2() string {
return comp.Expr.AsComprehension().IterVar2()
}

func (comp navigableComprehensionImpl) HasIterVar2() bool {
return comp.Expr.AsComprehension().HasIterVar2()
}

func (comp navigableComprehensionImpl) AccuVar() string {
return comp.Expr.AsComprehension().AccuVar()
}
Expand Down
Loading