Skip to content

Commit 4d06dd6

Browse files
authored
Refactor how generated result types are handled (#3564)
* Refactor how generated result types are handled Make the generated result type root a global variable similar to the expression root. Remove the dependencies on the go-diff package, use testify instead. * Fix linter issues
1 parent c1a4639 commit 4d06dd6

31 files changed

+222
-359
lines changed

codegen/doc.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ In particular package codegen defines the data structure that represents a
66
generated file (see File) which is composed of sections, each corresponding to a
77
Go text template and accompanying data used to render the final code.
88
9-
THe package also include functions that can generate code that transforms a
9+
The package also includes functions that generate code to transform a
1010
given type into another (see GoTransform).
1111
*/
1212
package codegen

codegen/go_transform_helpers_test.go

+14-24
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ package codegen
33
import (
44
"testing"
55

6+
"github.com/stretchr/testify/assert"
7+
"github.com/stretchr/testify/require"
8+
69
"goa.design/goa/v3/codegen/testdata"
710
"goa.design/goa/v3/expr"
811
)
@@ -29,40 +32,27 @@ func TestGoTransformHelpers(t *testing.T) {
2932
Type expr.DataType
3033
HelperNames []string
3134
}{
32-
{"simple", simple, []string{}},
35+
{"simple", simple, nil},
3336
{"recursive", recursive, []string{"transformRecursiveToRecursive"}},
3437
{"composite", composite, []string{"transformSimpleToSimple"}},
3538
{"deep", deep, []string{"transformCompositeToComposite", "transformSimpleToSimple"}},
3639
{"deep-array", deepArray, []string{"transformCompositeToComposite", "transformSimpleToSimple"}},
37-
{"simple-alias", simpleAlias, []string{}},
38-
{"nested-map-alias", mapAlias, []string{}},
39-
{"array-map-alias", arrayMapAlias, []string{}},
40+
{"simple-alias", simpleAlias, nil},
41+
{"nested-map-alias", mapAlias, nil},
42+
{"array-map-alias", arrayMapAlias, nil},
4043
{"result-type-collection", collection, []string{"transformResultTypeToResultType"}},
4144
}
4245
for _, c := range tc {
4346
t.Run(c.Name, func(t *testing.T) {
44-
if c.Type == nil {
45-
t.Fatal("source type not found in testdata")
46-
}
47+
require.NotNil(t, c.Type, "source type not found in testdata")
4748
_, funcs, err := GoTransform(&expr.AttributeExpr{Type: c.Type}, &expr.AttributeExpr{Type: c.Type}, "source", "target", defaultCtx, defaultCtx, "", true)
48-
if err != nil {
49-
t.Fatal(err)
50-
}
51-
if len(funcs) != len(c.HelperNames) {
52-
t.Errorf("invalid helpers count, got: %d, expected %d", len(funcs), len(c.HelperNames))
53-
} else {
54-
var diffs []string
55-
actual := make([]string, len(funcs))
56-
for i, f := range funcs {
57-
actual[i] = f.Name
58-
if c.HelperNames[i] != f.Name {
59-
diffs = append(diffs, f.Name)
60-
}
61-
}
62-
if len(diffs) > 0 {
63-
t.Errorf("invalid helper names, got: %v, expected: %v", actual, c.HelperNames)
64-
}
49+
require.NoError(t, err)
50+
assert.Equal(t, len(c.HelperNames), len(funcs), "invalid helpers count")
51+
var actual []string
52+
for _, f := range funcs {
53+
actual = append(actual, f.Name)
6554
}
55+
assert.Equal(t, c.HelperNames, actual, "invalid helper names")
6656
})
6757
}
6858
}

codegen/go_transform_test.go

+7-12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ package codegen
33
import (
44
"testing"
55

6+
"github.com/stretchr/testify/assert"
7+
"github.com/stretchr/testify/require"
8+
69
"goa.design/goa/v3/codegen/testdata"
710
"goa.design/goa/v3/expr"
811
)
@@ -182,20 +185,12 @@ func TestGoTransform(t *testing.T) {
182185
t.Run(name, func(t *testing.T) {
183186
for _, c := range cases {
184187
t.Run(c.Name, func(t *testing.T) {
185-
if c.Source == nil {
186-
t.Fatal("source type not found in testdata")
187-
}
188-
if c.Target == nil {
189-
t.Fatal("target type not found in testdata")
190-
}
188+
require.NotNil(t, c.Source)
189+
require.NotNil(t, c.Target)
191190
code, _, err := GoTransform(&expr.AttributeExpr{Type: c.Source}, &expr.AttributeExpr{Type: c.Target}, "source", "target", c.SourceCtx, c.TargetCtx, "", true)
192-
if err != nil {
193-
t.Fatal(err)
194-
}
191+
require.NoError(t, err)
195192
code = FormatTestCode(t, "package foo\nfunc transform(){\n"+code+"}")
196-
if code != c.Code {
197-
t.Errorf("invalid code, got:\n%s\ngot vs. expected:\n%s", code, Diff(t, code, c.Code))
198-
}
193+
assert.Equal(t, c.Code, code)
199194
})
200195
}
201196
})

codegen/go_transform_union_test.go

+5-7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ package codegen
33
import (
44
"testing"
55

6+
"github.com/stretchr/testify/assert"
7+
"github.com/stretchr/testify/require"
8+
69
"goa.design/goa/v3/codegen/testdata"
710
"goa.design/goa/v3/expr"
811
)
@@ -41,14 +44,9 @@ func TestGoTransformUnion(t *testing.T) {
4144
for _, c := range tc {
4245
t.Run(c.Name, func(t *testing.T) {
4346
code, _, err := GoTransform(c.Source, c.Target, "source", "target", defaultCtx, defaultCtx, "", true)
44-
if err != nil {
45-
t.Errorf("unexpected error %s", err)
46-
return
47-
}
47+
require.NoError(t, err)
4848
code = FormatTestCode(t, "package foo\nfunc transform(){\n"+code+"}")
49-
if code != c.Expected {
50-
t.Errorf("invalid code, got:\n%s\ngot vs. expected:\n%s", code, Diff(t, code, c.Expected))
51-
}
49+
assert.Equal(t, c.Expected, code)
5250
})
5351
}
5452
}

codegen/service/convert_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -261,17 +261,17 @@ func runDSL(t *testing.T, dsl func()) *expr.RootExpr {
261261
Services = make(ServicesData)
262262
eval.Reset()
263263
expr.Root = new(expr.RootExpr)
264-
err := eval.Register(expr.Root)
265-
require.NoError(t, err)
264+
expr.GeneratedResultTypes = new(expr.ResultTypesRoot)
265+
require.NoError(t, eval.Register(expr.Root))
266+
require.NoError(t, eval.Register(expr.GeneratedResultTypes))
266267
expr.Root.API = expr.NewAPIExpr("test api", func() {})
267268
expr.Root.API.Servers = []*expr.ServerExpr{expr.Root.API.DefaultServer()}
268269

269270
// run DSL (first pass)
270271
require.True(t, eval.Execute(dsl, nil))
271272

272273
// run DSL (second pass)
273-
err = eval.RunDSL()
274-
require.NoError(t, err)
274+
require.NoError(t, eval.RunDSL())
275275

276276
// return generated root
277277
return expr.Root

codegen/service/endpoint_test.go

-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ func TestEndpoint(t *testing.T) {
3838
for _, c := range cases {
3939
t.Run(c.Name, func(t *testing.T) {
4040
codegen.RunDSL(t, c.DSL)
41-
expr.Root.GeneratedTypes = &expr.GeneratedRoot{}
4241
require.Len(t, expr.Root.Services, 1)
4342
fs := EndpointFile("goa.design/goa/example", expr.Root.Services[0])
4443
require.NotNil(t, fs)

codegen/service/example_svc_test.go

-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ func TestExampleServiceFiles(t *testing.T) {
3333
for _, c := range cases {
3434
t.Run(c.Name, func(t *testing.T) {
3535
codegen.RunDSL(t, c.DSL)
36-
expr.Root.GeneratedTypes = &expr.GeneratedRoot{}
3736
require.Len(t, expr.Root.Services, 3)
3837
fs := ExampleServiceFiles("", expr.Root)
3938
require.Len(t, fs, 3)

codegen/testing.go

+13-56
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,26 @@ import (
44
"bytes"
55
"fmt"
66
"os"
7-
"os/exec"
87
"strings"
98
"testing"
109

10+
"github.com/stretchr/testify/require"
1111
"goa.design/goa/v3/eval"
1212
"goa.design/goa/v3/expr"
13-
14-
"github.com/sergi/go-diff/diffmatchpatch"
1513
)
1614

1715
// RunDSL returns the DSL root resulting from running the given DSL.
1816
func RunDSL(t *testing.T, dsl func()) *expr.RootExpr {
1917
t.Helper()
2018
eval.Reset()
2119
expr.Root = new(expr.RootExpr)
22-
expr.Root.GeneratedTypes = &expr.GeneratedRoot{}
23-
if err := eval.Register(expr.Root); err != nil {
24-
t.Fatal(err)
25-
}
26-
if err := eval.Register(expr.Root.GeneratedTypes); err != nil {
27-
t.Fatal(err)
28-
}
20+
expr.GeneratedResultTypes = new(expr.ResultTypesRoot)
21+
require.NoError(t, eval.Register(expr.Root))
22+
require.NoError(t, eval.Register(expr.GeneratedResultTypes))
2923
expr.Root.API = expr.NewAPIExpr("test api", func() {})
3024
expr.Root.API.Servers = []*expr.ServerExpr{expr.Root.API.DefaultServer()}
31-
if !eval.Execute(dsl, nil) {
32-
t.Fatal(eval.Context.Error())
33-
}
34-
if err := eval.RunDSL(); err != nil {
35-
t.Fatal(err)
36-
}
25+
require.True(t, eval.Execute(dsl, nil), eval.Context.Error())
26+
require.NoError(t, eval.RunDSL())
3727
return expr.Root
3828
}
3929

@@ -52,22 +42,16 @@ func SectionsCode(t *testing.T, sections []*SectionTemplate) string {
5242
}
5343

5444
// SectionCodeFromImportsAndMethods generates and formats the code for given import and method definition sections.
55-
func SectionCodeFromImportsAndMethods(t *testing.T, importSection *SectionTemplate, methodSection *SectionTemplate) string {
45+
func SectionCodeFromImportsAndMethods(t *testing.T, importSection, methodSection *SectionTemplate) string {
5646
t.Helper()
5747
var code bytes.Buffer
58-
if err := importSection.Write(&code); err != nil {
59-
t.Fatal(err)
60-
}
61-
48+
require.NoError(t, importSection.Write(&code))
6249
return sectionCodeWithPrefix(t, methodSection, code.String())
6350
}
6451

6552
func sectionCodeWithPrefix(t *testing.T, section *SectionTemplate, prefix string) string {
6653
var code bytes.Buffer
67-
if err := section.Write(&code); err != nil {
68-
t.Fatal(err)
69-
}
70-
54+
require.NoError(t, section.Write(&code))
7155
codestr := code.String()
7256

7357
if len(prefix) > 0 {
@@ -83,50 +67,23 @@ func FormatTestCode(t *testing.T, code string) string {
8367
t.Helper()
8468
tmp := CreateTempFile(t, code)
8569
defer os.Remove(tmp)
86-
if err := finalizeGoSource(tmp); err != nil {
87-
t.Fatal(err)
88-
}
70+
require.NoError(t, finalizeGoSource(tmp))
8971
content, err := os.ReadFile(tmp)
90-
if err != nil {
91-
t.Fatal(err)
92-
}
72+
require.NoError(t, err)
9373
return strings.Join(strings.Split(string(content), "\n")[2:], "\n")
9474
}
9575

96-
// Diff returns a diff between s1 and s2. It uses the diff tool if installed
97-
// otherwise degrades to using the dmp package.
98-
func Diff(t *testing.T, s1, s2 string) string {
99-
_, err := exec.LookPath("diff")
100-
supportsDiff := (err == nil)
101-
if !supportsDiff {
102-
dmp := diffmatchpatch.New()
103-
diffs := dmp.DiffMain(s1, s2, false)
104-
return dmp.DiffPrettyText(diffs)
105-
}
106-
left := CreateTempFile(t, s1)
107-
right := CreateTempFile(t, s2)
108-
defer os.Remove(left)
109-
defer os.Remove(right)
110-
cmd := exec.Command("diff", left, right)
111-
diffb, _ := cmd.CombinedOutput()
112-
return strings.ReplaceAll(string(diffb), "\t", " ␉ ")
113-
}
114-
11576
// CreateTempFile creates a temporary file and writes the given content.
11677
// It is used only for testing.
11778
func CreateTempFile(t *testing.T, content string) string {
11879
t.Helper()
11980
f, err := os.CreateTemp("", "")
120-
if err != nil {
121-
t.Fatal(err)
122-
}
81+
require.NoError(t, err)
12382
_, err = f.WriteString(content)
12483
if err != nil {
12584
os.Remove(f.Name())
12685
t.Fatal(err)
12786
}
128-
if err := f.Close(); err != nil {
129-
t.Fatal(err)
130-
}
87+
require.NoError(t, f.Close())
13188
return f.Name()
13289
}

codegen/validation_test.go

+4-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package codegen
33
import (
44
"testing"
55

6+
"github.com/stretchr/testify/assert"
7+
68
"goa.design/goa/v3/codegen/testdata"
79
"goa.design/goa/v3/expr"
810
)
@@ -66,18 +68,14 @@ func TestRecursiveValidationCode(t *testing.T) {
6668
ctx := NewAttributeContext(c.Pointer, false, c.UseDefault, "", scope)
6769
code := ValidationCode(&expr.AttributeExpr{Type: c.Type}, nil, ctx, c.Required, expr.IsAlias(c.Type), false, "target")
6870
code = FormatTestCode(t, "package foo\nfunc Validate() (err error){\n"+code+"}")
69-
if code != c.Code {
70-
t.Errorf("invalid code, got:\n%s\ngot vs. expected:\n%s", code, Diff(t, code, c.Code))
71-
}
71+
assert.Equal(t, c.Code, code)
7272
})
7373
}
7474
// Special case of unions with views
7575
t.Run("union-with-view", func(t *testing.T) {
7676
ctx := NewAttributeContext(false, false, false, "", scope)
7777
code := ValidationCode(&expr.AttributeExpr{Type: unionT}, nil, ctx, true, false, true, "target")
7878
code = FormatTestCode(t, "package foo\nfunc Validate() (err error){\n"+code+"}")
79-
if code != testdata.UnionWithViewValidationCode {
80-
t.Errorf("invalid code, got:\n%s\ngot vs. expected:\n%s", code, Diff(t, code, testdata.UnionWithViewValidationCode))
81-
}
79+
assert.Equal(t, testdata.UnionWithViewValidationCode, code)
8280
})
8381
}

dsl/http.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ func Body(args ...any) {
962962
if rt, ok := attr.Type.(*expr.ResultTypeExpr); ok && expr.IsArray(rt.Type) {
963963
// If the attribute type is a result type collection add the type to the
964964
// GeneratedTypes so that the type's DSLFunc is executed.
965-
*expr.Root.GeneratedTypes = append(*expr.Root.GeneratedTypes, rt)
965+
expr.GeneratedResultTypes.Append(rt)
966966
}
967967
if len(args) > 1 {
968968
var ok bool

dsl/result_type.go

+11-11
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,19 @@ func ResultType(identifier string, args ...any) *expr.ResultTypeExpr {
106106
}
107107
canonicalID := expr.CanonicalIdentifier(identifier)
108108
// Validate that result type identifier doesn't clash
109-
for _, rt := range expr.Root.ResultTypes {
110-
if re := rt.(*expr.ResultTypeExpr); re.Identifier == canonicalID {
109+
for _, rt := range *expr.GeneratedResultTypes {
110+
if rt.Identifier == canonicalID {
111111
eval.ReportError(
112112
"result type %#v with canonical identifier %#v is defined twice",
113113
identifier, canonicalID)
114114
return nil
115115
}
116116
}
117-
// Now save the type in the API result types map
118-
mt := expr.NewResultTypeExpr(typeName, identifier, fn)
119-
expr.Root.ResultTypes = append(expr.Root.ResultTypes, mt)
117+
// Add the type to the generated types root for later evaluation.
118+
rt := expr.NewResultTypeExpr(typeName, identifier, fn)
119+
expr.Root.ResultTypes = append(expr.Root.ResultTypes, rt)
120120

121-
return mt
121+
return rt
122122
}
123123

124124
// TypeName makes it possible to set the Go struct name for a type or result
@@ -201,7 +201,7 @@ func View(name string, adsl ...func()) {
201201
switch e := eval.Current().(type) {
202202
case *expr.ResultTypeExpr:
203203
if e.View(name) != nil {
204-
eval.ReportError("multiple expressions for view %#v in result type %#v", name, e.TypeName)
204+
eval.ReportError("view %q is defined multiple times in result type %q", name, e.TypeName)
205205
return
206206
}
207207
at := &expr.AttributeExpr{}
@@ -340,11 +340,11 @@ func CollectionOf(v any, adsl ...func()) *expr.ResultTypeExpr {
340340
}
341341
id = mime.FormatMediaType(rtype, params)
342342
canonical := expr.CanonicalIdentifier(id)
343-
if mt := expr.Root.GeneratedResultType(canonical); mt != nil {
343+
if mt := expr.GeneratedResultType(canonical); mt != nil {
344344
// Already have a type for this collection, reuse it.
345345
return mt
346346
}
347-
mt := expr.NewResultTypeExpr("", id, func() {
347+
rt := expr.NewResultTypeExpr("", id, func() {
348348
rt, ok := eval.Current().(*expr.ResultTypeExpr)
349349
if !ok {
350350
eval.IncompatibleDSL()
@@ -371,8 +371,8 @@ func CollectionOf(v any, adsl ...func()) *expr.ResultTypeExpr {
371371
})
372372
// do not execute the DSL right away, will be done last to make sure
373373
// the element DSL has run first.
374-
*expr.Root.GeneratedTypes = append(*expr.Root.GeneratedTypes, mt)
375-
return mt
374+
expr.GeneratedResultTypes.Append(rt)
375+
return rt
376376
}
377377

378378
// Reference sets a type or result type reference. The value itself can be a

0 commit comments

Comments
 (0)