@@ -120,7 +120,7 @@ func getEnumDefault(enum *descriptor.Enum) string {
120
120
// messageToQueryParameters converts a message to a list of OpenAPI query parameters.
121
121
func messageToQueryParameters (message * descriptor.Message , reg * descriptor.Registry , pathParams []descriptor.Parameter , body * descriptor.Body ) (params []openapiParameterObject , err error ) {
122
122
for _ , field := range message .Fields {
123
- p , err := queryParams (message , field , "" , reg , pathParams , body )
123
+ p , err := queryParams (message , field , "" , reg , pathParams , body , reg . GetRecursiveDepth () )
124
124
if err != nil {
125
125
return nil , err
126
126
}
@@ -130,17 +130,64 @@ func messageToQueryParameters(message *descriptor.Message, reg *descriptor.Regis
130
130
}
131
131
132
132
// queryParams converts a field to a list of OpenAPI query parameters recursively through the use of nestedQueryParams.
133
- func queryParams (message * descriptor.Message , field * descriptor.Field , prefix string , reg * descriptor.Registry , pathParams []descriptor.Parameter , body * descriptor.Body ) (params []openapiParameterObject , err error ) {
134
- return nestedQueryParams (message , field , prefix , reg , pathParams , body , map [string ]bool {})
133
+ func queryParams (message * descriptor.Message , field * descriptor.Field , prefix string , reg * descriptor.Registry , pathParams []descriptor.Parameter , body * descriptor.Body , recursiveCount int ) (params []openapiParameterObject , err error ) {
134
+ return nestedQueryParams (message , field , prefix , reg , pathParams , body , newCycleChecker (recursiveCount ))
135
+ }
136
+
137
+ type cycleChecker struct {
138
+ m map [string ]int
139
+ count int
140
+ }
141
+
142
+ func newCycleChecker (recursive int ) * cycleChecker {
143
+ return & cycleChecker {
144
+ m : make (map [string ]int ),
145
+ count : recursive ,
146
+ }
147
+ }
148
+
149
+ // Check returns whether name is still within recursion
150
+ // toleration
151
+ func (c * cycleChecker ) Check (name string ) bool {
152
+ count , ok := c .m [name ]
153
+ count = count + 1
154
+ isCycle := count > c .count
155
+
156
+ if isCycle {
157
+ return false
158
+ }
159
+
160
+ // provision map entry if not available
161
+ if ! ok {
162
+ c .m [name ] = 1
163
+ return true
164
+ }
165
+
166
+ c .m [name ] = count
167
+
168
+ return true
169
+ }
170
+
171
+ func (c * cycleChecker ) Branch () * cycleChecker {
172
+ copy := & cycleChecker {
173
+ count : c .count ,
174
+ m : map [string ]int {},
175
+ }
176
+
177
+ for k , v := range c .m {
178
+ copy .m [k ] = v
179
+ }
180
+
181
+ return copy
135
182
}
136
183
137
184
// nestedQueryParams converts a field to a list of OpenAPI query parameters recursively.
138
185
// This function is a helper function for queryParams, that keeps track of cyclical message references
139
186
// through the use of
140
- // touched map[string]bool
141
- // If a cycle is discovered, an error is returned, as cyclical data structures aren't allowed
187
+ // touched map[string]int
188
+ // If a cycle is discovered, an error is returned, as cyclical data structures are dangerous
142
189
// in query parameters.
143
- func nestedQueryParams (message * descriptor.Message , field * descriptor.Field , prefix string , reg * descriptor.Registry , pathParams []descriptor.Parameter , body * descriptor.Body , touchedIn map [ string ] bool ) (params []openapiParameterObject , err error ) {
190
+ func nestedQueryParams (message * descriptor.Message , field * descriptor.Field , prefix string , reg * descriptor.Registry , pathParams []descriptor.Parameter , body * descriptor.Body , cycle * cycleChecker ) (params []openapiParameterObject , err error ) {
144
191
// make sure the parameter is not already listed as a path parameter
145
192
for _ , pathParam := range pathParams {
146
193
if pathParam .Target == field {
@@ -248,19 +295,15 @@ func nestedQueryParams(message *descriptor.Message, field *descriptor.Field, pre
248
295
}
249
296
250
297
// Check for cyclical message reference:
251
- isCycle := touchedIn [ * msg .Name ]
252
- if isCycle {
253
- return nil , fmt .Errorf ("recursive types are not allowed for query parameters, cycle found on %q" , fieldType )
298
+ isOK := cycle . Check ( * msg .Name )
299
+ if ! isOK {
300
+ return nil , fmt .Errorf ("exceeded recursive count (%d) for query parameter %q" , cycle . count , fieldType )
254
301
}
255
302
256
303
// Construct a new map with the message name so a cycle further down the recursive path can be detected.
257
304
// Do not keep anything in the original touched reference and do not pass that reference along. This will
258
305
// prevent clobbering adjacent records while recursing.
259
- touchedOut := make (map [string ]bool )
260
- for k , v := range touchedIn {
261
- touchedOut [k ] = v
262
- }
263
- touchedOut [* msg .Name ] = true
306
+ touchedOut := cycle .Branch ()
264
307
265
308
for _ , nestedField := range msg .Fields {
266
309
var fieldName string
0 commit comments