Skip to content

Commit 80f0ea6

Browse files
committed
Correctly add OpDeref if needed
Fixes #739
1 parent 2d9f616 commit 80f0ea6

File tree

10 files changed

+118
-15
lines changed

10 files changed

+118
-15
lines changed

builtin/builtin.go

+9
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,9 @@ var Builtins = []*Function{
493493
}
494494
return anyType, fmt.Errorf("invalid number of arguments (expected 0, got %d)", len(args))
495495
},
496+
Deref: func(i int, arg reflect.Type) bool {
497+
return false
498+
},
496499
},
497500
{
498501
Name: "duration",
@@ -567,6 +570,12 @@ var Builtins = []*Function{
567570
}
568571
return timeType, nil
569572
},
573+
Deref: func(i int, arg reflect.Type) bool {
574+
if arg.AssignableTo(locationType) {
575+
return false
576+
}
577+
return true
578+
},
570579
},
571580
{
572581
Name: "timezone",

builtin/builtin_test.go

+13-3
Original file line numberDiff line numberDiff line change
@@ -642,11 +642,17 @@ func Test_int_unwraps_underlying_value(t *testing.T) {
642642
func TestBuiltin_with_deref(t *testing.T) {
643643
x := 42
644644
arr := []any{1, 2, 3}
645+
arrStr := []string{"1", "2", "3"}
645646
m := map[string]any{"a": 1, "b": 2}
647+
jsonString := `["1"]`
648+
str := "1,2,3"
646649
env := map[string]any{
647-
"x": &x,
648-
"arr": &arr,
649-
"m": &m,
650+
"x": &x,
651+
"arr": &arr,
652+
"arrStr": &arrStr,
653+
"m": &m,
654+
"json": &jsonString,
655+
"str": &str,
650656
}
651657

652658
tests := []struct {
@@ -669,6 +675,10 @@ func TestBuiltin_with_deref(t *testing.T) {
669675
{`uniq(arr)`, []any{1, 2, 3}},
670676
{`concat(arr, arr)`, []any{1, 2, 3, 1, 2, 3}},
671677
{`flatten([arr, [arr]])`, []any{1, 2, 3, 1, 2, 3}},
678+
{`toJSON(arr)`, "[\n 1,\n 2,\n 3\n]"},
679+
{`fromJSON(json)`, []any{"1"}},
680+
{`split(str, ",")`, []string{"1", "2", "3"}},
681+
{`join(arrStr, ",")`, "1,2,3"},
672682
}
673683

674684
for _, test := range tests {

builtin/function.go

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ type Function struct {
1111
Safe func(args ...any) (any, uint, error)
1212
Types []reflect.Type
1313
Validate func(args []reflect.Type) (reflect.Type, error)
14+
Deref func(i int, arg reflect.Type) bool
1415
Predicate bool
1516
}
1617

checker/checker.go

+8
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,14 @@ func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments []
975975
lastErr = err
976976
continue
977977
}
978+
979+
// As we found the correct function overload, we can stop the loop.
980+
// Also, we need to set the correct nature of the callee so compiler,
981+
// can correctly handle OpDeref opcode.
982+
if callNode, ok := node.(*ast.CallNode); ok {
983+
callNode.Callee.SetType(t)
984+
}
985+
978986
return outNature
979987
}
980988
if lastErr != nil {

checker/nature/nature.go

+18
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,24 @@ type Nature struct {
2525
FieldIndex []int // Index of field in type.
2626
}
2727

28+
func (n Nature) IsNil() bool {
29+
return n.Nil
30+
}
31+
32+
func (n Nature) IsAny() bool {
33+
return n.Kind() == reflect.Interface && n.NumMethods() == 0
34+
}
35+
36+
func (n Nature) IsUnknown() bool {
37+
switch {
38+
case n.Type == nil && !n.Nil:
39+
return true
40+
case isAny(n):
41+
return true
42+
}
43+
return false
44+
}
45+
2846
func (n Nature) String() string {
2947
if n.Type != nil {
3048
return n.Type.String()

compiler/compiler.go

+32-11
Original file line numberDiff line numberDiff line change
@@ -750,17 +750,15 @@ func (c *compiler) CallNode(node *ast.CallNode) {
750750
}
751751
for i, arg := range node.Arguments {
752752
c.compile(arg)
753-
if k := kind(arg.Type()); k == reflect.Ptr || k == reflect.Interface {
754-
var in reflect.Type
755-
if fn.IsVariadic() && i >= fnNumIn-1 {
756-
in = fn.In(fn.NumIn() - 1).Elem()
757-
} else {
758-
in = fn.In(i + fnInOffset)
759-
}
760-
if k = kind(in); k != reflect.Ptr && k != reflect.Interface {
761-
c.emit(OpDeref)
762-
}
753+
754+
var in reflect.Type
755+
if fn.IsVariadic() && i >= fnNumIn-1 {
756+
in = fn.In(fn.NumIn() - 1).Elem()
757+
} else {
758+
in = fn.In(i + fnInOffset)
763759
}
760+
761+
c.derefParam(in, arg)
764762
}
765763
} else {
766764
for _, arg := range node.Arguments {
@@ -1059,8 +1057,19 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
10591057

10601058
if id, ok := builtin.Index[node.Name]; ok {
10611059
f := builtin.Builtins[id]
1062-
for _, arg := range node.Arguments {
1060+
for i, arg := range node.Arguments {
10631061
c.compile(arg)
1062+
argType := arg.Type()
1063+
if argType.Kind() == reflect.Ptr || arg.Nature().IsUnknown() {
1064+
if f.Deref == nil {
1065+
// By default, builtins expect arguments to be dereferenced.
1066+
c.emit(OpDeref)
1067+
} else {
1068+
if f.Deref(i, argType) {
1069+
c.emit(OpDeref)
1070+
}
1071+
}
1072+
}
10641073
}
10651074

10661075
if f.Fast != nil {
@@ -1218,6 +1227,18 @@ func (c *compiler) derefInNeeded(node ast.Node) {
12181227
}
12191228
}
12201229

1230+
func (c *compiler) derefParam(in reflect.Type, param ast.Node) {
1231+
if param.Nature().Nil {
1232+
return
1233+
}
1234+
if param.Type().AssignableTo(in) {
1235+
return
1236+
}
1237+
if in.Kind() != reflect.Ptr && param.Type().Kind() == reflect.Ptr {
1238+
c.emit(OpDeref)
1239+
}
1240+
}
1241+
12211242
func (c *compiler) optimize() {
12221243
for i, op := range c.bytecode {
12231244
switch op {

expr.go

+1
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ func Eval(input string, env any) (any, error) {
244244
if err != nil {
245245
return nil, err
246246
}
247+
println(program.Disassemble())
247248

248249
output, err := Run(program, env)
249250
if err != nil {

patcher/value/value_test.go

-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ func Test_valueAddInt(t *testing.T) {
8888

8989
program, err := expr.Compile("ValueOne + ValueTwo", expr.Env(env), ValueGetter)
9090
require.NoError(t, err)
91-
9291
out, err := vm.Run(program, env)
9392

9493
require.NoError(t, err)

test/deref/deref_test.go

+17
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,20 @@ func TestDeref_ignore_struct_func_args(t *testing.T) {
320320
require.NoError(t, err)
321321
require.Equal(t, "UTC", out)
322322
}
323+
324+
func TestDeref_keep_pointer_if_arg_in_interface(t *testing.T) {
325+
x := 42
326+
env := map[string]any{
327+
"x": &x,
328+
"fn": func(p any) int {
329+
return *p.(*int) + 1
330+
},
331+
}
332+
333+
program, err := expr.Compile(`fn(x)`, expr.Env(env))
334+
require.NoError(t, err)
335+
336+
out, err := expr.Run(program, env)
337+
require.NoError(t, err)
338+
require.Equal(t, 43, out)
339+
}

test/issues/739/issue_test.go

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package issue_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/expr-lang/expr"
7+
"github.com/expr-lang/expr/internal/testify/require"
8+
)
9+
10+
func TestIssue739(t *testing.T) {
11+
jsonString := `{"Num": 1}`
12+
env := map[string]any{
13+
"aJSONString": &jsonString,
14+
}
15+
16+
result, err := expr.Eval("fromJSON(aJSONString)", env)
17+
require.NoError(t, err)
18+
require.Contains(t, result, "Num")
19+
}

0 commit comments

Comments
 (0)