|
| 1 | +// Copyright 2024 The Go Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a BSD-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +package fillswitch |
| 6 | + |
| 7 | +import ( |
| 8 | + "bytes" |
| 9 | + "fmt" |
| 10 | + "go/ast" |
| 11 | + "go/token" |
| 12 | + "go/types" |
| 13 | + |
| 14 | + "golang.org/x/tools/go/analysis" |
| 15 | + "golang.org/x/tools/go/ast/inspector" |
| 16 | +) |
| 17 | + |
| 18 | +// Diagnose computes diagnostics for switch statements with missing cases |
| 19 | +// overlapping with the provided start and end position. |
| 20 | +// |
| 21 | +// If either start or end is invalid, the entire package is inspected. |
| 22 | +func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Package, info *types.Info) []analysis.Diagnostic { |
| 23 | + var diags []analysis.Diagnostic |
| 24 | + nodeFilter := []ast.Node{(*ast.SwitchStmt)(nil), (*ast.TypeSwitchStmt)(nil)} |
| 25 | + inspect.Preorder(nodeFilter, func(n ast.Node) { |
| 26 | + if start.IsValid() && n.End() < start || |
| 27 | + end.IsValid() && n.Pos() > end { |
| 28 | + return // non-overlapping |
| 29 | + } |
| 30 | + |
| 31 | + var fix *analysis.SuggestedFix |
| 32 | + switch n := n.(type) { |
| 33 | + case *ast.SwitchStmt: |
| 34 | + fix = suggestedFixSwitch(n, pkg, info) |
| 35 | + case *ast.TypeSwitchStmt: |
| 36 | + fix = suggestedFixTypeSwitch(n, pkg, info) |
| 37 | + } |
| 38 | + |
| 39 | + if fix == nil { |
| 40 | + return |
| 41 | + } |
| 42 | + |
| 43 | + diags = append(diags, analysis.Diagnostic{ |
| 44 | + Message: fix.Message, |
| 45 | + Pos: n.Pos(), |
| 46 | + End: n.Pos() + token.Pos(len("switch")), |
| 47 | + SuggestedFixes: []analysis.SuggestedFix{*fix}, |
| 48 | + }) |
| 49 | + }) |
| 50 | + |
| 51 | + return diags |
| 52 | +} |
| 53 | + |
| 54 | +func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *types.Info) *analysis.SuggestedFix { |
| 55 | + if hasDefaultCase(stmt.Body) { |
| 56 | + return nil |
| 57 | + } |
| 58 | + |
| 59 | + namedType := namedTypeFromTypeSwitch(stmt, info) |
| 60 | + if namedType == nil { |
| 61 | + return nil |
| 62 | + } |
| 63 | + |
| 64 | + existingCases := caseTypes(stmt.Body, info) |
| 65 | + // Gather accessible package-level concrete types |
| 66 | + // that implement the switch interface type. |
| 67 | + scope := namedType.Obj().Pkg().Scope() |
| 68 | + var buf bytes.Buffer |
| 69 | + for _, name := range scope.Names() { |
| 70 | + obj := scope.Lookup(name) |
| 71 | + if tname, ok := obj.(*types.TypeName); !ok || tname.IsAlias() { |
| 72 | + continue // not a defined type |
| 73 | + } |
| 74 | + |
| 75 | + if types.IsInterface(obj.Type()) { |
| 76 | + continue |
| 77 | + } |
| 78 | + |
| 79 | + samePkg := obj.Pkg() == pkg |
| 80 | + if !samePkg && !obj.Exported() { |
| 81 | + continue // inaccessible |
| 82 | + } |
| 83 | + |
| 84 | + var key caseType |
| 85 | + if types.AssignableTo(obj.Type(), namedType.Obj().Type()) { |
| 86 | + key.named = obj.Type().(*types.Named) |
| 87 | + } else if ptr := types.NewPointer(obj.Type()); types.AssignableTo(ptr, namedType.Obj().Type()) { |
| 88 | + key.named = obj.Type().(*types.Named) |
| 89 | + key.ptr = true |
| 90 | + } |
| 91 | + |
| 92 | + if key.named != nil { |
| 93 | + if existingCases[key] { |
| 94 | + continue |
| 95 | + } |
| 96 | + |
| 97 | + if buf.Len() > 0 { |
| 98 | + buf.WriteString("\t") |
| 99 | + } |
| 100 | + |
| 101 | + buf.WriteString("case ") |
| 102 | + if key.ptr { |
| 103 | + buf.WriteByte('*') |
| 104 | + } |
| 105 | + |
| 106 | + if p := key.named.Obj().Pkg(); p != pkg { |
| 107 | + // TODO: use the correct package name when the import is renamed |
| 108 | + buf.WriteString(p.Name()) |
| 109 | + buf.WriteByte('.') |
| 110 | + } |
| 111 | + buf.WriteString(key.named.Obj().Name()) |
| 112 | + buf.WriteString(":\n") |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + if buf.Len() == 0 { |
| 117 | + return nil |
| 118 | + } |
| 119 | + |
| 120 | + switch assign := stmt.Assign.(type) { |
| 121 | + case *ast.AssignStmt: |
| 122 | + addDefaultCase(&buf, namedType, assign.Lhs[0]) |
| 123 | + case *ast.ExprStmt: |
| 124 | + if assert, ok := assign.X.(*ast.TypeAssertExpr); ok { |
| 125 | + addDefaultCase(&buf, namedType, assert.X) |
| 126 | + } |
| 127 | + } |
| 128 | + |
| 129 | + return &analysis.SuggestedFix{ |
| 130 | + Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), |
| 131 | + TextEdits: []analysis.TextEdit{{ |
| 132 | + Pos: stmt.End() - token.Pos(len("}")), |
| 133 | + End: stmt.End() - token.Pos(len("}")), |
| 134 | + NewText: buf.Bytes(), |
| 135 | + }}, |
| 136 | + } |
| 137 | +} |
| 138 | + |
| 139 | +func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.Info) *analysis.SuggestedFix { |
| 140 | + if hasDefaultCase(stmt.Body) { |
| 141 | + return nil |
| 142 | + } |
| 143 | + |
| 144 | + namedType, ok := info.TypeOf(stmt.Tag).(*types.Named) |
| 145 | + if !ok { |
| 146 | + return nil |
| 147 | + } |
| 148 | + |
| 149 | + existingCases := caseConsts(stmt.Body, info) |
| 150 | + // Gather accessible named constants of the same type as the switch value. |
| 151 | + scope := namedType.Obj().Pkg().Scope() |
| 152 | + var buf bytes.Buffer |
| 153 | + for _, name := range scope.Names() { |
| 154 | + obj := scope.Lookup(name) |
| 155 | + if c, ok := obj.(*types.Const); ok && |
| 156 | + (obj.Pkg() == pkg || obj.Exported()) && // accessible |
| 157 | + types.Identical(obj.Type(), namedType.Obj().Type()) && |
| 158 | + !existingCases[c] { |
| 159 | + |
| 160 | + if buf.Len() > 0 { |
| 161 | + buf.WriteString("\t") |
| 162 | + } |
| 163 | + |
| 164 | + buf.WriteString("case ") |
| 165 | + if c.Pkg() != pkg { |
| 166 | + buf.WriteString(c.Pkg().Name()) |
| 167 | + buf.WriteByte('.') |
| 168 | + } |
| 169 | + buf.WriteString(c.Name()) |
| 170 | + buf.WriteString(":\n") |
| 171 | + } |
| 172 | + } |
| 173 | + |
| 174 | + if buf.Len() == 0 { |
| 175 | + return nil |
| 176 | + } |
| 177 | + |
| 178 | + addDefaultCase(&buf, namedType, stmt.Tag) |
| 179 | + |
| 180 | + return &analysis.SuggestedFix{ |
| 181 | + Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), |
| 182 | + TextEdits: []analysis.TextEdit{{ |
| 183 | + Pos: stmt.End() - token.Pos(len("}")), |
| 184 | + End: stmt.End() - token.Pos(len("}")), |
| 185 | + NewText: buf.Bytes(), |
| 186 | + }}, |
| 187 | + } |
| 188 | +} |
| 189 | + |
| 190 | +func addDefaultCase(buf *bytes.Buffer, named *types.Named, expr ast.Expr) { |
| 191 | + var dottedBuf bytes.Buffer |
| 192 | + // writeDotted emits a dotted path a.b.c. |
| 193 | + var writeDotted func(e ast.Expr) bool |
| 194 | + writeDotted = func(e ast.Expr) bool { |
| 195 | + switch e := e.(type) { |
| 196 | + case *ast.SelectorExpr: |
| 197 | + if !writeDotted(e.X) { |
| 198 | + return false |
| 199 | + } |
| 200 | + dottedBuf.WriteByte('.') |
| 201 | + dottedBuf.WriteString(e.Sel.Name) |
| 202 | + return true |
| 203 | + case *ast.Ident: |
| 204 | + dottedBuf.WriteString(e.Name) |
| 205 | + return true |
| 206 | + } |
| 207 | + return false |
| 208 | + } |
| 209 | + |
| 210 | + buf.WriteString("\tdefault:\n") |
| 211 | + typeName := fmt.Sprintf("%s.%s", named.Obj().Pkg().Name(), named.Obj().Name()) |
| 212 | + if writeDotted(expr) { |
| 213 | + // Switch tag expression is a dotted path. |
| 214 | + // It is safe to re-evaluate it in the default case. |
| 215 | + format := fmt.Sprintf("unexpected %s: %%#v", typeName) |
| 216 | + fmt.Fprintf(buf, "\t\tpanic(fmt.Sprintf(%q, %s))\n\t", format, dottedBuf.String()) |
| 217 | + } else { |
| 218 | + // Emit simpler message, without re-evaluating tag expression. |
| 219 | + fmt.Fprintf(buf, "\t\tpanic(%q)\n\t", "unexpected "+typeName) |
| 220 | + } |
| 221 | +} |
| 222 | + |
| 223 | +func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) *types.Named { |
| 224 | + switch assign := stmt.Assign.(type) { |
| 225 | + case *ast.ExprStmt: |
| 226 | + if typ, ok := assign.X.(*ast.TypeAssertExpr); ok { |
| 227 | + if named, ok := info.TypeOf(typ.X).(*types.Named); ok { |
| 228 | + return named |
| 229 | + } |
| 230 | + } |
| 231 | + |
| 232 | + case *ast.AssignStmt: |
| 233 | + if typ, ok := assign.Rhs[0].(*ast.TypeAssertExpr); ok { |
| 234 | + if named, ok := info.TypeOf(typ.X).(*types.Named); ok { |
| 235 | + return named |
| 236 | + } |
| 237 | + } |
| 238 | + } |
| 239 | + |
| 240 | + return nil |
| 241 | +} |
| 242 | + |
| 243 | +func hasDefaultCase(body *ast.BlockStmt) bool { |
| 244 | + for _, clause := range body.List { |
| 245 | + if len(clause.(*ast.CaseClause).List) == 0 { |
| 246 | + return true |
| 247 | + } |
| 248 | + } |
| 249 | + |
| 250 | + return false |
| 251 | +} |
| 252 | + |
| 253 | +func caseConsts(body *ast.BlockStmt, info *types.Info) map[*types.Const]bool { |
| 254 | + out := map[*types.Const]bool{} |
| 255 | + for _, stmt := range body.List { |
| 256 | + for _, e := range stmt.(*ast.CaseClause).List { |
| 257 | + if info.Types[e].Value == nil { |
| 258 | + continue // not a constant |
| 259 | + } |
| 260 | + |
| 261 | + if sel, ok := e.(*ast.SelectorExpr); ok { |
| 262 | + e = sel.Sel // replace pkg.C with C |
| 263 | + } |
| 264 | + |
| 265 | + if e, ok := e.(*ast.Ident); ok { |
| 266 | + if c, ok := info.Uses[e].(*types.Const); ok { |
| 267 | + out[c] = true |
| 268 | + } |
| 269 | + } |
| 270 | + } |
| 271 | + } |
| 272 | + |
| 273 | + return out |
| 274 | +} |
| 275 | + |
| 276 | +type caseType struct { |
| 277 | + named *types.Named |
| 278 | + ptr bool |
| 279 | +} |
| 280 | + |
| 281 | +func caseTypes(body *ast.BlockStmt, info *types.Info) map[caseType]bool { |
| 282 | + out := map[caseType]bool{} |
| 283 | + for _, stmt := range body.List { |
| 284 | + for _, e := range stmt.(*ast.CaseClause).List { |
| 285 | + if tv, ok := info.Types[e]; ok && tv.IsType() { |
| 286 | + t := tv.Type |
| 287 | + ptr := false |
| 288 | + if p, ok := t.(*types.Pointer); ok { |
| 289 | + t = p.Elem() |
| 290 | + ptr = true |
| 291 | + } |
| 292 | + |
| 293 | + if named, ok := t.(*types.Named); ok { |
| 294 | + out[caseType{named, ptr}] = true |
| 295 | + } |
| 296 | + } |
| 297 | + } |
| 298 | + } |
| 299 | + |
| 300 | + return out |
| 301 | +} |
0 commit comments