Skip to content

Commit 67d0a61

Browse files
authored
fix: error-strings custom funcs overwrites defaults (#1249)
1 parent 9177f50 commit 67d0a61

File tree

4 files changed

+98
-7
lines changed

4 files changed

+98
-7
lines changed

rule/error_strings.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,19 @@ func (r *ErrorStringsRule) Configure(arguments lint.Arguments) error {
3737

3838
var invalidCustomFunctions []string
3939
for _, argument := range arguments {
40-
if functionName, ok := argument.(string); ok {
41-
fields := strings.Split(strings.TrimSpace(functionName), ".")
42-
if len(fields) != 2 || len(fields[0]) == 0 || len(fields[1]) == 0 {
43-
invalidCustomFunctions = append(invalidCustomFunctions, functionName)
44-
continue
45-
}
46-
r.errorFunctions[fields[0]] = map[string]struct{}{fields[1]: {}}
40+
pkgFunction, ok := argument.(string)
41+
if !ok {
42+
continue
4743
}
44+
pkg, function, ok := strings.Cut(strings.TrimSpace(pkgFunction), ".")
45+
if !ok || pkg == "" || function == "" {
46+
invalidCustomFunctions = append(invalidCustomFunctions, pkgFunction)
47+
continue
48+
}
49+
if _, ok := r.errorFunctions[pkg]; !ok {
50+
r.errorFunctions[pkg] = map[string]struct{}{}
51+
}
52+
r.errorFunctions[pkg][function] = struct{}{}
4853
}
4954
if len(invalidCustomFunctions) != 0 {
5055
return fmt.Errorf("found invalid custom function: %s", strings.Join(invalidCustomFunctions, ","))

rule/error_strings_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package rule_test
2+
3+
import (
4+
"errors"
5+
"testing"
6+
7+
"github.com/mgechev/revive/lint"
8+
"github.com/mgechev/revive/rule"
9+
)
10+
11+
func TestErrorStringsRule_Configure(t *testing.T) {
12+
tests := []struct {
13+
name string
14+
arguments lint.Arguments
15+
wantErr error
16+
}{
17+
{
18+
name: "Default configuration",
19+
arguments: lint.Arguments{},
20+
},
21+
{
22+
name: "Valid custom functions",
23+
arguments: lint.Arguments{"mypkg.MyErrorFunc", "errors.New"},
24+
},
25+
{
26+
name: "Argument not a string",
27+
arguments: lint.Arguments{123},
28+
},
29+
{
30+
name: "Invalid package",
31+
arguments: lint.Arguments{".MyErrorFunc"},
32+
wantErr: errors.New("found invalid custom function: .MyErrorFunc"),
33+
},
34+
{
35+
name: "Invalid function",
36+
arguments: lint.Arguments{"errors."},
37+
wantErr: errors.New("found invalid custom function: errors."),
38+
},
39+
{
40+
name: "Invalid custom function",
41+
arguments: lint.Arguments{"invalidFunction"},
42+
wantErr: errors.New("found invalid custom function: invalidFunction"),
43+
},
44+
{
45+
name: "Mixed valid and invalid custom functions",
46+
arguments: lint.Arguments{"mypkg.MyErrorFunc", "invalidFunction", "invalidFunction2"},
47+
wantErr: errors.New("found invalid custom function: invalidFunction,invalidFunction2"),
48+
},
49+
}
50+
51+
for _, tt := range tests {
52+
t.Run(tt.name, func(t *testing.T) {
53+
var r rule.ErrorStringsRule
54+
55+
err := r.Configure(tt.arguments)
56+
57+
if tt.wantErr == nil {
58+
if err != nil {
59+
t.Errorf("Configure() unexpected non-nil error %q", err)
60+
}
61+
return
62+
}
63+
if err == nil || err.Error() != tt.wantErr.Error() {
64+
t.Errorf("Configure() unexpected error: got %q, want %q", err, tt.wantErr)
65+
}
66+
})
67+
}
68+
}

test/error_strings_custom_functions_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,10 @@ func TestErrorStringsWithCustomFunctions(t *testing.T) {
1313
Arguments: args,
1414
})
1515
}
16+
17+
func TestErrorStringsIssue1243(t *testing.T) {
18+
args := []any{"errors.Wrap"}
19+
testRule(t, "error_strings_issue_1243", &rule.ErrorStringsRule{}, &lint.RuleConfig{
20+
Arguments: args,
21+
})
22+
}

testdata/error_strings_issue_1243.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package fixtures
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
)
7+
8+
func issue1243() {
9+
err := errors.New("An error occurred!") // MATCH /error strings should not be capitalized or end with punctuation or a newline/
10+
fmt.Println(err)
11+
}

0 commit comments

Comments
 (0)