diff --git a/field_parser_v3_test.go b/field_parser_v3_test.go index 593238c79..85176ec92 100644 --- a/field_parser_v3_test.go +++ b/field_parser_v3_test.go @@ -184,7 +184,7 @@ func TestDefaultFieldParserV3(t *testing.T) { t.Parallel() schema := spec.NewSchemaSpec() - schema.Spec.Type = []string{"string"} + schema.Spec.Type = &spec.SingleOrArray[string]{"string"} parser := &Parser{} fieldParser := newTagBaseFieldParserV3( parser, diff --git a/parser.go b/parser.go index b7c81fd6b..f327441bd 100644 --- a/parser.go +++ b/parser.go @@ -1395,8 +1395,13 @@ func (parser *Parser) ParseDefinition(typeSpecDef *TypeSpecDef) (*Schema, error) if parser.isInStructStack(typeSpecDef) { parser.debug.Printf("Skipping '%s', recursion detected.", typeName) + schemaName := typeName + if typeSpecDef.SchemaName != "" { + schemaName = typeSpecDef.SchemaName + } + return &Schema{ - Name: typeName, + Name: schemaName, PkgPath: typeSpecDef.PkgPath, Schema: PrimitiveSchema(OBJECT), }, diff --git a/parserv3.go b/parserv3.go index 256153da9..0d98a15e7 100644 --- a/parserv3.go +++ b/parserv3.go @@ -702,12 +702,25 @@ func (p *Parser) ParseDefinitionV3(typeSpecDef *TypeSpecDef) (*SchemaV3, error) if p.isInStructStack(typeSpecDef) { p.debug.Printf("Skipping '%s', recursion detected.", typeName) - return &SchemaV3{ - Name: typeName, - PkgPath: typeSpecDef.PkgPath, - Schema: PrimitiveSchemaV3(OBJECT).Spec, - }, - ErrRecursiveParseStruct + schemaName := typeName + if typeSpecDef.SchemaName != "" { + schemaName = typeSpecDef.SchemaName + } + + schema := &SchemaV3{ + Name: schemaName, + PkgPath: typeSpecDef.PkgPath, + Schema: PrimitiveSchemaV3(OBJECT).Spec, + } + + p.parsedSchemasV3[typeSpecDef] = schema + + if p.openAPI.Components.Spec.Schemas == nil { + p.openAPI.Components.Spec.Schemas = make(map[string]*spec.RefOrSpec[spec.Schema]) + } + p.openAPI.Components.Spec.Schemas[schema.Name] = spec.NewRefOrSpec(nil, schema.Schema) + + return schema, ErrRecursiveParseStruct } p.structStack = append(p.structStack, typeSpecDef) @@ -1086,5 +1099,11 @@ func (p *Parser) GetSchemaTypePathV3(schema *spec.RefOrSpec[spec.Schema], depth func (p *Parser) getSchemaByRef(ref *spec.Ref) *spec.Schema { searchString := strings.ReplaceAll(ref.Ref, "#/components/schemas/", "") - return p.openAPI.Components.Spec.Schemas[searchString].Spec + schemaRef, exists := p.openAPI.Components.Spec.Schemas[searchString] + if !exists || schemaRef == nil { + println(fmt.Sprintf("Schema not found for ref: %s, returning any", ref.Ref)) + return &spec.Schema{} // return empty schema if not found + } + + return schemaRef.Spec } diff --git a/parserv3_test.go b/parserv3_test.go index f6c1a8771..130fdbd4e 100644 --- a/parserv3_test.go +++ b/parserv3_test.go @@ -515,3 +515,59 @@ func TestParseTypeAlias(t *testing.T) { assert.JSONEq(t, string(expected), string(result)) } + +func TestParseRecursionWithSchemaName(t *testing.T) { + t.Parallel() + + searchDir := "testdata/recursion_schema_name" + p := New(GenerateOpenAPI3Doc(true)) + + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + require.NoError(t, err) + + userSchema, exists := p.openAPI.Components.Spec.Schemas["User"] + require.True(t, exists, "User schema should exist") + require.NotNil(t, userSchema, "User schema should not be nil") + require.NotNil(t, userSchema.Spec, "User schema spec should not be nil") + + assert.Equal(t, "object", (*userSchema.Spec.Type)[0]) + + childrenProp, exists := userSchema.Spec.Properties["children"] + require.True(t, exists, "children property should exist") + require.NotNil(t, childrenProp.Spec, "children property spec should not be nil") + + assert.Equal(t, "array", (*childrenProp.Spec.Type)[0]) + + require.NotNil(t, childrenProp.Spec.Items, "children items should not be nil") + require.NotNil(t, childrenProp.Spec.Items.Schema, "children items schema should not be nil") + + expectedRef := "#/components/schemas/User" + assert.Equal(t, expectedRef, childrenProp.Spec.Items.Schema.Ref.Ref) +} + +func TestGetSchemaByRef(t *testing.T) { + t.Parallel() + + p := New(GenerateOpenAPI3Doc(true)) + p.openAPI.Components.Spec.Schemas = make(map[string]*spec.RefOrSpec[spec.Schema]) + + t.Run("Existing schema", func(t *testing.T) { + testSchema := &spec.Schema{} + testSchema.Type = &spec.SingleOrArray[string]{"string"} + p.openAPI.Components.Spec.Schemas["TestSchema"] = spec.NewRefOrSpec(nil, testSchema) + + ref := &spec.Ref{Ref: "#/components/schemas/TestSchema"} + result := p.getSchemaByRef(ref) + + require.NotNil(t, result) + assert.Equal(t, testSchema, result) + }) + + t.Run("Non-existing schema returns empty schema", func(t *testing.T) { + ref := &spec.Ref{Ref: "#/components/schemas/NonExistentSchema"} + result := p.getSchemaByRef(ref) + + require.NotNil(t, result) + assert.Equal(t, &spec.Schema{}, result) + }) +} diff --git a/testdata/recursion_schema_name/main.go b/testdata/recursion_schema_name/main.go new file mode 100644 index 000000000..406a921bc --- /dev/null +++ b/testdata/recursion_schema_name/main.go @@ -0,0 +1,24 @@ +package main + +// User represents a user with self-references +type User struct { + ID int `json:"id"` + Name string `json:"name"` + Children []*User `json:"children,omitempty"` +} // @name User + +// @title Test API +// @version 1.0 +// @description Test API for recursion with schema name +// @BasePath / +func main() {} + +// GetUser returns a user +// @Summary Get user +// @Description Get user by ID +// @Tags users +// @Accept json +// @Produce json +// @Success 200 {object} User +// @Router /user [get] +func GetUser() {}