Skip to content

Commit 9e25b2b

Browse files
authored
protoc-gen-swagger: fix infinite loop on circular references in query parameters (#1266)
* Added function nestedQueryParams with map[string] parameter for keeping track of and detecting circular references. Added test TestMessageToQueryParametersRecursive for testing gracefully handling of circular references between messages. See issue #1167 * Code-review change requests accepted * More missed circle references changed to cycle Fixes #1167
1 parent 01598a7 commit 9e25b2b

File tree

2 files changed

+138
-2
lines changed

2 files changed

+138
-2
lines changed

protoc-gen-swagger/genswagger/template.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,18 @@ func messageToQueryParameters(message *descriptor.Message, reg *descriptor.Regis
120120
return params, nil
121121
}
122122

123-
// queryParams converts a field to a list of swagger query parameters recursively.
123+
// queryParams converts a field to a list of swagger query parameters recursively through the use of nestedQueryParams.
124124
func queryParams(message *descriptor.Message, field *descriptor.Field, prefix string, reg *descriptor.Registry, pathParams []descriptor.Parameter) (params []swaggerParameterObject, err error) {
125+
return nestedQueryParams(message, field, prefix, reg, pathParams, map[string]bool{})
126+
}
127+
128+
// nestedQueryParams converts a field to a list of swagger query parameters recursively.
129+
// This function is a helper function for queryParams, that keeps track of cyclical message references
130+
// through the use of
131+
// touched map[string]bool
132+
// If a cycle is discovered, an error is returned, as cyclical data structures aren't allowed
133+
// in query parameters.
134+
func nestedQueryParams(message *descriptor.Message, field *descriptor.Field, prefix string, reg *descriptor.Registry, pathParams []descriptor.Parameter, touched map[string]bool) (params []swaggerParameterObject, err error) {
125135
// make sure the parameter is not already listed as a path parameter
126136
for _, pathParam := range pathParams {
127137
if pathParam.Target == field {
@@ -216,14 +226,22 @@ func queryParams(message *descriptor.Message, field *descriptor.Field, prefix st
216226
if err != nil {
217227
return nil, fmt.Errorf("unknown message type %s", fieldType)
218228
}
229+
// Check for cyclical message reference:
230+
isCycle := touched[*msg.Name]
231+
if isCycle {
232+
return nil, fmt.Errorf("Recursive types are not allowed for query parameters, cycle found on %q", fieldType)
233+
}
234+
// Update map with the massage name so a cycle further down the recursive path can be detected.
235+
touched[*msg.Name] = true
236+
219237
for _, nestedField := range msg.Fields {
220238
var fieldName string
221239
if reg.GetUseJSONNamesForFields() {
222240
fieldName = field.GetJsonName()
223241
} else {
224242
fieldName = field.GetName()
225243
}
226-
p, err := queryParams(msg, nestedField, prefix+fieldName+".", reg, pathParams)
244+
p, err := nestedQueryParams(msg, nestedField, prefix+fieldName+".", reg, pathParams, touched)
227245
if err != nil {
228246
return nil, err
229247
}

protoc-gen-swagger/genswagger/template_test.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,124 @@ func TestMessageToQueryParameters(t *testing.T) {
407407
}
408408
}
409409

410+
// TestMessagetoQueryParametersRecursive, is a check that cyclical references between messages
411+
// are handled gracefully. The goal is to insure that attempts to add messages with cyclical
412+
// references to query-parameters returns an error message.
413+
func TestMessageToQueryParametersRecursive(t *testing.T) {
414+
type test struct {
415+
MsgDescs []*protodescriptor.DescriptorProto
416+
Message string
417+
}
418+
419+
tests := []test{
420+
// First test:
421+
// Here we test that a message that references it self through a field will return an error.
422+
// Example proto:
423+
// message DirectRecursiveMessage {
424+
// DirectRecursiveMessage nested = 1;
425+
// }
426+
{
427+
MsgDescs: []*protodescriptor.DescriptorProto{
428+
&protodescriptor.DescriptorProto{
429+
Name: proto.String("DirectRecursiveMessage"),
430+
Field: []*protodescriptor.FieldDescriptorProto{
431+
{
432+
Name: proto.String("nested"),
433+
Label: protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
434+
Type: protodescriptor.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
435+
TypeName: proto.String(".example.DirectRecursiveMessage"),
436+
Number: proto.Int32(1),
437+
},
438+
},
439+
},
440+
},
441+
Message: "DirectRecursiveMessage",
442+
},
443+
// Second test:
444+
// Here we test that a cycle through multiple messages is detected and that an error is returned.
445+
// Sample:
446+
// message Root { NodeMessage nested = 1; }
447+
// message NodeMessage { CycleMessage nested = 1; }
448+
// message CycleMessage { Root nested = 1; }
449+
{
450+
MsgDescs: []*protodescriptor.DescriptorProto{
451+
&protodescriptor.DescriptorProto{
452+
Name: proto.String("RootMessage"),
453+
Field: []*protodescriptor.FieldDescriptorProto{
454+
{
455+
Name: proto.String("nested"),
456+
Label: protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
457+
Type: protodescriptor.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
458+
TypeName: proto.String(".example.NodeMessage"),
459+
Number: proto.Int32(1),
460+
},
461+
},
462+
},
463+
&protodescriptor.DescriptorProto{
464+
Name: proto.String("NodeMessage"),
465+
Field: []*protodescriptor.FieldDescriptorProto{
466+
{
467+
Name: proto.String("nested"),
468+
Label: protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
469+
Type: protodescriptor.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
470+
TypeName: proto.String(".example.CycleMessage"),
471+
Number: proto.Int32(1),
472+
},
473+
},
474+
},
475+
&protodescriptor.DescriptorProto{
476+
Name: proto.String("CycleMessage"),
477+
Field: []*protodescriptor.FieldDescriptorProto{
478+
{
479+
Name: proto.String("nested"),
480+
Label: protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
481+
Type: protodescriptor.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
482+
TypeName: proto.String(".example.RootMessage"),
483+
Number: proto.Int32(1),
484+
},
485+
},
486+
},
487+
},
488+
Message: "RootMessage",
489+
},
490+
}
491+
492+
for _, test := range tests {
493+
reg := descriptor.NewRegistry()
494+
msgs := []*descriptor.Message{}
495+
for _, msgdesc := range test.MsgDescs {
496+
msgs = append(msgs, &descriptor.Message{DescriptorProto: msgdesc})
497+
}
498+
file := descriptor.File{
499+
FileDescriptorProto: &protodescriptor.FileDescriptorProto{
500+
SourceCodeInfo: &protodescriptor.SourceCodeInfo{},
501+
Name: proto.String("example.proto"),
502+
Package: proto.String("example"),
503+
Dependency: []string{},
504+
MessageType: test.MsgDescs,
505+
Service: []*protodescriptor.ServiceDescriptorProto{},
506+
},
507+
GoPkg: descriptor.GoPackage{
508+
Path: "example.com/path/to/example/example.pb",
509+
Name: "example_pb",
510+
},
511+
Messages: msgs,
512+
}
513+
reg.Load(&plugin.CodeGeneratorRequest{
514+
ProtoFile: []*protodescriptor.FileDescriptorProto{file.FileDescriptorProto},
515+
})
516+
517+
message, err := reg.LookupMsg("", ".example."+test.Message)
518+
if err != nil {
519+
t.Fatalf("failed to lookup message: %s", err)
520+
}
521+
_, err = messageToQueryParameters(message, reg, []descriptor.Parameter{})
522+
if err == nil {
523+
t.Fatalf("It should not be allowed to have recursive query parameters")
524+
}
525+
}
526+
}
527+
410528
func TestMessageToQueryParametersWithJsonName(t *testing.T) {
411529
type test struct {
412530
MsgDescs []*protodescriptor.DescriptorProto

0 commit comments

Comments
 (0)