Skip to content

Commit abe5874

Browse files
martskinsfindleyr
authored andcommitted
gopls/internal/analysis: add fill switch cases code action
This PR adds a code action to fill missing cases on type switches and switches on named types. Rules are defined here: golang/go#65411 (comment). Edit: I added some tests, but I'm sure there are still things to fix so sharing to get some feedback. Fixes golang/go#65411 https://github.com/golang/tools/assets/4250565/1e67c404-e24f-478e-a3df-60a3adfaa9b1 Change-Id: Ie4ef0955d0e7ca130af8980a488b738c812aae4d GitHub-Last-Rev: a04dc69 GitHub-Pull-Request: #476 Reviewed-on: https://go-review.googlesource.com/c/tools/+/561416 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Robert Findley <[email protected]> Reviewed-by: Alan Donovan <[email protected]>
1 parent fc70354 commit abe5874

File tree

8 files changed

+750
-3
lines changed

8 files changed

+750
-3
lines changed
+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// Package fillswitch identifies switches with missing cases.
6+
//
7+
// It reports a diagnostic for each type switch or 'enum' switch that
8+
// has missing cases, and suggests a fix to fill them in.
9+
//
10+
// The possible cases are: for a type switch, each accessible named
11+
// type T or pointer *T that is assignable to the interface type; and
12+
// for an 'enum' switch, each accessible named constant of the same
13+
// type as the switch value.
14+
//
15+
// For an 'enum' switch, it will suggest cases for all possible values of the
16+
// type.
17+
//
18+
// type Suit int8
19+
// const (
20+
// Spades Suit = iota
21+
// Hearts
22+
// Diamonds
23+
// Clubs
24+
// )
25+
//
26+
// var s Suit
27+
// switch s {
28+
// case Spades:
29+
// }
30+
//
31+
// It will report a diagnostic with a suggested fix to fill in the remaining
32+
// cases:
33+
//
34+
// var s Suit
35+
// switch s {
36+
// case Spades:
37+
// case Hearts:
38+
// case Diamonds:
39+
// case Clubs:
40+
// default:
41+
// panic(fmt.Sprintf("unexpected Suit: %v", s))
42+
// }
43+
//
44+
// For a type switch, it will suggest cases for all types that implement the
45+
// interface.
46+
//
47+
// var stmt ast.Stmt
48+
// switch stmt.(type) {
49+
// case *ast.IfStmt:
50+
// }
51+
//
52+
// It will report a diagnostic with a suggested fix to fill in the remaining
53+
// cases:
54+
//
55+
// var stmt ast.Stmt
56+
// switch stmt.(type) {
57+
// case *ast.IfStmt:
58+
// case *ast.ForStmt:
59+
// case *ast.RangeStmt:
60+
// case *ast.AssignStmt:
61+
// case *ast.GoStmt:
62+
// ...
63+
// default:
64+
// panic(fmt.Sprintf("unexpected ast.Stmt: %T", stmt))
65+
// }
66+
package fillswitch
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package fillswitch
6+
7+
import (
8+
"bytes"
9+
"fmt"
10+
"go/ast"
11+
"go/token"
12+
"go/types"
13+
14+
"golang.org/x/tools/go/analysis"
15+
"golang.org/x/tools/go/ast/inspector"
16+
)
17+
18+
// Diagnose computes diagnostics for switch statements with missing cases
19+
// overlapping with the provided start and end position.
20+
//
21+
// If either start or end is invalid, the entire package is inspected.
22+
func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Package, info *types.Info) []analysis.Diagnostic {
23+
var diags []analysis.Diagnostic
24+
nodeFilter := []ast.Node{(*ast.SwitchStmt)(nil), (*ast.TypeSwitchStmt)(nil)}
25+
inspect.Preorder(nodeFilter, func(n ast.Node) {
26+
if start.IsValid() && n.End() < start ||
27+
end.IsValid() && n.Pos() > end {
28+
return // non-overlapping
29+
}
30+
31+
var fix *analysis.SuggestedFix
32+
switch n := n.(type) {
33+
case *ast.SwitchStmt:
34+
fix = suggestedFixSwitch(n, pkg, info)
35+
case *ast.TypeSwitchStmt:
36+
fix = suggestedFixTypeSwitch(n, pkg, info)
37+
}
38+
39+
if fix == nil {
40+
return
41+
}
42+
43+
diags = append(diags, analysis.Diagnostic{
44+
Message: fix.Message,
45+
Pos: n.Pos(),
46+
End: n.Pos() + token.Pos(len("switch")),
47+
SuggestedFixes: []analysis.SuggestedFix{*fix},
48+
})
49+
})
50+
51+
return diags
52+
}
53+
54+
func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *types.Info) *analysis.SuggestedFix {
55+
if hasDefaultCase(stmt.Body) {
56+
return nil
57+
}
58+
59+
namedType := namedTypeFromTypeSwitch(stmt, info)
60+
if namedType == nil {
61+
return nil
62+
}
63+
64+
existingCases := caseTypes(stmt.Body, info)
65+
// Gather accessible package-level concrete types
66+
// that implement the switch interface type.
67+
scope := namedType.Obj().Pkg().Scope()
68+
var buf bytes.Buffer
69+
for _, name := range scope.Names() {
70+
obj := scope.Lookup(name)
71+
if tname, ok := obj.(*types.TypeName); !ok || tname.IsAlias() {
72+
continue // not a defined type
73+
}
74+
75+
if types.IsInterface(obj.Type()) {
76+
continue
77+
}
78+
79+
samePkg := obj.Pkg() == pkg
80+
if !samePkg && !obj.Exported() {
81+
continue // inaccessible
82+
}
83+
84+
var key caseType
85+
if types.AssignableTo(obj.Type(), namedType.Obj().Type()) {
86+
key.named = obj.Type().(*types.Named)
87+
} else if ptr := types.NewPointer(obj.Type()); types.AssignableTo(ptr, namedType.Obj().Type()) {
88+
key.named = obj.Type().(*types.Named)
89+
key.ptr = true
90+
}
91+
92+
if key.named != nil {
93+
if existingCases[key] {
94+
continue
95+
}
96+
97+
if buf.Len() > 0 {
98+
buf.WriteString("\t")
99+
}
100+
101+
buf.WriteString("case ")
102+
if key.ptr {
103+
buf.WriteByte('*')
104+
}
105+
106+
if p := key.named.Obj().Pkg(); p != pkg {
107+
// TODO: use the correct package name when the import is renamed
108+
buf.WriteString(p.Name())
109+
buf.WriteByte('.')
110+
}
111+
buf.WriteString(key.named.Obj().Name())
112+
buf.WriteString(":\n")
113+
}
114+
}
115+
116+
if buf.Len() == 0 {
117+
return nil
118+
}
119+
120+
switch assign := stmt.Assign.(type) {
121+
case *ast.AssignStmt:
122+
addDefaultCase(&buf, namedType, assign.Lhs[0])
123+
case *ast.ExprStmt:
124+
if assert, ok := assign.X.(*ast.TypeAssertExpr); ok {
125+
addDefaultCase(&buf, namedType, assert.X)
126+
}
127+
}
128+
129+
return &analysis.SuggestedFix{
130+
Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()),
131+
TextEdits: []analysis.TextEdit{{
132+
Pos: stmt.End() - token.Pos(len("}")),
133+
End: stmt.End() - token.Pos(len("}")),
134+
NewText: buf.Bytes(),
135+
}},
136+
}
137+
}
138+
139+
func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.Info) *analysis.SuggestedFix {
140+
if hasDefaultCase(stmt.Body) {
141+
return nil
142+
}
143+
144+
namedType, ok := info.TypeOf(stmt.Tag).(*types.Named)
145+
if !ok {
146+
return nil
147+
}
148+
149+
existingCases := caseConsts(stmt.Body, info)
150+
// Gather accessible named constants of the same type as the switch value.
151+
scope := namedType.Obj().Pkg().Scope()
152+
var buf bytes.Buffer
153+
for _, name := range scope.Names() {
154+
obj := scope.Lookup(name)
155+
if c, ok := obj.(*types.Const); ok &&
156+
(obj.Pkg() == pkg || obj.Exported()) && // accessible
157+
types.Identical(obj.Type(), namedType.Obj().Type()) &&
158+
!existingCases[c] {
159+
160+
if buf.Len() > 0 {
161+
buf.WriteString("\t")
162+
}
163+
164+
buf.WriteString("case ")
165+
if c.Pkg() != pkg {
166+
buf.WriteString(c.Pkg().Name())
167+
buf.WriteByte('.')
168+
}
169+
buf.WriteString(c.Name())
170+
buf.WriteString(":\n")
171+
}
172+
}
173+
174+
if buf.Len() == 0 {
175+
return nil
176+
}
177+
178+
addDefaultCase(&buf, namedType, stmt.Tag)
179+
180+
return &analysis.SuggestedFix{
181+
Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()),
182+
TextEdits: []analysis.TextEdit{{
183+
Pos: stmt.End() - token.Pos(len("}")),
184+
End: stmt.End() - token.Pos(len("}")),
185+
NewText: buf.Bytes(),
186+
}},
187+
}
188+
}
189+
190+
func addDefaultCase(buf *bytes.Buffer, named *types.Named, expr ast.Expr) {
191+
var dottedBuf bytes.Buffer
192+
// writeDotted emits a dotted path a.b.c.
193+
var writeDotted func(e ast.Expr) bool
194+
writeDotted = func(e ast.Expr) bool {
195+
switch e := e.(type) {
196+
case *ast.SelectorExpr:
197+
if !writeDotted(e.X) {
198+
return false
199+
}
200+
dottedBuf.WriteByte('.')
201+
dottedBuf.WriteString(e.Sel.Name)
202+
return true
203+
case *ast.Ident:
204+
dottedBuf.WriteString(e.Name)
205+
return true
206+
}
207+
return false
208+
}
209+
210+
buf.WriteString("\tdefault:\n")
211+
typeName := fmt.Sprintf("%s.%s", named.Obj().Pkg().Name(), named.Obj().Name())
212+
if writeDotted(expr) {
213+
// Switch tag expression is a dotted path.
214+
// It is safe to re-evaluate it in the default case.
215+
format := fmt.Sprintf("unexpected %s: %%#v", typeName)
216+
fmt.Fprintf(buf, "\t\tpanic(fmt.Sprintf(%q, %s))\n\t", format, dottedBuf.String())
217+
} else {
218+
// Emit simpler message, without re-evaluating tag expression.
219+
fmt.Fprintf(buf, "\t\tpanic(%q)\n\t", "unexpected "+typeName)
220+
}
221+
}
222+
223+
func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) *types.Named {
224+
switch assign := stmt.Assign.(type) {
225+
case *ast.ExprStmt:
226+
if typ, ok := assign.X.(*ast.TypeAssertExpr); ok {
227+
if named, ok := info.TypeOf(typ.X).(*types.Named); ok {
228+
return named
229+
}
230+
}
231+
232+
case *ast.AssignStmt:
233+
if typ, ok := assign.Rhs[0].(*ast.TypeAssertExpr); ok {
234+
if named, ok := info.TypeOf(typ.X).(*types.Named); ok {
235+
return named
236+
}
237+
}
238+
}
239+
240+
return nil
241+
}
242+
243+
func hasDefaultCase(body *ast.BlockStmt) bool {
244+
for _, clause := range body.List {
245+
if len(clause.(*ast.CaseClause).List) == 0 {
246+
return true
247+
}
248+
}
249+
250+
return false
251+
}
252+
253+
func caseConsts(body *ast.BlockStmt, info *types.Info) map[*types.Const]bool {
254+
out := map[*types.Const]bool{}
255+
for _, stmt := range body.List {
256+
for _, e := range stmt.(*ast.CaseClause).List {
257+
if info.Types[e].Value == nil {
258+
continue // not a constant
259+
}
260+
261+
if sel, ok := e.(*ast.SelectorExpr); ok {
262+
e = sel.Sel // replace pkg.C with C
263+
}
264+
265+
if e, ok := e.(*ast.Ident); ok {
266+
if c, ok := info.Uses[e].(*types.Const); ok {
267+
out[c] = true
268+
}
269+
}
270+
}
271+
}
272+
273+
return out
274+
}
275+
276+
type caseType struct {
277+
named *types.Named
278+
ptr bool
279+
}
280+
281+
func caseTypes(body *ast.BlockStmt, info *types.Info) map[caseType]bool {
282+
out := map[caseType]bool{}
283+
for _, stmt := range body.List {
284+
for _, e := range stmt.(*ast.CaseClause).List {
285+
if tv, ok := info.Types[e]; ok && tv.IsType() {
286+
t := tv.Type
287+
ptr := false
288+
if p, ok := t.(*types.Pointer); ok {
289+
t = p.Elem()
290+
ptr = true
291+
}
292+
293+
if named, ok := t.(*types.Named); ok {
294+
out[caseType{named, ptr}] = true
295+
}
296+
}
297+
}
298+
}
299+
300+
return out
301+
}

0 commit comments

Comments
 (0)