Skip to content

Commit 48b4c88

Browse files
feat: add method parameter generation
1 parent f4a4001 commit 48b4c88

File tree

1 file changed

+162
-95
lines changed

1 file changed

+162
-95
lines changed

protoc-gen-openapiv3/internal/genopenapiv3/file_generator.go

Lines changed: 162 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package genopenapiv3
22

33
import (
4+
"fmt"
45
"maps"
56
"slices"
7+
"strings"
68

79
"github.com/getkin/kin-openapi/openapi3"
810
"github.com/grpc-ecosystem/grpc-gateway/v2/internal/descriptor"
11+
"github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv3/options"
912
"google.golang.org/grpc/grpclog"
1013
"google.golang.org/protobuf/types/descriptorpb"
1114
)
@@ -37,7 +40,7 @@ func (fg *fileGenerator) generateFileDoc(file *descriptor.File) *openapi3.T {
3740
}
3841

3942
for _, msg := range file.Messages {
40-
fg.getMessageSchema(msg)
43+
fg.getMessageSchemaRef(msg)
4144
}
4245

4346
for _, enum := range file.Enums {
@@ -47,20 +50,21 @@ func (fg *fileGenerator) generateFileDoc(file *descriptor.File) *openapi3.T {
4750
return fg.doc
4851
}
4952

50-
func (fg *fileGenerator) getMessageSchema(msg *descriptor.Message) *openapi3.SchemaRef {
53+
func (fg *fileGenerator) getMessageSchemaRef(msg *descriptor.Message) *openapi3.SchemaRef {
5154
name := fg.resolveName(msg.FQMN())
52-
schemaRef, ok := fg.doc.Components.Schemas[name]
55+
resultRef := openapi3.NewSchemaRef(fmt.Sprintf("#/components/schemas/%s", name), nil)
56+
57+
_, ok := fg.doc.Components.Schemas[name]
5358
if ok {
54-
return schemaRef.Value.NewRef()
59+
return resultRef
5560
}
5661

57-
schemaRef = fg.generateMessageSchema(msg).NewRef()
58-
fg.doc.Components.Schemas[name] = schemaRef
62+
fg.doc.Components.Schemas[name] = fg.generateMessageSchema(msg, nil).NewRef()
5963

60-
return schemaRef
64+
return resultRef
6165
}
6266

63-
func (fg *fileGenerator) generateMessageSchema(msg *descriptor.Message) *openapi3.Schema {
67+
func (fg *fileGenerator) generateMessageSchema(msg *descriptor.Message, excludeFields []string) *openapi3.Schema {
6468
msgName := fg.resolveName(msg.FQMN())
6569
if scheme, ok := wktSchemas[msgName]; ok {
6670
return scheme
@@ -193,9 +197,7 @@ func (fg *fileGenerator) generateFieldDoc(field *descriptor.Field) *openapi3.Sch
193197

194198
func (fg *fileGenerator) generateFieldTypeSchema(fd *descriptorpb.FieldDescriptorProto, location string) *openapi3.SchemaRef {
195199
if schema, ok := primitiveTypeSchemas[fd.GetType()]; ok {
196-
return &openapi3.SchemaRef{
197-
Value: schema,
198-
}
200+
return schema.NewRef()
199201
}
200202

201203
switch ft := fd.GetType(); ft {
@@ -216,7 +218,7 @@ func (fg *fileGenerator) generateFieldTypeSchema(fd *descriptorpb.FieldDescripto
216218
if err != nil {
217219
panic(err)
218220
}
219-
return fg.getMessageSchema(fieldTypeMsg)
221+
return fg.getMessageSchemaRef(fieldTypeMsg)
220222
}
221223
}
222224
default:
@@ -259,95 +261,160 @@ func (fg *fileGenerator) generateServiceDoc(svc *descriptor.Service) {
259261
}
260262

261263
func (fg *fileGenerator) generateMethodDoc(meth *descriptor.Method) error {
264+
for bindingIdx, binding := range meth.Bindings {
265+
opOpts, err := extractOperationOptionFromMethodDescriptor(meth.MethodDescriptorProto)
266+
if err != nil {
267+
return fmt.Errorf("error extracting method %s operations: %v", meth.GetName(), err)
268+
}
269+
270+
var params openapi3.Parameters
271+
var requestBody *openapi3.RequestBodyRef
272+
273+
if meth.RequestType != nil {
274+
tmpParams, err := fg.messageToParameters(meth.RequestType, binding.PathParams, binding.Body,
275+
binding.HTTPMethod, "")
276+
if err != nil {
277+
grpclog.Errorf("error generating query parameters for method %s: %v", meth.GetName(), err)
278+
} else {
279+
params = append(params, tmpParams...)
280+
}
281+
282+
switch binding.HTTPMethod {
283+
case "POST", "PUT", "PATCH":
284+
// For POST, PUT, PATCH, add request body
285+
requestBody = &openapi3.RequestBodyRef{Value: openapi3.NewRequestBody().WithContent(
286+
openapi3.NewContentWithJSONSchemaRef(fg.getMessageSchemaRef(meth.RequestType)))}
287+
288+
}
289+
}
290+
291+
responseSchema := &openapi3.ResponseRef{
292+
Ref: fg.getMessageSchemaRef(meth.ResponseType).RefString(),
293+
}
294+
295+
operation := &openapi3.Operation{
296+
Tags: []string{meth.Service.GetName()},
297+
Summary: opOpts.GetSummary(),
298+
Description: opOpts.GetDescription(),
299+
OperationID: fg.getOperationName(meth.Service.GetName(), meth.GetName(), bindingIdx),
300+
RequestBody: requestBody,
301+
Responses: openapi3.NewResponses(openapi3.WithStatus(200, responseSchema)),
302+
Parameters: params,
303+
}
304+
305+
if opOpts.GetSecurity() != nil {
306+
operation.Security = fg.generateSecurity(opOpts.GetSecurity())
307+
}
308+
309+
path := fg.convertPathTemplate(binding.PathTmpl.Template)
310+
pathItem := fg.doc.Paths.Find(path)
311+
if pathItem == nil {
312+
pathItem = &openapi3.PathItem{}
313+
fg.doc.Paths.Set(path, pathItem)
314+
}
315+
316+
switch binding.HTTPMethod {
317+
case "GET":
318+
pathItem.Get = operation
319+
case "POST":
320+
pathItem.Post = operation
321+
case "PUT":
322+
pathItem.Put = operation
323+
case "PATCH":
324+
pathItem.Patch = operation
325+
case "DELETE":
326+
pathItem.Delete = operation
327+
case "HEAD":
328+
pathItem.Head = operation
329+
case "OPTIONS":
330+
pathItem.Options = operation
331+
}
332+
}
333+
262334
return nil
263-
// for bindingIdx, binding := range meth.Bindings {
264-
// opOpts, err := extractOperationOptionFromMethodDescriptor(meth.MethodDescriptorProto)
265-
// if err != nil {
266-
// return fmt.Errorf("error extracting method %s operations: %v", meth.GetName(), err)
267-
// }
268-
//
269-
// pathParams, err := fg.generatePathParameters(binding.PathParams)
270-
// if err != nil {
271-
// return fmt.Errorf("error generating path parameters for method %s: %v", meth.GetName(), err)
272-
// }
273-
//
274-
// if meth.RequestType != nil {
275-
// switch binding.HTTPMethod {
276-
// case "GET", "DELETE":
277-
// queryParams, err := fg.messageToQueryParameters(meth.RequestType, binding.PathParams, binding.Body, binding.HTTPMethod)
278-
// if err != nil {
279-
// grpclog.Errorf("error generating query parameters for method %s: %v", meth.GetName(), err)
280-
// } else {
281-
// pathParams = append(pathParams, queryParams...)
282-
// }
283-
// case "POST", "PUT", "PATCH":
284-
// // For POST, PUT, PATCH, add request body
285-
// operation.RequestBody = fg.generateRequestBody(binding, meth.RequestType)
286-
//
287-
// queryParams, err := fg.messageToQueryParameters(meth.RequestType, binding.PathParams, binding.Body, binding.HTTPMethod)
288-
// if err != nil {
289-
// grpclog.Errorf("error generating query parameters for method %s: %v", meth.GetName(), err)
290-
// } else {
291-
// pathParams = append(pathParams, queryParams...)
292-
// }
293-
// }
294-
// }
295-
//
296-
// var responses *openapi3.Responses
297-
//
298-
// operation := &openapi3.Operation{
299-
// Tags: []string{meth.Service.GetName()},
300-
// Summary: opOpts.GetSummary(),
301-
// Description: opOpts.GetDescription(),
302-
// OperationID: fg.getOperationName(meth.Service.GetName(), meth.GetName(), bindingIdx),
303-
// Parameters: pathParams,
304-
// Responses: openapi3.NewResponses(),
305-
// }
306-
//
307-
// fg.addMethodResponses(operation, meth)
308-
//
309-
// if opOpts.GetSecurity() != nil {
310-
// operation.Security = fg.convertSecurity(opOpts.GetSecurity())
311-
// }
312-
//
313-
// pathTemplate := fg.convertPathTemplate(binding.PathTmpl.Template)
314-
// pathItem := doc.Paths.Find(pathTemplate)
315-
// if pathItem == nil {
316-
// pathItem = &openapi3.PathItem{}
317-
// doc.Paths.Set(pathTemplate, pathItem)
318-
// }
319-
//
320-
// switch binding.HTTPMethod {
321-
// case "GET":
322-
// pathItem.Get = operation
323-
// case "POST":
324-
// pathItem.Post = operation
325-
// case "PUT":
326-
// pathItem.Put = operation
327-
// case "PATCH":
328-
// pathItem.Patch = operation
329-
// case "DELETE":
330-
// pathItem.Delete = operation
331-
// case "HEAD":
332-
// pathItem.Head = operation
333-
// case "OPTIONS":
334-
// pathItem.Options = operation
335-
// }
336-
// }
337-
//
338-
// return nil
339335
}
340336

341-
func (fg *fileGenerator) generatePathParameters(params []descriptor.Parameter) (any, error) {
342-
panic("unimplemented")
337+
func (fg *fileGenerator) generateSecurity(requirements []*options.SecurityRequirement) *openapi3.SecurityRequirements {
338+
res := openapi3.NewSecurityRequirements()
339+
340+
for _, req := range requirements {
341+
oAPISecReq := openapi3.NewSecurityRequirement()
342+
for authenticator, scopes := range req.GetAdditionalProperties() {
343+
oAPISecReq.Authenticate(authenticator, scopes.GetScopes()...)
344+
}
345+
346+
res.With(oAPISecReq)
347+
}
348+
349+
return res
343350
}
344351

345-
func (fg *fileGenerator) generateResponseSchema(responseType *descriptor.Message) *openapi3.SchemaRef {
346-
if responseType == nil {
347-
return &openapi3.SchemaRef{
348-
Value: &openapi3.Schema{Type: &openapi3.Types{openapi3.TypeObject}},
352+
func (fg *fileGenerator) convertPathTemplate(template string) string {
353+
// TODO: handle /{args=foo/*}
354+
return template
355+
}
356+
357+
func (fg *fileGenerator) messageToParameters(message *descriptor.Message,
358+
pathParams []descriptor.Parameter, body *descriptor.Body,
359+
httpMethod string, paramPrefix string) (openapi3.Parameters, error) {
360+
361+
params := openapi3.NewParameters()
362+
for _, field := range message.Fields {
363+
paramType, isParam := fg.getParamType(field, pathParams, body, message, httpMethod)
364+
if !isParam {
365+
// TODO: handle nested path parameter reference
366+
continue
367+
}
368+
369+
schema := fg.generateFieldTypeSchema(field.FieldDescriptorProto, fg.fqmnToLocation(field.FQFN()))
370+
371+
switch paramType {
372+
case openapi3.ParameterInPath:
373+
param := openapi3.NewPathParameter(field.GetJsonName())
374+
param.Schema = schema
375+
params = append(params, &openapi3.ParameterRef{
376+
Value: param,
377+
})
378+
case openapi3.ParameterInQuery:
379+
param := openapi3.NewQueryParameter(field.GetJsonName())
380+
param.Schema = schema
381+
params = append(params, &openapi3.ParameterRef{
382+
Value: param,
383+
})
349384
}
350385
}
351386

352-
return fg.getMessageSchema(responseType)
387+
return params, nil
388+
}
389+
390+
func (fg *fileGenerator) getParamType(field *descriptor.Field, pathParams []descriptor.Parameter, body *descriptor.Body,
391+
message *descriptor.Message, httpMethod string) (string, bool) {
392+
393+
for _, pathParam := range pathParams {
394+
if pathParam.Target.FQFN() == field.FQFN() {
395+
return openapi3.ParameterInPath, true
396+
}
397+
398+
if strings.HasSuffix(pathParam.Target.FQFN(), message.FQMN()) {
399+
return "", false
400+
}
401+
}
402+
403+
if httpMethod == "GET" || httpMethod == "DELETE" {
404+
return openapi3.ParameterInQuery, true
405+
}
406+
407+
if body == nil {
408+
return openapi3.ParameterInQuery, true
409+
}
410+
411+
if len(body.FieldPath) == 0 {
412+
return "", false
413+
}
414+
415+
if body.FieldPath[len(body.FieldPath)-1].Target.FQFN() == field.FQFN() {
416+
return "", false
417+
}
418+
419+
return openapi3.ParameterInQuery, true
353420
}

0 commit comments

Comments
 (0)