Skip to content

Commit 7c9107a

Browse files
authored
Remove ban on recursive query parameters (#2022)
* implements a max recursive depth check * bazel is just at root and update build.bazel * review comments
1 parent 6adc4fe commit 7c9107a

File tree

7 files changed

+124
-16
lines changed

7 files changed

+124
-16
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ bazel-genfiles
77
bazel-grpc-gateway
88
bazel-out
99
bazel-testlogs
10-
.bazelrc
10+
/.bazelrc
1111

1212
# Go vendor directory
1313
vendor

internal/descriptor/registry.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ type Registry struct {
109109

110110
// omitPackageDoc, if false, causes a package comment to be included in the generated code.
111111
omitPackageDoc bool
112+
113+
// recursiveDepth sets the maximum depth of a field parameter
114+
recursiveDepth int
112115
}
113116

114117
type repeatedFieldSeparator struct {
@@ -134,6 +137,7 @@ func NewRegistry() *Registry {
134137
messageOptions: make(map[string]*options.Schema),
135138
serviceOptions: make(map[string]*options.Tag),
136139
fieldOptions: make(map[string]*options.JSONSchema),
140+
recursiveDepth: 1000,
137141
}
138142
}
139143

@@ -356,6 +360,16 @@ func (r *Registry) SetStandalone(standalone bool) {
356360
r.standalone = standalone
357361
}
358362

363+
// SetRecursiveDepth records the max recursion count
364+
func (r *Registry) SetRecursiveDepth(count int) {
365+
r.recursiveDepth = count
366+
}
367+
368+
// GetRecursiveDepth returns the max recursion count
369+
func (r *Registry) GetRecursiveDepth() int {
370+
return r.recursiveDepth
371+
}
372+
359373
// ReserveGoPackageAlias reserves the unique alias of go package.
360374
// If succeeded, the alias will be never used for other packages in generated go files.
361375
// If failed, the alias is already taken by another package, so you need to use another

protoc-gen-openapiv2/internal/genopenapi/BUILD.bazel

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ go_library(
3434
go_test(
3535
name = "go_default_test",
3636
size = "small",
37-
srcs = ["template_test.go"],
37+
srcs = [
38+
"cycle_test.go",
39+
"template_test.go",
40+
],
3841
embed = [":go_default_library"],
3942
deps = [
4043
"//internal/descriptor:go_default_library",
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package genopenapi
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestCycle(t *testing.T) {
8+
for _, tt := range []struct {
9+
max int
10+
attempt int
11+
e bool
12+
}{
13+
{
14+
max: 3,
15+
attempt: 3,
16+
e: true,
17+
},
18+
{
19+
max: 5,
20+
attempt: 6,
21+
},
22+
{
23+
max: 1000,
24+
attempt: 1001,
25+
},
26+
} {
27+
28+
c := newCycleChecker(tt.max)
29+
var final bool
30+
for i := 0; i < tt.attempt; i++ {
31+
final = c.Check("a")
32+
if !final {
33+
break
34+
}
35+
}
36+
37+
if final != tt.e {
38+
t.Errorf("got: %t wanted: %t", final, tt.e)
39+
}
40+
}
41+
}

protoc-gen-openapiv2/internal/genopenapi/generator.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ type wrapper struct {
3636
swagger *openapiSwaggerObject
3737
}
3838

39+
type GeneratorOptions struct {
40+
Registry *descriptor.Registry
41+
RecursiveDepth int
42+
}
43+
3944
// New returns a new generator which generates grpc gateway files.
4045
func New(reg *descriptor.Registry) gen.Generator {
4146
return &generator{reg: reg}

protoc-gen-openapiv2/internal/genopenapi/template.go

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func getEnumDefault(enum *descriptor.Enum) string {
120120
// messageToQueryParameters converts a message to a list of OpenAPI query parameters.
121121
func messageToQueryParameters(message *descriptor.Message, reg *descriptor.Registry, pathParams []descriptor.Parameter, body *descriptor.Body) (params []openapiParameterObject, err error) {
122122
for _, field := range message.Fields {
123-
p, err := queryParams(message, field, "", reg, pathParams, body)
123+
p, err := queryParams(message, field, "", reg, pathParams, body, reg.GetRecursiveDepth())
124124
if err != nil {
125125
return nil, err
126126
}
@@ -130,17 +130,64 @@ func messageToQueryParameters(message *descriptor.Message, reg *descriptor.Regis
130130
}
131131

132132
// queryParams converts a field to a list of OpenAPI query parameters recursively through the use of nestedQueryParams.
133-
func queryParams(message *descriptor.Message, field *descriptor.Field, prefix string, reg *descriptor.Registry, pathParams []descriptor.Parameter, body *descriptor.Body) (params []openapiParameterObject, err error) {
134-
return nestedQueryParams(message, field, prefix, reg, pathParams, body, map[string]bool{})
133+
func queryParams(message *descriptor.Message, field *descriptor.Field, prefix string, reg *descriptor.Registry, pathParams []descriptor.Parameter, body *descriptor.Body, recursiveCount int) (params []openapiParameterObject, err error) {
134+
return nestedQueryParams(message, field, prefix, reg, pathParams, body, newCycleChecker(recursiveCount))
135+
}
136+
137+
type cycleChecker struct {
138+
m map[string]int
139+
count int
140+
}
141+
142+
func newCycleChecker(recursive int) *cycleChecker {
143+
return &cycleChecker{
144+
m: make(map[string]int),
145+
count: recursive,
146+
}
147+
}
148+
149+
// Check returns whether name is still within recursion
150+
// toleration
151+
func (c *cycleChecker) Check(name string) bool {
152+
count, ok := c.m[name]
153+
count = count + 1
154+
isCycle := count > c.count
155+
156+
if isCycle {
157+
return false
158+
}
159+
160+
// provision map entry if not available
161+
if !ok {
162+
c.m[name] = 1
163+
return true
164+
}
165+
166+
c.m[name] = count
167+
168+
return true
169+
}
170+
171+
func (c *cycleChecker) Branch() *cycleChecker {
172+
copy := &cycleChecker{
173+
count: c.count,
174+
m: map[string]int{},
175+
}
176+
177+
for k, v := range c.m {
178+
copy.m[k] = v
179+
}
180+
181+
return copy
135182
}
136183

137184
// nestedQueryParams converts a field to a list of OpenAPI query parameters recursively.
138185
// This function is a helper function for queryParams, that keeps track of cyclical message references
139186
// through the use of
140-
// touched map[string]bool
141-
// If a cycle is discovered, an error is returned, as cyclical data structures aren't allowed
187+
// touched map[string]int
188+
// If a cycle is discovered, an error is returned, as cyclical data structures are dangerous
142189
// in query parameters.
143-
func nestedQueryParams(message *descriptor.Message, field *descriptor.Field, prefix string, reg *descriptor.Registry, pathParams []descriptor.Parameter, body *descriptor.Body, touchedIn map[string]bool) (params []openapiParameterObject, err error) {
190+
func nestedQueryParams(message *descriptor.Message, field *descriptor.Field, prefix string, reg *descriptor.Registry, pathParams []descriptor.Parameter, body *descriptor.Body, cycle *cycleChecker) (params []openapiParameterObject, err error) {
144191
// make sure the parameter is not already listed as a path parameter
145192
for _, pathParam := range pathParams {
146193
if pathParam.Target == field {
@@ -248,19 +295,15 @@ func nestedQueryParams(message *descriptor.Message, field *descriptor.Field, pre
248295
}
249296

250297
// Check for cyclical message reference:
251-
isCycle := touchedIn[*msg.Name]
252-
if isCycle {
253-
return nil, fmt.Errorf("recursive types are not allowed for query parameters, cycle found on %q", fieldType)
298+
isOK := cycle.Check(*msg.Name)
299+
if !isOK {
300+
return nil, fmt.Errorf("exceeded recursive count (%d) for query parameter %q", cycle.count, fieldType)
254301
}
255302

256303
// Construct a new map with the message name so a cycle further down the recursive path can be detected.
257304
// Do not keep anything in the original touched reference and do not pass that reference along. This will
258305
// prevent clobbering adjacent records while recursing.
259-
touchedOut := make(map[string]bool)
260-
for k, v := range touchedIn {
261-
touchedOut[k] = v
262-
}
263-
touchedOut[*msg.Name] = true
306+
touchedOut := cycle.Branch()
264307

265308
for _, nestedField := range msg.Fields {
266309
var fieldName string

protoc-gen-openapiv2/main.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ var (
3333
simpleOperationIDs = flag.Bool("simple_operation_ids", false, "whether to remove the service prefix in the operationID generation. Can introduce duplicate operationIDs, use with caution.")
3434
openAPIConfiguration = flag.String("openapi_configuration", "", "path to file which describes the OpenAPI Configuration in YAML format")
3535
generateUnboundMethods = flag.Bool("generate_unbound_methods", false, "generate swagger metadata even for RPC methods that have no HttpRule annotation")
36+
recursiveDepth = flag.Int("recursive-depth", 1000, "maximum recursion count allowed for a field type")
3637
)
3738

3839
// Variables set by goreleaser at build time
@@ -89,6 +90,7 @@ func main() {
8990
reg.SetDisableDefaultErrors(*disableDefaultErrors)
9091
reg.SetSimpleOperationIDs(*simpleOperationIDs)
9192
reg.SetGenerateUnboundMethods(*generateUnboundMethods)
93+
reg.SetRecursiveDepth(*recursiveDepth)
9294
if err := reg.SetRepeatedPathParamSeparator(*repeatedPathParamSeparator); err != nil {
9395
emitError(err)
9496
return

0 commit comments

Comments
 (0)