Skip to content

Commit e797004

Browse files
committed
added relationship detection code
1 parent ecdedb3 commit e797004

File tree

2 files changed

+211
-0
lines changed

2 files changed

+211
-0
lines changed

sourcecode-parser/antlr/evaluator.go

+76
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,82 @@ import (
77
"github.com/expr-lang/expr"
88
)
99

10+
// RelationshipMap represents relationships between entities and their attributes
11+
type RelationshipMap struct {
12+
// map[EntityName]map[AttributeName][]RelatedEntity
13+
// Example: {"class": {"methods": ["method", "function"]}}
14+
Relationships map[string]map[string][]string
15+
}
16+
17+
// NewRelationshipMap creates a new RelationshipMap
18+
func NewRelationshipMap() *RelationshipMap {
19+
return &RelationshipMap{
20+
Relationships: make(map[string]map[string][]string),
21+
}
22+
}
23+
24+
// AddRelationship adds a relationship between an entity and its related entities through an attribute
25+
func (rm *RelationshipMap) AddRelationship(entity, attribute string, relatedEntities []string) {
26+
if rm.Relationships[entity] == nil {
27+
rm.Relationships[entity] = make(map[string][]string)
28+
}
29+
rm.Relationships[entity][attribute] = relatedEntities
30+
}
31+
32+
// HasRelationship checks if two entities are related through any attribute
33+
func (rm *RelationshipMap) HasRelationship(entity1, entity2 string) bool {
34+
// Check direct relationships from entity1 to entity2
35+
if attrs, ok := rm.Relationships[entity1]; ok {
36+
for _, relatedEntities := range attrs {
37+
for _, related := range relatedEntities {
38+
if related == entity2 {
39+
return true
40+
}
41+
}
42+
}
43+
}
44+
45+
// Check direct relationships from entity2 to entity1
46+
if attrs, ok := rm.Relationships[entity2]; ok {
47+
for _, relatedEntities := range attrs {
48+
for _, related := range relatedEntities {
49+
if related == entity1 {
50+
return true
51+
}
52+
}
53+
}
54+
}
55+
56+
return false
57+
}
58+
59+
// CheckExpressionRelationship checks if a binary expression involves related entities
60+
func CheckExpressionRelationship(node *ExpressionNode, relationshipMap *RelationshipMap) (bool, error) {
61+
// First check if it's a dual entity comparison
62+
compType, err := DetectComparisonType(node)
63+
if err != nil {
64+
return false, fmt.Errorf("failed to detect comparison type: %w", err)
65+
}
66+
67+
if compType != DUAL_ENTITY {
68+
return false, nil // Not a dual entity comparison
69+
}
70+
71+
// Get entity names from both sides
72+
leftEntity, err := getEntityName(node.Left)
73+
if err != nil {
74+
return false, fmt.Errorf("failed to get left entity: %w", err)
75+
}
76+
77+
rightEntity, err := getEntityName(node.Right)
78+
if err != nil {
79+
return false, fmt.Errorf("failed to get right entity: %w", err)
80+
}
81+
82+
// Check if entities are related
83+
return relationshipMap.HasRelationship(leftEntity, rightEntity), nil
84+
}
85+
1086
// ComparisonType represents the type of comparison in an expression
1187
type ComparisonType string
1288

sourcecode-parser/antlr/evaluator_test.go

+135
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,141 @@ func TestEvaluateExpressionTree(t *testing.T) {
175175
}
176176
}
177177

178+
func TestRelationshipMap(t *testing.T) {
179+
// Create a relationship map
180+
rm := NewRelationshipMap()
181+
182+
// Add some relationships
183+
rm.AddRelationship("class", "methods", []string{"method", "function"})
184+
rm.AddRelationship("method", "parameters", []string{"parameter", "variable"})
185+
rm.AddRelationship("function", "returns", []string{"type", "class"})
186+
187+
tests := []struct {
188+
name string
189+
entity1 string
190+
entity2 string
191+
expected bool
192+
}{
193+
{
194+
name: "direct relationship exists",
195+
entity1: "class",
196+
entity2: "method",
197+
expected: true,
198+
},
199+
{
200+
name: "reverse relationship exists",
201+
entity1: "function",
202+
entity2: "class",
203+
expected: true,
204+
},
205+
{
206+
name: "no relationship exists",
207+
entity1: "class",
208+
entity2: "parameter",
209+
expected: false,
210+
},
211+
{
212+
name: "unknown entity",
213+
entity1: "unknown",
214+
entity2: "class",
215+
expected: false,
216+
},
217+
}
218+
219+
for _, tt := range tests {
220+
t.Run(tt.name, func(t *testing.T) {
221+
got := rm.HasRelationship(tt.entity1, tt.entity2)
222+
assert.Equal(t, tt.expected, got)
223+
})
224+
}
225+
}
226+
227+
func TestCheckExpressionRelationship(t *testing.T) {
228+
// Create a relationship map
229+
rm := NewRelationshipMap()
230+
rm.AddRelationship("class", "methods", []string{"method"})
231+
232+
tests := []struct {
233+
name string
234+
node *ExpressionNode
235+
expected bool
236+
wantErr bool
237+
}{
238+
{
239+
name: "related entities comparison",
240+
node: &ExpressionNode{
241+
Type: "binary",
242+
Operator: "==",
243+
Left: &ExpressionNode{
244+
Type: "variable",
245+
Value: "class",
246+
},
247+
Right: &ExpressionNode{
248+
Type: "variable",
249+
Value: "method",
250+
},
251+
},
252+
expected: true,
253+
wantErr: false,
254+
},
255+
{
256+
name: "unrelated entities comparison",
257+
node: &ExpressionNode{
258+
Type: "binary",
259+
Operator: "==",
260+
Left: &ExpressionNode{
261+
Type: "variable",
262+
Value: "class",
263+
},
264+
Right: &ExpressionNode{
265+
Type: "variable",
266+
Value: "unrelated",
267+
},
268+
},
269+
expected: false,
270+
wantErr: false,
271+
},
272+
{
273+
name: "single entity comparison",
274+
node: &ExpressionNode{
275+
Type: "binary",
276+
Operator: ">",
277+
Left: &ExpressionNode{
278+
Type: "variable",
279+
Value: "class",
280+
},
281+
Right: &ExpressionNode{
282+
Type: "literal",
283+
Value: "10",
284+
},
285+
},
286+
expected: false,
287+
wantErr: false,
288+
},
289+
{
290+
name: "non-binary node",
291+
node: &ExpressionNode{
292+
Type: "literal",
293+
Value: "25",
294+
},
295+
expected: false,
296+
wantErr: true,
297+
},
298+
}
299+
300+
for _, tt := range tests {
301+
t.Run(tt.name, func(t *testing.T) {
302+
got, err := CheckExpressionRelationship(tt.node, rm)
303+
if tt.wantErr {
304+
assert.Error(t, err)
305+
return
306+
}
307+
assert.NoError(t, err)
308+
assert.Equal(t, tt.expected, got)
309+
})
310+
}
311+
}
312+
178313
func TestDetectComparisonType(t *testing.T) {
179314
tests := []struct {
180315
name string

0 commit comments

Comments
 (0)