Skip to content

Commit f912d21

Browse files
feat: add visibility level filtering
1 parent 92afaa3 commit f912d21

File tree

3 files changed

+129
-38
lines changed

3 files changed

+129
-38
lines changed

protoc-gen-openapiv3/internal/genopenapiv3/file_generator.go

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,49 +15,55 @@ import (
1515
type fileGenerator struct {
1616
*generator
1717

18-
doc *openapi3.T
18+
spec *openapi3.T
1919
}
2020

21-
func (fg *fileGenerator) generateFileDoc(file *descriptor.File) *openapi3.T {
21+
func (fg *fileGenerator) generateFileSpec(file *descriptor.File) *openapi3.T {
2222

23-
fg.doc = convertFileOptions(file)
23+
fg.spec = convertFileOptions(file)
2424

25-
fg.doc.Components = &openapi3.Components{}
26-
fg.doc.Components.Schemas = make(openapi3.Schemas)
27-
fg.doc.Components.RequestBodies = make(openapi3.RequestBodies)
25+
fg.spec.Components = &openapi3.Components{}
26+
fg.spec.Components.Schemas = make(openapi3.Schemas)
27+
fg.spec.Components.RequestBodies = make(openapi3.RequestBodies)
2828

29-
if fg.doc.Paths == nil {
30-
fg.doc.Paths = &openapi3.Paths{}
29+
if fg.spec.Paths == nil {
30+
fg.spec.Paths = &openapi3.Paths{}
3131
}
3232

3333
for _, svc := range file.Services {
34-
err := fg.generateServiceDoc(svc)
35-
if err != nil {
36-
grpclog.Errorf("could not generate service document: %v", err)
34+
if fg.IsServiceVisible(svc) {
35+
err := fg.generateServiceSpec(svc)
36+
if err != nil {
37+
grpclog.Errorf("could not generate service document: %v", err)
38+
}
3739
}
3840
}
3941

4042
for _, msg := range file.Messages {
41-
fg.getMessageSchemaRef(msg)
43+
if fg.IsMessageVisible(msg) {
44+
fg.generateMessageSchemaRef(msg)
45+
}
4246
}
4347

4448
for _, enum := range file.Enums {
45-
fg.getEnumSchema(enum)
49+
if fg.IsEnumVisible(enum) {
50+
fg.generateEnumSchemaRef(enum)
51+
}
4652
}
4753

48-
return fg.doc
54+
return fg.spec
4955
}
5056

51-
func (fg *fileGenerator) getMessageSchemaRef(msg *descriptor.Message) *openapi3.SchemaRef {
57+
func (fg *fileGenerator) generateMessageSchemaRef(msg *descriptor.Message) *openapi3.SchemaRef {
5258
name := fg.resolveName(msg.FQMN())
5359
resultRef := openapi3.NewSchemaRef(fmt.Sprintf("#/components/schemas/%s", name), nil)
5460

55-
_, ok := fg.doc.Components.Schemas[name]
61+
_, ok := fg.spec.Components.Schemas[name]
5662
if ok {
5763
return resultRef
5864
}
5965

60-
fg.doc.Components.Schemas[name] = fg.generateMessageSchema(msg, nil).NewRef()
66+
fg.spec.Components.Schemas[name] = fg.generateMessageSchema(msg, nil).NewRef()
6167

6268
return resultRef
6369
}
@@ -110,7 +116,7 @@ func (fg *fileGenerator) generateMessageSchema(msg *descriptor.Message, excludeF
110116
switch fg.reg.GetOneOfStrategy() {
111117
case "oneOf":
112118
return &openapi3.Schema{
113-
OneOf: fg.generateMessageWithOneOfsSchemas(allOneOfsProperties, properties, msg.GetOneofDecl()),
119+
OneOf: fg.generateMessageWithOneOfsSchemas(allOneOfsProperties, properties, msg.GetOneofDecl(), ""),
114120
}
115121
default:
116122
grpclog.Fatal("unknown oneof strategy")
@@ -139,7 +145,7 @@ e.g.: if you have a proto like this:
139145
2 * 2 = 4, object schemas will be generate for each combination of set {field_one, field_two} and {field_three, field_four}
140146
*/
141147
func (fg *fileGenerator) generateMessageWithOneOfsSchemas(allOneOfsProperties map[int32]openapi3.Schemas, properties openapi3.Schemas,
142-
oneOfs []*descriptorpb.OneofDescriptorProto) openapi3.SchemaRefs {
148+
oneOfs []*descriptorpb.OneofDescriptorProto, namePrefix string) openapi3.SchemaRefs {
143149
if len(oneOfs) == 0 {
144150
return openapi3.SchemaRefs{&openapi3.SchemaRef{
145151
Value: &openapi3.Schema{
@@ -162,7 +168,7 @@ func (fg *fileGenerator) generateMessageWithOneOfsSchemas(allOneOfsProperties ma
162168
for fieldName, fieldSchema := range oneOfProperties {
163169
newProperties := maps.Clone(properties)
164170
newProperties[fieldName] = fieldSchema
165-
res = append(res, fg.generateMessageWithOneOfsSchemas(newAllOneOfsProperties, newProperties, newOneOfs)...)
171+
res = append(res, fg.generateMessageWithOneOfsSchemas(newAllOneOfsProperties, newProperties, newOneOfs, namePrefix + fieldName)...)
166172
}
167173

168174
return res
@@ -205,7 +211,7 @@ func (fg *fileGenerator) generateFieldTypeSchema(fd *descriptorpb.FieldDescripto
205211
switch ft := fd.GetType(); ft {
206212
case descriptorpb.FieldDescriptorProto_TYPE_ENUM, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE, descriptorpb.FieldDescriptorProto_TYPE_GROUP:
207213
openAPIRef := fg.resolveType(fd.GetTypeName())
208-
if schema, ok := fg.doc.Components.Schemas[openAPIRef]; ok {
214+
if schema, ok := fg.spec.Components.Schemas[openAPIRef]; ok {
209215
return schema
210216
} else {
211217
if fd.GetType() == descriptorpb.FieldDescriptorProto_TYPE_ENUM {
@@ -214,13 +220,13 @@ func (fg *fileGenerator) generateFieldTypeSchema(fd *descriptorpb.FieldDescripto
214220
panic(err)
215221
}
216222

217-
return fg.getEnumSchema(fieldTypeEnum)
223+
return fg.generateEnumSchemaRef(fieldTypeEnum)
218224
} else {
219225
fieldTypeMsg, err := fg.reg.LookupMsg(location, fd.GetTypeName())
220226
if err != nil {
221227
panic(err)
222228
}
223-
return fg.getMessageSchemaRef(fieldTypeMsg)
229+
return fg.generateMessageSchemaRef(fieldTypeMsg)
224230
}
225231
}
226232
default:
@@ -230,16 +236,16 @@ func (fg *fileGenerator) generateFieldTypeSchema(fd *descriptorpb.FieldDescripto
230236
}
231237
}
232238

233-
func (fg *fileGenerator) getEnumSchema(enum *descriptor.Enum) *openapi3.SchemaRef {
239+
func (fg *fileGenerator) generateEnumSchemaRef(enum *descriptor.Enum) *openapi3.SchemaRef {
234240
name := fg.resolveName(enum.FQEN())
235241

236-
schemaRef, ok := fg.doc.Components.Schemas[name]
242+
schemaRef, ok := fg.spec.Components.Schemas[name]
237243
if ok {
238244
return schemaRef
239245
}
240246

241247
schemaRef = fg.generateEnumSchema(enum).NewRef()
242-
fg.doc.Components.Schemas[name] = schemaRef
248+
fg.spec.Components.Schemas[name] = schemaRef
243249

244250
return schemaRef
245251
}
@@ -256,11 +262,13 @@ func (fg *fileGenerator) generateEnumSchema(enum *descriptor.Enum) *openapi3.Sch
256262
}
257263
}
258264

259-
func (fg *fileGenerator) generateServiceDoc(svc *descriptor.Service) error {
265+
func (fg *fileGenerator) generateServiceSpec(svc *descriptor.Service) error {
260266
for _, meth := range svc.Methods {
261-
err := fg.generateMethodDoc(meth)
262-
if err != nil {
263-
return fmt.Errorf("could not generate method %s doc: %w", meth.GetName(), err)
267+
if fg.generator.IsMethodVisible(meth) {
268+
err := fg.generateMethodDoc(meth)
269+
if err != nil {
270+
return fmt.Errorf("could not generate method %s doc: %w", meth.GetName(), err)
271+
}
264272
}
265273
}
266274

@@ -300,13 +308,13 @@ func (fg *fileGenerator) generateMethodDoc(meth *descriptor.Method) error {
300308
name := fg.resolveName(meth.RequestType.FQMN())
301309
resultRef := fmt.Sprintf("#/components/requestBodies/%s", name)
302310

303-
fg.doc.Components.RequestBodies[name] = &openapi3.RequestBodyRef{Value: openapi3.NewRequestBody().WithContent(
311+
fg.spec.Components.RequestBodies[name] = &openapi3.RequestBodyRef{Value: openapi3.NewRequestBody().WithContent(
304312
openapi3.NewContentWithJSONSchemaRef(messageSchema.NewRef()))}
305313

306314
requestBody = &openapi3.RequestBodyRef{Ref: resultRef}
307315
} else {
308316
requestBody = &openapi3.RequestBodyRef{Value: openapi3.NewRequestBody().
309-
WithJSONSchemaRef(fg.getMessageSchemaRef(meth.RequestType))}
317+
WithJSONSchemaRef(fg.generateMessageSchemaRef(meth.RequestType))}
310318
}
311319
}
312320
}
@@ -317,10 +325,10 @@ func (fg *fileGenerator) generateMethodDoc(meth *descriptor.Method) error {
317325
}
318326

319327
successResponseSchema := openapi3.NewResponse().
320-
WithJSONSchemaRef(fg.getMessageSchemaRef(meth.ResponseType))
328+
WithJSONSchemaRef(fg.generateMessageSchemaRef(meth.ResponseType))
321329

322330
defaultResponseSchema := openapi3.NewResponse().
323-
WithJSONSchemaRef(fg.getMessageSchemaRef(defaultResponse))
331+
WithJSONSchemaRef(fg.generateMessageSchemaRef(defaultResponse))
324332

325333
responses := openapi3.NewResponses(openapi3.WithStatus(200, &openapi3.ResponseRef{Value: successResponseSchema}),
326334
openapi3.WithName("default", defaultResponseSchema))
@@ -340,10 +348,10 @@ func (fg *fileGenerator) generateMethodDoc(meth *descriptor.Method) error {
340348
}
341349

342350
path := fg.convertPathTemplate(binding.PathTmpl.Template)
343-
pathItem := fg.doc.Paths.Find(path)
351+
pathItem := fg.spec.Paths.Find(path)
344352
if pathItem == nil {
345353
pathItem = &openapi3.PathItem{}
346-
fg.doc.Paths.Set(path, pathItem)
354+
fg.spec.Paths.Set(path, pathItem)
347355
}
348356

349357
switch binding.HTTPMethod {

protoc-gen-openapiv3/internal/genopenapiv3/generator.go

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/grpc-ecosystem/grpc-gateway/v2/internal/descriptor"
1010
gen "github.com/grpc-ecosystem/grpc-gateway/v2/internal/generator"
1111
"github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv3/options"
12+
"google.golang.org/genproto/googleapis/api/visibility"
1213
statuspb "google.golang.org/genproto/googleapis/rpc/status"
1314
"google.golang.org/protobuf/proto"
1415
"google.golang.org/protobuf/reflect/protodesc"
@@ -22,6 +23,7 @@ type generator struct {
2223
format Format
2324
}
2425

26+
2527
func NewGenerator(reg *descriptor.Registry, format Format) gen.Generator {
2628
return &generator{
2729
reg: reg,
@@ -38,8 +40,8 @@ func (g *generator) Generate(targets []*descriptor.File) ([]*descriptor.Response
3840
respFiles := make([]*descriptor.ResponseFile, 0, len(targets))
3941
docs := make([]*openapi3.T, 0, len(targets))
4042
for _, t := range targets {
41-
fileGenerator := &fileGenerator{generator: g, doc: &openapi3.T{}}
42-
doc := fileGenerator.generateFileDoc(t)
43+
fileGenerator := &fileGenerator{generator: g, spec: &openapi3.T{}}
44+
doc := fileGenerator.generateFileSpec(t)
4345
docs = append(docs, doc)
4446

4547
contentBytes, err := g.format.MarshalOpenAPIDoc(doc)
@@ -80,6 +82,84 @@ func (g *generator) Generate(targets []*descriptor.File) ([]*descriptor.Response
8082
return respFiles, nil
8183
}
8284

85+
func (g *generator) IsMessageVisible(msg *descriptor.Message) bool {
86+
if !proto.HasExtension(msg, visibility.E_MessageVisibility) {
87+
return true
88+
}
89+
90+
ext := proto.GetExtension(msg.DescriptorProto, visibility.E_MessageVisibility)
91+
visibilityOpt, ok := ext.(*visibility.VisibilityRule)
92+
if ok {
93+
return g.isVisible(visibilityOpt)
94+
}
95+
96+
return true
97+
}
98+
99+
func (g *generator) IsFieldVisible(field *descriptor.Field) bool {
100+
if !proto.HasExtension(field, visibility.E_MessageVisibility) {
101+
return true
102+
}
103+
104+
ext := proto.GetExtension(field.FieldDescriptorProto, visibility.E_MessageVisibility)
105+
visibilityOpt, ok := ext.(*visibility.VisibilityRule)
106+
if ok {
107+
return g.isVisible(visibilityOpt)
108+
}
109+
110+
return true
111+
}
112+
func (g *generator) IsEnumVisible(enum *descriptor.Enum) bool {
113+
114+
if !proto.HasExtension(enum, visibility.E_MessageVisibility) {
115+
return true
116+
}
117+
118+
ext := proto.GetExtension(enum.EnumDescriptorProto, visibility.E_EnumVisibility)
119+
visibilityOpt, ok := ext.(*visibility.VisibilityRule)
120+
if ok {
121+
return g.isVisible(visibilityOpt)
122+
}
123+
124+
return true
125+
}
126+
127+
func (g *generator) IsMethodVisible(meth *descriptor.Method) bool {
128+
if !proto.HasExtension(meth, visibility.E_MessageVisibility) {
129+
return true
130+
}
131+
132+
ext := proto.GetExtension(meth.MethodDescriptorProto, visibility.E_MethodVisibility)
133+
visibilityOpt, ok := ext.(*visibility.VisibilityRule)
134+
if ok {
135+
return g.isVisible(visibilityOpt)
136+
}
137+
138+
return true
139+
}
140+
141+
func (g *generator) IsServiceVisible(svc *descriptor.Service) bool {
142+
if !proto.HasExtension(svc, visibility.E_ApiVisibility) {
143+
return true
144+
}
145+
146+
ext := proto.GetExtension(svc.ServiceDescriptorProto, visibility.E_ApiVisibility)
147+
visibilityOpt, ok := ext.(*visibility.VisibilityRule)
148+
if ok {
149+
return g.isVisible(visibilityOpt)
150+
}
151+
152+
return true
153+
}
154+
155+
func (g *generator) isVisible(rule *visibility.VisibilityRule) bool {
156+
if rule == nil {
157+
return true
158+
}
159+
160+
return g.reg.GetVisibilityRestrictionSelectors()[rule.GetRestriction()]
161+
}
162+
83163
func (g *generator) getOperationName(serviceName, methodName string, bindingIdx int) string {
84164
if bindingIdx == 0 {
85165
return fmt.Sprintf("%s_%s", serviceName, methodName)

protoc-gen-openapiv3/main.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/grpc-ecosystem/grpc-gateway/v2/internal/codegenerator"
1010
"github.com/grpc-ecosystem/grpc-gateway/v2/internal/descriptor"
1111
"github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv3/internal/genopenapiv3"
12+
"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
1213
"google.golang.org/grpc/grpclog"
1314
"google.golang.org/protobuf/proto"
1415
"google.golang.org/protobuf/types/pluginpb"
@@ -19,6 +20,7 @@ var (
1920
openAPIConfiguration = flag.String("openapi_configuration", "", "path to file which describes the OpenAPI Configuration in YAML format")
2021
oneOfStrategy = flag.String("oneof_strategy", "oneOf", "how to handle oneofs")
2122
outputFormat = flag.String("output_format", string(genopenapiv3.FormatJSON), fmt.Sprintf("output content format. Allowed values are: `%s`, `%s`", genopenapiv3.FormatJSON, genopenapiv3.FormatYAML))
23+
visibilityRestrictionSelectors = utilities.StringArrayFlag(flag.CommandLine, "visibility_restriction_selectors", "list of `google.api.VisibilityRule` visibility labels to include in the generated output when a visibility annotation is defined. Repeat this option to supply multiple values. Elements without visibility annotations are unaffected by this setting.")
2224
)
2325

2426
// Variables set by goreleaser at build time
@@ -66,6 +68,7 @@ func main() {
6668
}
6769

6870
reg.SetOneOfStrategy(*oneOfStrategy)
71+
reg.SetVisibilityRestrictionSelectors(*visibilityRestrictionSelectors)
6972

7073
if err := reg.Load(req); err != nil {
7174
emitError(err)

0 commit comments

Comments
 (0)