Skip to content

Commit ecdedb3

Browse files
committed
added entity detection in conditions
1 parent bc59d7a commit ecdedb3

File tree

2 files changed

+176
-5
lines changed

2 files changed

+176
-5
lines changed

sourcecode-parser/antlr/evaluator.go

+71-5
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,20 @@ package parser
33
import (
44
"fmt"
55
"strings"
6+
67
"github.com/expr-lang/expr"
78
)
89

10+
// ComparisonType represents the type of comparison in an expression
11+
type ComparisonType string
12+
13+
const (
14+
// SINGLE_ENTITY represents comparison between one entity and a static value
15+
SINGLE_ENTITY ComparisonType = "SINGLE_ENTITY"
16+
// DUAL_ENTITY represents comparison between two different entities
17+
DUAL_ENTITY ComparisonType = "DUAL_ENTITY"
18+
)
19+
920
// EvaluateExpressionTree evaluates the expression tree against input data
1021
// and returns filtered data based on the expression conditions
1122
func EvaluateExpressionTree(tree *ExpressionNode, data []map[string]interface{}) ([]map[string]interface{}, error) {
@@ -19,7 +30,7 @@ func EvaluateExpressionTree(tree *ExpressionNode, data []map[string]interface{})
1930
if err != nil {
2031
return nil, fmt.Errorf("evaluation error: %w", err)
2132
}
22-
33+
2334
// Only include items that match the expression
2435
if matches.(bool) {
2536
result = append(result, item)
@@ -57,16 +68,73 @@ func evaluateNode(node *ExpressionNode, data map[string]interface{}) (interface{
5768
}
5869

5970
// nodeToExprString converts an ExpressionNode to an expr-lang expression string
71+
// DetectComparisonType analyzes a binary expression node and determines if it's comparing
72+
// a single entity with a static value or comparing two different entities
73+
func DetectComparisonType(node *ExpressionNode) (ComparisonType, error) {
74+
if node == nil {
75+
return "", fmt.Errorf("nil node")
76+
}
77+
78+
// Only analyze binary nodes
79+
if node.Type != "binary" {
80+
return "", fmt.Errorf("not a binary node")
81+
}
82+
83+
// Get entity names from left and right sides
84+
leftEntity, err := getEntityName(node.Left)
85+
if err != nil {
86+
return "", fmt.Errorf("failed to get left entity: %w", err)
87+
}
88+
89+
rightEntity, err := getEntityName(node.Right)
90+
if err != nil {
91+
return "", fmt.Errorf("failed to get right entity: %w", err)
92+
}
93+
94+
// If either side is empty (literal/static value) or they're the same entity,
95+
// it's a SINGLE_ENTITY comparison
96+
if leftEntity == "" || rightEntity == "" || leftEntity == rightEntity {
97+
return SINGLE_ENTITY, nil
98+
}
99+
100+
// Different entities are being compared
101+
return DUAL_ENTITY, nil
102+
}
103+
104+
// getEntityName extracts the entity name from a node.
105+
// Returns empty string for literals and static values.
106+
func getEntityName(node *ExpressionNode) (string, error) {
107+
if node == nil {
108+
return "", fmt.Errorf("nil node")
109+
}
110+
111+
switch node.Type {
112+
case "variable":
113+
return node.Value, nil
114+
case "method_call":
115+
// For method calls, consider the target object as the entity
116+
parts := strings.Split(node.Value, ".")
117+
if len(parts) > 0 {
118+
return parts[0], nil
119+
}
120+
return "", nil
121+
case "literal":
122+
return "", nil // Literals are static values
123+
default:
124+
return "", fmt.Errorf("unsupported node type: %s", node.Type)
125+
}
126+
}
127+
60128
func nodeToExprString(node *ExpressionNode) (string, error) {
61129
switch node.Type {
62130
case "binary":
63131
left, err := nodeToExprString(node.Left)
64132
if err != nil {
65-
return "", err
133+
return "", err
66134
}
67135
right, err := nodeToExprString(node.Right)
68136
if err != nil {
69-
return "", err
137+
return "", err
70138
}
71139
return fmt.Sprintf("(%s %s %s)", left, node.Operator, right), nil
72140

@@ -111,5 +179,3 @@ func nodeToExprString(node *ExpressionNode) (string, error) {
111179
return "", fmt.Errorf("unknown node type: %s", node.Type)
112180
}
113181
}
114-
115-

sourcecode-parser/antlr/evaluator_test.go

+105
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,111 @@ func TestEvaluateExpressionTree(t *testing.T) {
175175
}
176176
}
177177

178+
func TestDetectComparisonType(t *testing.T) {
179+
tests := []struct {
180+
name string
181+
node *ExpressionNode
182+
expected ComparisonType
183+
wantErr bool
184+
}{
185+
{
186+
name: "single entity with literal",
187+
node: &ExpressionNode{
188+
Type: "binary",
189+
Operator: ">",
190+
Left: &ExpressionNode{
191+
Type: "variable",
192+
Value: "age",
193+
},
194+
Right: &ExpressionNode{
195+
Type: "literal",
196+
Value: "25",
197+
},
198+
},
199+
expected: SINGLE_ENTITY,
200+
wantErr: false,
201+
},
202+
{
203+
name: "dual entity comparison",
204+
node: &ExpressionNode{
205+
Type: "binary",
206+
Operator: "==",
207+
Left: &ExpressionNode{
208+
Type: "variable",
209+
Value: "age",
210+
},
211+
Right: &ExpressionNode{
212+
Type: "variable",
213+
Value: "count",
214+
},
215+
},
216+
expected: DUAL_ENTITY,
217+
wantErr: false,
218+
},
219+
{
220+
name: "single entity method call",
221+
node: &ExpressionNode{
222+
Type: "binary",
223+
Operator: ">",
224+
Left: &ExpressionNode{
225+
Type: "method_call",
226+
Value: "method.complexity",
227+
},
228+
Right: &ExpressionNode{
229+
Type: "literal",
230+
Value: "10",
231+
},
232+
},
233+
expected: SINGLE_ENTITY,
234+
wantErr: false,
235+
},
236+
{
237+
name: "dual entity method calls",
238+
node: &ExpressionNode{
239+
Type: "binary",
240+
Operator: "==",
241+
Left: &ExpressionNode{
242+
Type: "method_call",
243+
Value: "method1.complexity",
244+
},
245+
Right: &ExpressionNode{
246+
Type: "method_call",
247+
Value: "method2.complexity",
248+
},
249+
},
250+
expected: DUAL_ENTITY,
251+
wantErr: false,
252+
},
253+
{
254+
name: "non-binary node",
255+
node: &ExpressionNode{
256+
Type: "literal",
257+
Value: "25",
258+
},
259+
expected: "",
260+
wantErr: true,
261+
},
262+
{
263+
name: "nil node",
264+
node: nil,
265+
expected: "",
266+
wantErr: true,
267+
},
268+
}
269+
270+
for _, tt := range tests {
271+
t.Run(tt.name, func(t *testing.T) {
272+
got, err := DetectComparisonType(tt.node)
273+
if tt.wantErr {
274+
assert.Error(t, err)
275+
return
276+
}
277+
assert.NoError(t, err)
278+
assert.Equal(t, tt.expected, got)
279+
})
280+
}
281+
}
282+
178283
func TestEvaluateNode(t *testing.T) {
179284
// Mock data with method and predicate functions
180285
testData := map[string]interface{}{

0 commit comments

Comments
 (0)