Skip to content

Commit 370e626

Browse files
authored
ast: Parser recursion depth guard (#7568)
Add enter/leave helpers that bump a depth counter and fail with ErrMaxParsingRecursionDepthExceeded once the limit (default 100k) is reached. Every recursive parse helper now calls them. Expose WithMaxRecursionDepth for callers that need higher limits. If limit is set to 0 then recursion tracking is effectively disabled. Tests utilise a much lower depth limit. Signed-off-by: Ville Vesilehto <[email protected]>
1 parent 8a75563 commit 370e626

File tree

3 files changed

+252
-6
lines changed

3 files changed

+252
-6
lines changed

v1/ast/parser.go

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,16 @@ import (
2828
"github.com/open-policy-agent/opa/v1/ast/location"
2929
)
3030

31+
// DefaultMaxParsingRecursionDepth is the default maximum recursion
32+
// depth for the parser
33+
const DefaultMaxParsingRecursionDepth = 100000
34+
35+
var maxParsingRecursionDepth = DefaultMaxParsingRecursionDepth
36+
37+
// ErrMaxParsingRecursionDepthExceeded is returned when the parser
38+
// recursion exceeds the maximum allowed depth
39+
var ErrMaxParsingRecursionDepthExceeded = errors.New("max parsing recursion depth exceeded")
40+
3141
var RegoV1CompatibleRef = Ref{VarTerm("rego"), StringTerm("v1")}
3242

3343
// RegoVersion defines the Rego syntax requirements for a module.
@@ -114,10 +124,12 @@ func (s *state) Text(offset, end int) []byte {
114124

115125
// Parser is used to parse Rego statements.
116126
type Parser struct {
117-
r io.Reader
118-
s *state
119-
po ParserOptions
120-
cache parsedTermCache
127+
r io.Reader
128+
s *state
129+
po ParserOptions
130+
cache parsedTermCache
131+
recursionDepth int
132+
maxRecursionDepth int
121133
}
122134

123135
type parsedTermCacheItem struct {
@@ -169,12 +181,19 @@ func (po *ParserOptions) EffectiveRegoVersion() RegoVersion {
169181
// NewParser creates and initializes a Parser.
170182
func NewParser() *Parser {
171183
p := &Parser{
172-
s: &state{},
173-
po: ParserOptions{},
184+
s: &state{},
185+
po: ParserOptions{},
186+
maxRecursionDepth: maxParsingRecursionDepth,
174187
}
175188
return p
176189
}
177190

191+
// WithMaxRecursionDepth sets the maximum recursion depth for the parser.
192+
func (p *Parser) WithMaxRecursionDepth(depth int) *Parser {
193+
p.maxRecursionDepth = depth
194+
return p
195+
}
196+
178197
// WithFilename provides the filename for Location details
179198
// on parsed statements.
180199
func (p *Parser) WithFilename(filename string) *Parser {
@@ -1031,6 +1050,10 @@ func (p *Parser) parseHead(defaultRule bool) (*Head, bool) {
10311050
}
10321051

10331052
func (p *Parser) parseBody(end tokens.Token) Body {
1053+
if !p.enter() {
1054+
return nil
1055+
}
1056+
defer p.leave()
10341057
return p.parseQuery(false, end)
10351058
}
10361059

@@ -1356,10 +1379,20 @@ func (p *Parser) parseExpr() *Expr {
13561379
// other binary operators (|, &, arithmetics), it constitutes the binding
13571380
// precedence.
13581381
func (p *Parser) parseTermInfixCall() *Term {
1382+
if !p.enter() {
1383+
return nil
1384+
}
1385+
defer p.leave()
1386+
13591387
return p.parseTermIn(nil, true, p.s.loc.Offset)
13601388
}
13611389

13621390
func (p *Parser) parseTermInfixCallInList() *Term {
1391+
if !p.enter() {
1392+
return nil
1393+
}
1394+
defer p.leave()
1395+
13631396
return p.parseTermIn(nil, false, p.s.loc.Offset)
13641397
}
13651398

@@ -1369,6 +1402,11 @@ var memberWithKeyRef = MemberWithKey.Ref()
13691402
var memberRef = Member.Ref()
13701403

13711404
func (p *Parser) parseTermIn(lhs *Term, keyVal bool, offset int) *Term {
1405+
if !p.enter() {
1406+
return nil
1407+
}
1408+
defer p.leave()
1409+
13721410
// NOTE(sr): `in` is a bit special: besides `lhs in rhs`, it also
13731411
// supports `key, val in rhs`, so it can have an optional second lhs.
13741412
// `keyVal` triggers if we attempt to parse a second lhs argument (`mhs`).
@@ -1411,6 +1449,11 @@ func (p *Parser) parseTermIn(lhs *Term, keyVal bool, offset int) *Term {
14111449
}
14121450

14131451
func (p *Parser) parseTermRelation(lhs *Term, offset int) *Term {
1452+
if !p.enter() {
1453+
return nil
1454+
}
1455+
defer p.leave()
1456+
14141457
if lhs == nil {
14151458
lhs = p.parseTermOr(nil, offset)
14161459
}
@@ -1431,6 +1474,11 @@ func (p *Parser) parseTermRelation(lhs *Term, offset int) *Term {
14311474
}
14321475

14331476
func (p *Parser) parseTermOr(lhs *Term, offset int) *Term {
1477+
if !p.enter() {
1478+
return nil
1479+
}
1480+
defer p.leave()
1481+
14341482
if lhs == nil {
14351483
lhs = p.parseTermAnd(nil, offset)
14361484
}
@@ -1452,6 +1500,11 @@ func (p *Parser) parseTermOr(lhs *Term, offset int) *Term {
14521500
}
14531501

14541502
func (p *Parser) parseTermAnd(lhs *Term, offset int) *Term {
1503+
if !p.enter() {
1504+
return nil
1505+
}
1506+
defer p.leave()
1507+
14551508
if lhs == nil {
14561509
lhs = p.parseTermArith(nil, offset)
14571510
}
@@ -1473,6 +1526,11 @@ func (p *Parser) parseTermAnd(lhs *Term, offset int) *Term {
14731526
}
14741527

14751528
func (p *Parser) parseTermArith(lhs *Term, offset int) *Term {
1529+
if !p.enter() {
1530+
return nil
1531+
}
1532+
defer p.leave()
1533+
14761534
if lhs == nil {
14771535
lhs = p.parseTermFactor(nil, offset)
14781536
}
@@ -1493,6 +1551,11 @@ func (p *Parser) parseTermArith(lhs *Term, offset int) *Term {
14931551
}
14941552

14951553
func (p *Parser) parseTermFactor(lhs *Term, offset int) *Term {
1554+
if !p.enter() {
1555+
return nil
1556+
}
1557+
defer p.leave()
1558+
14961559
if lhs == nil {
14971560
lhs = p.parseTerm()
14981561
}
@@ -1513,6 +1576,11 @@ func (p *Parser) parseTermFactor(lhs *Term, offset int) *Term {
15131576
}
15141577

15151578
func (p *Parser) parseTerm() *Term {
1579+
if !p.enter() {
1580+
return nil
1581+
}
1582+
defer p.leave()
1583+
15161584
if term, s := p.parsedTermCacheLookup(); s != nil {
15171585
p.restore(s)
15181586
return term
@@ -1665,6 +1733,10 @@ func (p *Parser) parseRawString() *Term {
16651733
var setConstructor = RefTerm(VarTerm("set"))
16661734

16671735
func (p *Parser) parseCall(operator *Term, offset int) (term *Term) {
1736+
if !p.enter() {
1737+
return nil
1738+
}
1739+
defer p.leave()
16681740

16691741
loc := operator.Location
16701742
var end int
@@ -1694,6 +1766,10 @@ func (p *Parser) parseCall(operator *Term, offset int) (term *Term) {
16941766
}
16951767

16961768
func (p *Parser) parseRef(head *Term, offset int) (term *Term) {
1769+
if !p.enter() {
1770+
return nil
1771+
}
1772+
defer p.leave()
16971773

16981774
loc := head.Location
16991775
var end int
@@ -1759,6 +1835,10 @@ func (p *Parser) parseRef(head *Term, offset int) (term *Term) {
17591835
}
17601836

17611837
func (p *Parser) parseArray() (term *Term) {
1838+
if !p.enter() {
1839+
return nil
1840+
}
1841+
defer p.leave()
17621842

17631843
loc := p.s.Loc()
17641844
offset := p.s.loc.Offset
@@ -1830,6 +1910,11 @@ func (p *Parser) parseArray() (term *Term) {
18301910
}
18311911

18321912
func (p *Parser) parseSetOrObject() (term *Term) {
1913+
if !p.enter() {
1914+
return nil
1915+
}
1916+
defer p.leave()
1917+
18331918
loc := p.s.Loc()
18341919
offset := p.s.loc.Offset
18351920

@@ -1896,6 +1981,11 @@ func (p *Parser) parseSetOrObject() (term *Term) {
18961981
}
18971982

18981983
func (p *Parser) parseSet(s *state, head *Term, potentialComprehension bool) *Term {
1984+
if !p.enter() {
1985+
return nil
1986+
}
1987+
defer p.leave()
1988+
18991989
switch p.s.tok {
19001990
case tokens.RBrace:
19011991
return SetTerm(head)
@@ -1925,6 +2015,11 @@ func (p *Parser) parseSet(s *state, head *Term, potentialComprehension bool) *Te
19252015
}
19262016

19272017
func (p *Parser) parseObject(k *Term, potentialComprehension bool) *Term {
2018+
if !p.enter() {
2019+
return nil
2020+
}
2021+
defer p.leave()
2022+
19282023
// NOTE(tsandall): Assumption: this function is called after parsing the key
19292024
// of the head element and then receiving a colon token from the scanner.
19302025
// Advance beyond the colon and attempt to parse an object.
@@ -1978,6 +2073,11 @@ func (p *Parser) parseObject(k *Term, potentialComprehension bool) *Term {
19782073
}
19792074

19802075
func (p *Parser) parseObjectFinish(key, val *Term, potentialComprehension bool) *Term {
2076+
if !p.enter() {
2077+
return nil
2078+
}
2079+
defer p.leave()
2080+
19812081
switch p.s.tok {
19822082
case tokens.RBrace:
19832083
return ObjectTerm([2]*Term{key, val})
@@ -2759,3 +2859,21 @@ func init() {
27592859
maps.Copy(allFutureKeywords, futureKeywords)
27602860
maps.Copy(allFutureKeywords, futureKeywordsV0)
27612861
}
2862+
2863+
// enter increments the recursion depth counter and checks if it exceeds the maximum.
2864+
// Returns false if the maximum is exceeded, true otherwise.
2865+
// If p.maxRecursionDepth is 0 or negative, the check is effectively disabled.
2866+
func (p *Parser) enter() bool {
2867+
p.recursionDepth++
2868+
if p.maxRecursionDepth > 0 && p.recursionDepth > p.maxRecursionDepth {
2869+
p.error(p.s.Loc(), ErrMaxParsingRecursionDepthExceeded.Error())
2870+
p.recursionDepth--
2871+
return false
2872+
}
2873+
return true
2874+
}
2875+
2876+
// leave decrements the recursion depth counter.
2877+
func (p *Parser) leave() {
2878+
p.recursionDepth--
2879+
}

v1/ast/parser_bench_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,31 @@ func BenchmarkParseStatementNestedObjects(b *testing.B) {
7676
}
7777
}
7878

79+
// BenchmarkParseDeepNesting tests the impact of recursion depth tracking
80+
// on parsing performance with deeply nested structures (arrays and objects).
81+
// Different depths are used to measure the overhead at various nesting levels.
82+
func BenchmarkParseDeepNesting(b *testing.B) {
83+
depths := []int{10, 50, 100, 500, 2500, 12500}
84+
85+
b.Run("NestedArrays", func(b *testing.B) {
86+
for _, depth := range depths {
87+
b.Run(fmt.Sprintf("depth-%d", depth), func(b *testing.B) {
88+
stmt := generateDeeplyNestedArray(depth)
89+
runParseStatementBenchmark(b, stmt)
90+
})
91+
}
92+
})
93+
94+
b.Run("NestedObjects", func(b *testing.B) {
95+
for _, depth := range depths {
96+
b.Run(fmt.Sprintf("depth-%d", depth), func(b *testing.B) {
97+
stmt := generateDeeplyNestedObject(depth)
98+
runParseStatementBenchmark(b, stmt)
99+
})
100+
}
101+
})
102+
}
103+
79104
func BenchmarkParseStatementNestedObjectsOrSets(b *testing.B) {
80105
sizes := []int{1, 5, 10, 15, 20}
81106
for _, size := range sizes {

0 commit comments

Comments
 (0)