Skip to content

Commit 92f28cb

Browse files
authored
refactor: moves code related to AST from rule.utils into astutils package (#1380)
Modifications summary: * Moves AST-related functions from rule/utils.go to astutils/ast_utils.go (+ modifies function calls) * Renames some of these AST-related functions * Avoids instantiating a printer config at each call to astutils.GoFmt * Uses astutils.IsIdent and astutils.IsPkgDotName when possible
1 parent 87b146c commit 92f28cb

35 files changed

+164
-148
lines changed

internal/astutils/ast_utils.go

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
package astutils
33

44
import (
5+
"bytes"
6+
"fmt"
57
"go/ast"
8+
"go/printer"
69
"go/token"
10+
"regexp"
711
"slices"
812
)
913

@@ -78,9 +82,80 @@ func getFieldTypeName(typ ast.Expr) string {
7882
}
7983
}
8084

81-
// IsStringLiteral returns true if the given expression is a string literal, false otherwise
85+
// IsStringLiteral returns true if the given expression is a string literal, false otherwise.
8286
func IsStringLiteral(e ast.Expr) bool {
8387
sl, ok := e.(*ast.BasicLit)
8488

8589
return ok && sl.Kind == token.STRING
8690
}
91+
92+
// IsCgoExported returns true if the given function declaration is exported as Cgo function, false otherwise.
93+
func IsCgoExported(f *ast.FuncDecl) bool {
94+
if f.Recv != nil || f.Doc == nil {
95+
return false
96+
}
97+
98+
cgoExport := regexp.MustCompile(fmt.Sprintf("(?m)^//export %s$", regexp.QuoteMeta(f.Name.Name)))
99+
for _, c := range f.Doc.List {
100+
if cgoExport.MatchString(c.Text) {
101+
return true
102+
}
103+
}
104+
return false
105+
}
106+
107+
// IsIdent returns true if the given expression is the identifier with name ident, false otherwise.
108+
func IsIdent(expr ast.Expr, ident string) bool {
109+
id, ok := expr.(*ast.Ident)
110+
return ok && id.Name == ident
111+
}
112+
113+
// IsPkgDotName returns true if the given expression is a selector expression of the form <pkg>.<name>, false otherwise.
114+
func IsPkgDotName(expr ast.Expr, pkg, name string) bool {
115+
sel, ok := expr.(*ast.SelectorExpr)
116+
return ok && IsIdent(sel.X, pkg) && IsIdent(sel.Sel, name)
117+
}
118+
119+
// PickNodes yields a list of nodes by picking them from a sub-ast with root node n.
120+
// Nodes are selected by applying the selector function
121+
func PickNodes(n ast.Node, selector func(n ast.Node) bool) []ast.Node {
122+
var result []ast.Node
123+
124+
if n == nil {
125+
return result
126+
}
127+
128+
onSelect := func(n ast.Node) {
129+
result = append(result, n)
130+
}
131+
p := picker{selector: selector, onSelect: onSelect}
132+
ast.Walk(p, n)
133+
return result
134+
}
135+
136+
type picker struct {
137+
selector func(n ast.Node) bool
138+
onSelect func(n ast.Node)
139+
}
140+
141+
func (p picker) Visit(node ast.Node) ast.Visitor {
142+
if p.selector == nil {
143+
return nil
144+
}
145+
146+
if p.selector(node) {
147+
p.onSelect(node)
148+
}
149+
150+
return p
151+
}
152+
153+
var gofmtConfig = &printer.Config{Tabwidth: 8}
154+
155+
// GoFmt returns a string representation of an AST subtree.
156+
func GoFmt(x any) string {
157+
buf := bytes.Buffer{}
158+
fs := token.NewFileSet()
159+
gofmtConfig.Fprint(&buf, fs, x)
160+
return buf.String()
161+
}

rule/atomic.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"go/token"
66
"go/types"
77

8+
"github.com/mgechev/revive/internal/astutils"
89
"github.com/mgechev/revive/lint"
910
)
1011

@@ -76,9 +77,9 @@ func (w atomic) Visit(node ast.Node) ast.Visitor {
7677
broken := false
7778

7879
if uarg, ok := arg.(*ast.UnaryExpr); ok && uarg.Op == token.AND {
79-
broken = gofmt(left) == gofmt(uarg.X)
80+
broken = astutils.GoFmt(left) == astutils.GoFmt(uarg.X)
8081
} else if star, ok := left.(*ast.StarExpr); ok {
81-
broken = gofmt(star.X) == gofmt(arg)
82+
broken = astutils.GoFmt(star.X) == astutils.GoFmt(arg)
8283
}
8384

8485
if broken {

rule/confusing_naming.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"strings"
77
"sync"
88

9+
"github.com/mgechev/revive/internal/astutils"
910
"github.com/mgechev/revive/lint"
1011
)
1112

@@ -190,7 +191,7 @@ func (w *lintConfusingNames) Visit(n ast.Node) ast.Visitor {
190191
// Exclude naming warnings for functions that are exported to C but
191192
// not exported in the Go API.
192193
// See https://github.com/golang/lint/issues/144.
193-
if ast.IsExported(v.Name.Name) || !isCgoExported(v) {
194+
if ast.IsExported(v.Name.Name) || !astutils.IsCgoExported(v) {
194195
checkMethodName(getStructName(v.Recv), v.Name, w)
195196
}
196197
case *ast.TypeSpec:

rule/confusing_results.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package rule
33
import (
44
"go/ast"
55

6+
"github.com/mgechev/revive/internal/astutils"
67
"github.com/mgechev/revive/lint"
78
)
89

@@ -28,7 +29,7 @@ func (*ConfusingResultsRule) Apply(file *lint.File, _ lint.Arguments) []lint.Fai
2829

2930
lastType := ""
3031
for _, result := range funcDecl.Type.Results.List {
31-
resultTypeName := gofmt(result.Type)
32+
resultTypeName := astutils.GoFmt(result.Type)
3233

3334
if resultTypeName == lastType {
3435
failures = append(failures, lint.Failure{

rule/constant_logical_expr.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"go/ast"
55
"go/token"
66

7+
"github.com/mgechev/revive/internal/astutils"
78
"github.com/mgechev/revive/lint"
89
)
910

@@ -40,7 +41,7 @@ func (w *lintConstantLogicalExpr) Visit(node ast.Node) ast.Visitor {
4041
return w
4142
}
4243

43-
subExpressionsAreNotEqual := gofmt(n.X) != gofmt(n.Y)
44+
subExpressionsAreNotEqual := astutils.GoFmt(n.X) != astutils.GoFmt(n.Y)
4445
if subExpressionsAreNotEqual {
4546
return w // nothing to say
4647
}

rule/context_as_argument.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"go/ast"
66
"strings"
77

8+
"github.com/mgechev/revive/internal/astutils"
89
"github.com/mgechev/revive/lint"
910
)
1011

@@ -28,7 +29,7 @@ func (r *ContextAsArgumentRule) Apply(file *lint.File, _ lint.Arguments) []lint.
2829
// Flag any that show up after the first.
2930
isCtxStillAllowed := true
3031
for _, arg := range fnArgs {
31-
argIsCtx := isPkgDot(arg.Type, "context", "Context")
32+
argIsCtx := astutils.IsPkgDotName(arg.Type, "context", "Context")
3233
if argIsCtx && !isCtxStillAllowed {
3334
failures = append(failures, lint.Failure{
3435
Node: arg,
@@ -40,7 +41,7 @@ func (r *ContextAsArgumentRule) Apply(file *lint.File, _ lint.Arguments) []lint.
4041
break // only flag one
4142
}
4243

43-
typeName := gofmt(arg.Type)
44+
typeName := astutils.GoFmt(arg.Type)
4445
// a parameter of type context.Context is still allowed if the current arg type is in the allow types LookUpTable
4546
_, isCtxStillAllowed = r.allowTypes[typeName]
4647
}

rule/context_keys_type.go

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"go/ast"
66
"go/types"
77

8+
"github.com/mgechev/revive/internal/astutils"
89
"github.com/mgechev/revive/lint"
910
)
1011

@@ -51,15 +52,7 @@ func (w lintContextKeyTypes) Visit(n ast.Node) ast.Visitor {
5152

5253
func checkContextKeyType(w lintContextKeyTypes, x *ast.CallExpr) {
5354
f := w.file
54-
sel, ok := x.Fun.(*ast.SelectorExpr)
55-
if !ok {
56-
return
57-
}
58-
pkg, ok := sel.X.(*ast.Ident)
59-
if !ok || pkg.Name != "context" {
60-
return
61-
}
62-
if sel.Sel.Name != "WithValue" {
55+
if !astutils.IsPkgDotName(x.Fun, "context", "WithValue") {
6356
return
6457
}
6558

rule/datarace.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"go/ast"
66

7+
"github.com/mgechev/revive/internal/astutils"
78
"github.com/mgechev/revive/lint"
89
)
910

@@ -111,7 +112,7 @@ func (w lintFunctionForDataRaces) Visit(node ast.Node) ast.Visitor {
111112
return ok
112113
}
113114

114-
ids := pick(funcLit.Body, selectIDs)
115+
ids := astutils.PickNodes(funcLit.Body, selectIDs)
115116
for _, id := range ids {
116117
id := id.(*ast.Ident)
117118
_, isRangeID := w.rangeIDs[id.Obj]

rule/defer.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"go/ast"
66

7+
"github.com/mgechev/revive/internal/astutils"
78
"github.com/mgechev/revive/lint"
89
)
910

@@ -106,7 +107,7 @@ func (w lintDeferRule) Visit(node ast.Node) ast.Visitor {
106107
w.newFailure("return in a defer function has no effect", n, 1.0, lint.FailureCategoryLogic, deferOptionReturn)
107108
}
108109
case *ast.CallExpr:
109-
isCallToRecover := isIdent(n.Fun, "recover")
110+
isCallToRecover := astutils.IsIdent(n.Fun, "recover")
110111
switch {
111112
case !w.inADefer && isCallToRecover:
112113
// func fn() { recover() }
@@ -122,7 +123,7 @@ func (w lintDeferRule) Visit(node ast.Node) ast.Visitor {
122123
}
123124
return nil // no need to analyze the arguments of the function call
124125
case *ast.DeferStmt:
125-
if isIdent(n.Call.Fun, "recover") {
126+
if astutils.IsIdent(n.Call.Fun, "recover") {
126127
// defer recover()
127128
//
128129
// confidence is not truly 1 because this could be in a correctly-deferred func,

rule/enforce_map_style.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"go/ast"
66

7+
"github.com/mgechev/revive/internal/astutils"
78
"github.com/mgechev/revive/lint"
89
)
910

@@ -101,8 +102,7 @@ func (r *EnforceMapStyleRule) Apply(file *lint.File, _ lint.Arguments) []lint.Fa
101102
return true
102103
}
103104

104-
ident, ok := v.Fun.(*ast.Ident)
105-
if !ok || ident.Name != "make" {
105+
if !astutils.IsIdent(v.Fun, "make") {
106106
return true
107107
}
108108

rule/enforce_repeated_arg_type_style.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"go/ast"
66

7+
"github.com/mgechev/revive/internal/astutils"
78
"github.com/mgechev/revive/lint"
89
)
910

@@ -130,8 +131,8 @@ func (r *EnforceRepeatedArgTypeStyleRule) Apply(file *lint.File, _ lint.Argument
130131
if fn.Type.Params != nil {
131132
var prevType ast.Expr
132133
for _, field := range fn.Type.Params.List {
133-
prevTypeStr := gofmt(prevType)
134-
currentTypeStr := gofmt(field.Type)
134+
prevTypeStr := astutils.GoFmt(prevType)
135+
currentTypeStr := astutils.GoFmt(field.Type)
135136
if currentTypeStr == prevTypeStr {
136137
failures = append(failures, lint.Failure{
137138
Confidence: 1,
@@ -163,8 +164,8 @@ func (r *EnforceRepeatedArgTypeStyleRule) Apply(file *lint.File, _ lint.Argument
163164
if fn.Type.Results != nil {
164165
var prevType ast.Expr
165166
for _, field := range fn.Type.Results.List {
166-
prevTypeStr := gofmt(prevType)
167-
currentTypeStr := gofmt(field.Type)
167+
prevTypeStr := astutils.GoFmt(prevType)
168+
currentTypeStr := astutils.GoFmt(field.Type)
168169
if field.Names != nil && currentTypeStr == prevTypeStr {
169170
failures = append(failures, lint.Failure{
170171
Confidence: 1,

rule/enforce_slice_style.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"go/ast"
66

7+
"github.com/mgechev/revive/internal/astutils"
78
"github.com/mgechev/revive/lint"
89
)
910

@@ -117,8 +118,7 @@ func (r *EnforceSliceStyleRule) Apply(file *lint.File, _ lint.Arguments) []lint.
117118
return true
118119
}
119120

120-
ident, ok := v.Fun.(*ast.Ident)
121-
if !ok || ident.Name != "make" {
121+
if !astutils.IsIdent(v.Fun, "make") {
122122
return true
123123
}
124124

rule/error_naming.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"go/token"
77
"strings"
88

9+
"github.com/mgechev/revive/internal/astutils"
910
"github.com/mgechev/revive/lint"
1011
)
1112

@@ -56,7 +57,7 @@ func (w lintErrors) Visit(_ ast.Node) ast.Visitor {
5657
if !ok {
5758
continue
5859
}
59-
if !isPkgDot(ce.Fun, "errors", "New") && !isPkgDot(ce.Fun, "fmt", "Errorf") {
60+
if !astutils.IsPkgDotName(ce.Fun, "errors", "New") && !astutils.IsPkgDotName(ce.Fun, "fmt", "Errorf") {
6061
continue
6162
}
6263

rule/error_return.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package rule
33
import (
44
"go/ast"
55

6+
"github.com/mgechev/revive/internal/astutils"
67
"github.com/mgechev/revive/lint"
78
)
89

@@ -21,15 +22,15 @@ func (*ErrorReturnRule) Apply(file *lint.File, _ lint.Arguments) []lint.Failure
2122
}
2223

2324
funcResults := funcDecl.Type.Results.List
24-
isLastResultError := isIdent(funcResults[len(funcResults)-1].Type, "error")
25+
isLastResultError := astutils.IsIdent(funcResults[len(funcResults)-1].Type, "error")
2526
if isLastResultError {
2627
continue
2728
}
2829

2930
// An error return parameter should be the last parameter.
3031
// Flag any error parameters found before the last.
3132
for _, r := range funcResults[:len(funcResults)-1] {
32-
if isIdent(r.Type, "error") {
33+
if astutils.IsIdent(r.Type, "error") {
3334
failures = append(failures, lint.Failure{
3435
Category: lint.FailureCategoryStyle,
3536
Confidence: 0.9,

rule/errorf.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"regexp"
77
"strings"
88

9+
"github.com/mgechev/revive/internal/astutils"
910
"github.com/mgechev/revive/lint"
1011
)
1112

@@ -47,7 +48,7 @@ func (w lintErrorf) Visit(n ast.Node) ast.Visitor {
4748
if !ok || len(ce.Args) != 1 {
4849
return w
4950
}
50-
isErrorsNew := isPkgDot(ce.Fun, "errors", "New")
51+
isErrorsNew := astutils.IsPkgDotName(ce.Fun, "errors", "New")
5152
var isTestingError bool
5253
se, ok := ce.Fun.(*ast.SelectorExpr)
5354
if ok && se.Sel.Name == "Error" {
@@ -60,7 +61,7 @@ func (w lintErrorf) Visit(n ast.Node) ast.Visitor {
6061
}
6162
arg := ce.Args[0]
6263
ce, ok = arg.(*ast.CallExpr)
63-
if !ok || !isPkgDot(ce.Fun, "fmt", "Sprintf") {
64+
if !ok || !astutils.IsPkgDotName(ce.Fun, "fmt", "Sprintf") {
6465
return w
6566
}
6667
errorfPrefix := "fmt"

0 commit comments

Comments
 (0)