Skip to content

Commit 57a4260

Browse files
authored
Add typed nil validation to dsl.Security (#3574)
* Add a test case for eval.InvalidArgError() * Add typed nil validation to dsl.Security * Remove unnecessary blocks
1 parent f0108a7 commit 57a4260

File tree

2 files changed

+37
-37
lines changed

2 files changed

+37
-37
lines changed

dsl/security.go

+25-26
Original file line numberDiff line numberDiff line change
@@ -228,35 +228,34 @@ func JWTSecurity(name string, fn ...func()) *expr.SchemeExpr {
228228
// })
229229
func Security(args ...any) {
230230
var dsl func()
231-
{
232-
if d, ok := args[len(args)-1].(func()); ok {
233-
args = args[:len(args)-1]
234-
dsl = d
235-
}
236-
}
237-
238-
var schemes []*expr.SchemeExpr
239-
{
240-
schemes = make([]*expr.SchemeExpr, len(args))
241-
for i, arg := range args {
242-
switch val := arg.(type) {
243-
case string:
244-
for _, s := range expr.Root.Schemes {
245-
if s.SchemeName == val {
246-
schemes[i] = expr.DupScheme(s)
247-
break
248-
}
249-
}
250-
if schemes[i] == nil {
251-
eval.ReportError("security scheme %q not found", val)
252-
return
231+
if d, ok := args[len(args)-1].(func()); ok {
232+
args = args[:len(args)-1]
233+
dsl = d
234+
}
235+
236+
schemes := make([]*expr.SchemeExpr, len(args))
237+
for i, arg := range args {
238+
switch val := arg.(type) {
239+
case string:
240+
for _, s := range expr.Root.Schemes {
241+
if s.SchemeName == val {
242+
schemes[i] = expr.DupScheme(s)
243+
break
253244
}
254-
case *expr.SchemeExpr:
255-
schemes[i] = expr.DupScheme(val)
256-
default:
257-
eval.InvalidArgError("security scheme or security scheme name", val)
245+
}
246+
if schemes[i] == nil {
247+
eval.ReportError("security scheme %q not found", val)
248+
return
249+
}
250+
case *expr.SchemeExpr:
251+
if val == nil {
252+
eval.InvalidArgError("security scheme", val)
258253
return
259254
}
255+
schemes[i] = expr.DupScheme(val)
256+
default:
257+
eval.InvalidArgError("security scheme or security scheme name", val)
258+
return
260259
}
261260
}
262261

eval/eval_test.go

+12-11
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@ func TestInvalidArgError(t *testing.T) {
1414
dsl func()
1515
want string
1616
}{
17-
"Attribute": {func() { Type("name", func() { Attribute("name", String, "description", 1) }) }, "cannot use 1 (type int) as type func()"},
18-
"Body": {func() { Service("s", func() { Method("m", func() { HTTP(func() { Body(1) }) }) }) }, "cannot use 1 (type int) as type attribute name, user type or DSL"},
19-
"ErrorName (bool)": {func() { Type("name", func() { ErrorName(true) }) }, "cannot use true (type bool) as type name or position"},
20-
"ErrorName (int)": {func() { Type("name", func() { ErrorName(1, 2) }) }, "cannot use 2 (type int) as type name"},
21-
"Example": {func() { Example(1, 2) }, "cannot use 1 (type int) as type summary (string)"},
22-
"Headers": {func() { Headers(1) }, "cannot use 1 (type int) as type function"},
23-
"Param": {func() { API("name", func() { HTTP(func() { Params(1) }) }) }, "cannot use 1 (type int) as type function"},
24-
"Response": {func() { Service("s", func() { HTTP(func() { Response(1) }) }) }, "cannot use 1 (type int) as type name of error"},
25-
"ResultType": {func() { ResultType("identifier", 1) }, "cannot use 1 (type int) as type function or string"},
26-
"Security": {func() { Security(1) }, "cannot use 1 (type int) as type security scheme or security scheme name"},
27-
"Type": {func() { Type("name", 1) }, "cannot use 1 (type int) as type type or function"},
17+
"Attribute": {func() { Type("name", func() { Attribute("name", String, "description", 1) }) }, "cannot use 1 (type int) as type func()"},
18+
"Body": {func() { Service("s", func() { Method("m", func() { HTTP(func() { Body(1) }) }) }) }, "cannot use 1 (type int) as type attribute name, user type or DSL"},
19+
"ErrorName (bool)": {func() { Type("name", func() { ErrorName(true) }) }, "cannot use true (type bool) as type name or position"},
20+
"ErrorName (int)": {func() { Type("name", func() { ErrorName(1, 2) }) }, "cannot use 2 (type int) as type name"},
21+
"Example": {func() { Example(1, 2) }, "cannot use 1 (type int) as type summary (string)"},
22+
"Headers": {func() { Headers(1) }, "cannot use 1 (type int) as type function"},
23+
"Param": {func() { API("name", func() { HTTP(func() { Params(1) }) }) }, "cannot use 1 (type int) as type function"},
24+
"Response": {func() { Service("s", func() { HTTP(func() { Response(1) }) }) }, "cannot use 1 (type int) as type name of error"},
25+
"ResultType": {func() { ResultType("identifier", 1) }, "cannot use 1 (type int) as type function or string"},
26+
"Security": {func() { Security(1) }, "cannot use 1 (type int) as type security scheme or security scheme name"},
27+
"Security (typed nil)": {func() { Security((*expr.SchemeExpr)(nil)) }, "cannot use (*expr.SchemeExpr)(nil) (type *expr.SchemeExpr) as type security scheme"},
28+
"Type": {func() { Type("name", 1) }, "cannot use 1 (type int) as type type or function"},
2829
}
2930
for name, tc := range dsls {
3031
t.Run(name, func(t *testing.T) {

0 commit comments

Comments
 (0)