Skip to content

Commit b26d4b7

Browse files
committed
Fix constant folding for floats and ints
1 parent 3d4c219 commit b26d4b7

File tree

2 files changed

+163
-23
lines changed

2 files changed

+163
-23
lines changed

expr_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -1386,11 +1386,11 @@ func TestIssue138(t *testing.T) {
13861386
env := map[string]interface{}{}
13871387

13881388
_, err := expr.Compile(`1 / (1 - 1)`, expr.Env(env))
1389-
require.Error(t, err)
1390-
require.Equal(t, "integer divide by zero (1:3)\n | 1 / (1 - 1)\n | ..^", err.Error())
1389+
require.NoError(t, err)
13911390

13921391
_, err = expr.Compile(`1 % 0`, expr.Env(env))
13931392
require.Error(t, err)
1393+
require.Equal(t, "integer divide by zero (1:3)\n | 1 % 0\n | ..^", err.Error())
13941394
}
13951395

13961396
func TestIssue154(t *testing.T) {

optimizer/fold.go

+161-21
Original file line numberDiff line numberDiff line change
@@ -32,48 +32,141 @@ func (fold *fold) Visit(node *Node) {
3232
if i, ok := n.Node.(*IntegerNode); ok {
3333
patchWithType(&IntegerNode{Value: -i.Value}, n.Node.Type())
3434
}
35+
if i, ok := n.Node.(*FloatNode); ok {
36+
patchWithType(&FloatNode{Value: -i.Value}, n.Node.Type())
37+
}
3538
case "+":
3639
if i, ok := n.Node.(*IntegerNode); ok {
3740
patchWithType(&IntegerNode{Value: i.Value}, n.Node.Type())
3841
}
42+
if i, ok := n.Node.(*FloatNode); ok {
43+
patchWithType(&FloatNode{Value: i.Value}, n.Node.Type())
44+
}
3945
}
4046

4147
case *BinaryNode:
4248
switch n.Operator {
4349
case "+":
44-
if a, ok := n.Left.(*IntegerNode); ok {
45-
if b, ok := n.Right.(*IntegerNode); ok {
50+
{
51+
a := toInteger(n.Left)
52+
b := toInteger(n.Right)
53+
if a != nil && b != nil {
4654
patchWithType(&IntegerNode{Value: a.Value + b.Value}, a.Type())
4755
}
4856
}
49-
if a, ok := n.Left.(*StringNode); ok {
50-
if b, ok := n.Right.(*StringNode); ok {
57+
{
58+
a := toInteger(n.Left)
59+
b := toFloat(n.Right)
60+
if a != nil && b != nil {
61+
patchWithType(&FloatNode{Value: float64(a.Value) + b.Value}, a.Type())
62+
}
63+
}
64+
{
65+
a := toFloat(n.Left)
66+
b := toInteger(n.Right)
67+
if a != nil && b != nil {
68+
patchWithType(&FloatNode{Value: a.Value + float64(b.Value)}, a.Type())
69+
}
70+
}
71+
{
72+
a := toFloat(n.Left)
73+
b := toFloat(n.Right)
74+
if a != nil && b != nil {
75+
patchWithType(&FloatNode{Value: a.Value + b.Value}, a.Type())
76+
}
77+
}
78+
{
79+
a := toString(n.Left)
80+
b := toString(n.Right)
81+
if a != nil && b != nil {
5182
patch(&StringNode{Value: a.Value + b.Value})
5283
}
5384
}
5485
case "-":
55-
if a, ok := n.Left.(*IntegerNode); ok {
56-
if b, ok := n.Right.(*IntegerNode); ok {
86+
{
87+
a := toInteger(n.Left)
88+
b := toInteger(n.Right)
89+
if a != nil && b != nil {
5790
patchWithType(&IntegerNode{Value: a.Value - b.Value}, a.Type())
5891
}
5992
}
93+
{
94+
a := toInteger(n.Left)
95+
b := toFloat(n.Right)
96+
if a != nil && b != nil {
97+
patchWithType(&FloatNode{Value: float64(a.Value) - b.Value}, a.Type())
98+
}
99+
}
100+
{
101+
a := toFloat(n.Left)
102+
b := toInteger(n.Right)
103+
if a != nil && b != nil {
104+
patchWithType(&FloatNode{Value: a.Value - float64(b.Value)}, a.Type())
105+
}
106+
}
107+
{
108+
a := toFloat(n.Left)
109+
b := toFloat(n.Right)
110+
if a != nil && b != nil {
111+
patchWithType(&FloatNode{Value: a.Value - b.Value}, a.Type())
112+
}
113+
}
60114
case "*":
61-
if a, ok := n.Left.(*IntegerNode); ok {
62-
if b, ok := n.Right.(*IntegerNode); ok {
115+
{
116+
a := toInteger(n.Left)
117+
b := toInteger(n.Right)
118+
if a != nil && b != nil {
63119
patchWithType(&IntegerNode{Value: a.Value * b.Value}, a.Type())
64120
}
65121
}
122+
{
123+
a := toInteger(n.Left)
124+
b := toFloat(n.Right)
125+
if a != nil && b != nil {
126+
patchWithType(&FloatNode{Value: float64(a.Value) * b.Value}, a.Type())
127+
}
128+
}
129+
{
130+
a := toFloat(n.Left)
131+
b := toInteger(n.Right)
132+
if a != nil && b != nil {
133+
patchWithType(&FloatNode{Value: a.Value * float64(b.Value)}, a.Type())
134+
}
135+
}
136+
{
137+
a := toFloat(n.Left)
138+
b := toFloat(n.Right)
139+
if a != nil && b != nil {
140+
patchWithType(&FloatNode{Value: a.Value * b.Value}, a.Type())
141+
}
142+
}
66143
case "/":
67-
if a, ok := n.Left.(*IntegerNode); ok {
68-
if b, ok := n.Right.(*IntegerNode); ok {
69-
if b.Value == 0 {
70-
fold.err = &file.Error{
71-
Location: (*node).Location(),
72-
Message: "integer divide by zero",
73-
}
74-
return
75-
}
76-
patchWithType(&IntegerNode{Value: a.Value / b.Value}, a.Type())
144+
{
145+
a := toInteger(n.Left)
146+
b := toInteger(n.Right)
147+
if a != nil && b != nil {
148+
patchWithType(&FloatNode{Value: float64(a.Value) / float64(b.Value)}, a.Type())
149+
}
150+
}
151+
{
152+
a := toInteger(n.Left)
153+
b := toFloat(n.Right)
154+
if a != nil && b != nil {
155+
patchWithType(&FloatNode{Value: float64(a.Value) / b.Value}, a.Type())
156+
}
157+
}
158+
{
159+
a := toFloat(n.Left)
160+
b := toInteger(n.Right)
161+
if a != nil && b != nil {
162+
patchWithType(&FloatNode{Value: a.Value / float64(b.Value)}, a.Type())
163+
}
164+
}
165+
{
166+
a := toFloat(n.Left)
167+
b := toFloat(n.Right)
168+
if a != nil && b != nil {
169+
patchWithType(&FloatNode{Value: a.Value / b.Value}, a.Type())
77170
}
78171
}
79172
case "%":
@@ -90,9 +183,32 @@ func (fold *fold) Visit(node *Node) {
90183
}
91184
}
92185
case "**", "^":
93-
if a, ok := n.Left.(*IntegerNode); ok {
94-
if b, ok := n.Right.(*IntegerNode); ok {
95-
patch(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))})
186+
{
187+
a := toInteger(n.Left)
188+
b := toInteger(n.Right)
189+
if a != nil && b != nil {
190+
patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))}, a.Type())
191+
}
192+
}
193+
{
194+
a := toInteger(n.Left)
195+
b := toFloat(n.Right)
196+
if a != nil && b != nil {
197+
patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), b.Value)}, a.Type())
198+
}
199+
}
200+
{
201+
a := toFloat(n.Left)
202+
b := toInteger(n.Right)
203+
if a != nil && b != nil {
204+
patchWithType(&FloatNode{Value: math.Pow(a.Value, float64(b.Value))}, a.Type())
205+
}
206+
}
207+
{
208+
a := toFloat(n.Left)
209+
b := toFloat(n.Right)
210+
if a != nil && b != nil {
211+
patchWithType(&FloatNode{Value: math.Pow(a.Value, b.Value)}, a.Type())
96212
}
97213
}
98214
}
@@ -145,3 +261,27 @@ func (fold *fold) Visit(node *Node) {
145261
}
146262
}
147263
}
264+
265+
func toString(n Node) *StringNode {
266+
switch a := n.(type) {
267+
case *StringNode:
268+
return a
269+
}
270+
return nil
271+
}
272+
273+
func toInteger(n Node) *IntegerNode {
274+
switch a := n.(type) {
275+
case *IntegerNode:
276+
return a
277+
}
278+
return nil
279+
}
280+
281+
func toFloat(n Node) *FloatNode {
282+
switch a := n.(type) {
283+
case *FloatNode:
284+
return a
285+
}
286+
return nil
287+
}

0 commit comments

Comments
 (0)