Skip to content

Commit 183e46f

Browse files
initial commit
0 parents  commit 183e46f

10 files changed

+659
-0
lines changed

ast.go

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package main
2+
3+
import (
4+
"io"
5+
"log"
6+
"strings"
7+
)
8+
9+
type astNode struct {
10+
left, right *astNode
11+
token Token
12+
depth int // depth in the ast, helps in stringify
13+
evalr evaluator
14+
}
15+
16+
type parserFn func(*parser) *astNode
17+
18+
var nilASTNode *astNode = nil
19+
20+
func newASTNode(token Token, left, right *astNode, evalr evaluator) *astNode {
21+
return &astNode{left, right, token, 0, evalr}
22+
}
23+
24+
type parser struct {
25+
tokens []Token
26+
processedToken int
27+
lastErr error
28+
}
29+
30+
func (ast *astNode) String() string {
31+
marginStr := ""
32+
var out strings.Builder
33+
for leftMargin := ast.depth; leftMargin > 0; leftMargin-- {
34+
marginStr += "\t"
35+
}
36+
out.WriteString("\n")
37+
out.WriteString(marginStr)
38+
out.WriteString("Token: ")
39+
out.WriteString(ast.token.String())
40+
if ast.left != nilASTNode {
41+
ast.left.depth = ast.depth + 1
42+
out.WriteString("\n")
43+
out.WriteString(marginStr)
44+
out.WriteString("Left: ")
45+
out.WriteString(ast.left.String())
46+
}
47+
if ast.right != nilASTNode {
48+
ast.right.depth = ast.depth + 1
49+
out.WriteString("\n")
50+
out.WriteString(marginStr)
51+
out.WriteString("Right: ")
52+
out.WriteString(ast.right.String())
53+
54+
}
55+
return out.String()
56+
}
57+
58+
func (p *parser) next() Token {
59+
60+
if p.processedToken < len(p.tokens) {
61+
token := p.tokens[p.processedToken]
62+
p.processedToken++
63+
return token
64+
}
65+
p.lastErr = io.EOF
66+
return Token{}
67+
}
68+
69+
func (p *parser) peek() Token {
70+
if p.processedToken < len(p.tokens) {
71+
return p.tokens[p.processedToken]
72+
}
73+
p.lastErr = io.EOF
74+
return Token{}
75+
}
76+
77+
func (p *parser) eof() bool {
78+
defer log.Print("LastError: ", p.lastErr)
79+
return p.lastErr == io.EOF
80+
}
81+
82+
func parseTokens(tokens []Token, prsFn parserFn) *astNode {
83+
par := parser{
84+
tokens: tokens,
85+
}
86+
return prsFn(&par)
87+
}

ast_test.go

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package main
2+
3+
import (
4+
"bytes"
5+
"testing"
6+
)
7+
8+
func TestAST(t *testing.T) {
9+
lexer := NewLexer(bytes.NewReader([]byte("(1)+(2)")))
10+
11+
err := lexer.Analyze()
12+
if err != nil {
13+
t.Error("error while doing lexical analysis: ", err)
14+
t.Fail()
15+
return
16+
}
17+
t.Logf("Token List: %v", lexer.tokenList)
18+
actual := parseTokens(lexer.tokenList, parseExpr)
19+
t.Log("Actual: ", actual.String())
20+
expected := newASTNode(NewToken(Plus, '+'),
21+
22+
newASTNode(NewToken(Number, '1'), nil, nil, nil),
23+
newASTNode(NewToken(Number, '2'), nil, nil, nil), nil)
24+
25+
if !compareAST(expected, actual) {
26+
t.Error("Expected AST didn't match with Actual AST")
27+
t.Error("Expected: ", expected.String())
28+
}
29+
}
30+
31+
func compareAST(x, y *astNode) bool {
32+
33+
if x == nil && y == nil {
34+
return true
35+
}
36+
37+
if x != nil && y != nil {
38+
return x.token.Equals(y.token) && compareAST(x.left, y.left) && compareAST(x.right, y.right)
39+
}
40+
return false
41+
42+
}

eval.go

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package main
2+
3+
import (
4+
"errors"
5+
"strconv"
6+
)
7+
8+
type evaluator func(*astNode) (any, error)
9+
10+
func eval(ast *astNode) (any, error) {
11+
return ast.evalr(ast)
12+
}
13+
14+
func evalNumber(ast *astNode) (any, error) {
15+
return strconv.ParseFloat(string(ast.token.val), 64)
16+
}
17+
18+
func evalInfix(ast *astNode) (any, error) {
19+
var left, right any
20+
left, err := eval(ast.left)
21+
if err != nil {
22+
return nil, err
23+
}
24+
right, err = eval(ast.right)
25+
if err != nil {
26+
return 0, nil
27+
}
28+
leftF := left.(float64)
29+
rightF := right.(float64)
30+
31+
switch ast.token.tokenType {
32+
case Minus:
33+
return leftF - rightF, nil
34+
case Plus:
35+
return leftF + rightF, nil
36+
case Asterisk:
37+
return leftF * rightF, nil
38+
case Slash:
39+
if rightF == 0 {
40+
return 0, errors.New("divided by zero")
41+
}
42+
return leftF / rightF, nil
43+
case Mod:
44+
return float64(int64(leftF) % int64(rightF)), nil
45+
46+
case LT:
47+
return leftF < rightF, nil
48+
case GT:
49+
return leftF > rightF, nil
50+
case LTEQ:
51+
return leftF <= rightF, nil
52+
case GTEQ:
53+
return leftF >= rightF, nil
54+
case Eq:
55+
return leftF == rightF, nil
56+
57+
}
58+
return 0, nil // Need to take care of this
59+
}

eval_test.go

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package main
2+
3+
import (
4+
"bytes"
5+
"testing"
6+
)
7+
8+
func TestEval(t *testing.T) {
9+
lexer := NewLexer(bytes.NewReader([]byte("(2*3*2+32) == 8")))
10+
11+
err := lexer.Analyze()
12+
if err != nil {
13+
t.Error("error while doing lexical analysis: ", err)
14+
t.Fail()
15+
return
16+
}
17+
t.Logf("Token List: %v", lexer.tokenList)
18+
ast := parseTokens(lexer.tokenList, parseExpr)
19+
t.Log("AST: ", ast.String())
20+
actualResult, err := eval(ast)
21+
if err != nil {
22+
t.Error("error occured while evaluting ast: ", err)
23+
return
24+
}
25+
expectedVal := false
26+
if actualResult != expectedVal {
27+
t.Errorf("Wrong Evaluation Actual: %v, Expected: %v", actualResult, expectedVal)
28+
}
29+
}

go.mod

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
module github.com/devansh42/rad
2+
3+
go 1.21.1

0 commit comments

Comments
 (0)