Skip to content

Commit ffd1c95

Browse files
committed
func: add support for functions as values
This updates #36
1 parent b147f00 commit ffd1c95

File tree

6 files changed

+70
-12
lines changed

6 files changed

+70
-12
lines changed

compiler/compiler/block.go

+7
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,20 @@ package compiler
33
import "fmt"
44

55
var blockIndex uint64
6+
var anonFuncIndex uint64
67

78
func getBlockName() string {
89
name := fmt.Sprintf("block-%d", blockIndex)
910
blockIndex++
1011
return name
1112
}
1213

14+
func getAnonFuncName() string {
15+
name := fmt.Sprintf("fn-%d", anonFuncIndex)
16+
anonFuncIndex++
17+
return name
18+
}
19+
1320
func getVarName(prefix string) string {
1421
name := fmt.Sprintf("%s-%d", prefix, blockIndex)
1522
blockIndex++

compiler/compiler/compiler.go

+8-5
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func (c *Compiler) Compile(root parser.PackageNode) (err error) {
149149
}
150150

151151
func (c *Compiler) GetIR() string {
152-
return fmt.Sprintln(c.module)
152+
return c.module.String()
153153
}
154154

155155
func (c *Compiler) addGlobal() {
@@ -214,17 +214,17 @@ func (c *Compiler) compile(instructions []parser.Node) {
214214
}
215215
}
216216

217-
func (c *Compiler) funcByName(name string) *types.Function {
217+
func (c *Compiler) funcByName(name string) (*types.Function, bool) {
218218
if f, ok := c.globalFuncs[name]; ok {
219-
return f
219+
return f, true
220220
}
221221

222222
// Function in the current package
223223
if f, ok := c.currentPackage.Funcs[name]; ok {
224-
return f
224+
return f, true
225225
}
226226

227-
panic("funcByName: no such func: " + name)
227+
return nil, false
228228
}
229229

230230
func (c *Compiler) varByName(name string) value.Value {
@@ -294,6 +294,9 @@ func (c *Compiler) compileValue(node parser.Node) value.Value {
294294
return c.compileInitStructWithValues(v)
295295
case *parser.TypeCastInterfaceNode:
296296
return c.compileTypeCastInterfaceNode(v)
297+
298+
case *parser.DefineFuncNode:
299+
return c.compileDefineFuncNode(v)
297300
}
298301

299302
panic("compileValue fail: " + fmt.Sprintf("%T: %+v", node, node))

compiler/compiler/func.go

+29-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package compiler
22

33
import (
4+
"fmt"
45
"github.com/llir/llvm/ir"
56
"github.com/llir/llvm/ir/constant"
67
llvmTypes "github.com/llir/llvm/ir/types"
@@ -11,7 +12,7 @@ import (
1112
"github.com/zegl/tre/compiler/parser"
1213
)
1314

14-
func (c *Compiler) compileDefineFuncNode(v *parser.DefineFuncNode) {
15+
func (c *Compiler) compileDefineFuncNode(v *parser.DefineFuncNode) value.Value {
1516
var compiledName string
1617

1718
if v.IsMethod {
@@ -31,8 +32,10 @@ func (c *Compiler) compileDefineFuncNode(v *parser.DefineFuncNode) {
3132

3233
// Change the name of our function
3334
compiledName = c.currentPackageName + "_method_" + v.MethodOnType.TypeName + "_" + v.Name
34-
} else {
35+
} else if v.IsNamed {
3536
compiledName = c.currentPackageName + "_" + v.Name
37+
} else {
38+
compiledName = c.currentPackageName + "_" + getAnonFuncName()
3639
}
3740

3841
llvmParams := make([]*ir.Param, len(v.Arguments))
@@ -127,12 +130,15 @@ func (c *Compiler) compileDefineFuncNode(v *parser.DefineFuncNode) {
127130

128131
// Make this method available in interfaces via a jump function
129132
typesFunc.JumpFunction = c.compileInterfaceMethodJump(fn)
130-
} else {
133+
} else if v.IsNamed {
131134
c.currentPackage.Funcs[v.Name] = typesFunc
132135
}
133136

134137
entry := fn.NewBlock(getBlockName())
135138

139+
prevContextFunc := c.contextFunc
140+
prevContextBlock := c.contextBlock
141+
136142
c.contextFunc = typesFunc
137143
c.contextBlock = entry
138144
c.pushVariablesStack()
@@ -196,7 +202,15 @@ func (c *Compiler) compileDefineFuncNode(v *parser.DefineFuncNode) {
196202
c.contextBlock.NewRet(constant.NewInt(llvmTypes.I32, 0))
197203
}
198204

205+
c.contextFunc = prevContextFunc
206+
c.contextBlock = prevContextBlock
207+
199208
c.popVariablesStack()
209+
210+
return value.Value{
211+
Type: typesFunc,
212+
Value: typesFunc.LlvmFunction,
213+
}
200214
}
201215

202216
func (c *Compiler) compileInterfaceMethodJump(targetFunc *ir.Function) *ir.Function {
@@ -248,7 +262,7 @@ func (c *Compiler) compileReturnNode(v *parser.ReturnNode) {
248262
// Set value and jump to return block
249263
val := c.compileValue(v.Vals[0])
250264

251-
// Type cast if neccesary
265+
// Type cast if necessary
252266
val = c.valueToInterfaceValue(val, c.contextFunc.LlvmReturnType)
253267

254268
if val.IsVariable {
@@ -264,7 +278,7 @@ func (c *Compiler) compileReturnNode(v *parser.ReturnNode) {
264278
for i, val := range v.Vals {
265279
compVal := c.compileValue(val)
266280

267-
// Type cast if neccesary
281+
// TODO: Type cast if necessary
268282
// compVal = c.valueToInterfaceValue(compVal, c.contextFunc.ReturnType)
269283

270284
// Assign to ptr
@@ -300,7 +314,16 @@ func (c *Compiler) compileCallNode(v *parser.CallNode) value.Value {
300314
var fn *types.Function
301315

302316
if isNameNode {
303-
fn = c.funcByName(name.Name)
317+
if namedFn, ok := c.funcByName(name.Name); ok {
318+
fn = namedFn
319+
} else {
320+
funcByVal := c.compileValue(v.Function)
321+
if checkIfFunc, ok := funcByVal.Type.(*types.Function); ok {
322+
fn = checkIfFunc
323+
} else {
324+
panic(fmt.Sprintf("no such function: %v", v))
325+
}
326+
}
304327
} else {
305328
funcByVal := c.compileValue(v.Function)
306329
if checkIfFunc, ok := funcByVal.Type.(*types.Function); ok {

compiler/compiler/func_len.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ func (c *Compiler) lenFuncCall(v *parser.CallNode) value.Value {
1313
arg := c.compileValue(v.Arguments[0])
1414

1515
if arg.Type.Name() == "string" {
16-
f := c.funcByName("len_string")
16+
f, ok := c.funcByName("len_string")
17+
if !ok {
18+
panic("could not find len_string func")
19+
}
1720

1821
val := arg.Value
1922
if arg.IsVariable {

compiler/testdata/func-as-value.go

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package main
2+
3+
import "external"
4+
5+
func main() {
6+
f1 := func() int {
7+
return 100
8+
}
9+
10+
external.Printf("%d\n", f1()) // 100
11+
12+
13+
f2 := func(a int) int {
14+
return 100 * a
15+
}
16+
17+
external.Printf("%d\n", f2(2)) // 200
18+
}

compiler/testdata/int64-method.go

+4
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,8 @@ func main() {
1414
var abc myint
1515
abc = 100
1616
abc.Yolo()
17+
18+
f1 := func() int {
19+
return 100
20+
}
1721
}

0 commit comments

Comments
 (0)