Skip to content

Commit d081471

Browse files
authored
Merge pull request #1817 from masnax/multi-pkg
Extra `generate-database` features
2 parents 4080aae + a9fb4ad commit d081471

37 files changed

+869
-485
lines changed

cmd/generate-database/README.md

+6
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ This will initiate a call to `generate-database db mapper generate`,
3535
which will then search for `//generate-database:mapper` directives in the same file
3636
and process those.
3737

38+
The following flags are available:
39+
* `--package` / `-p`: Package import paths to search for structs to parse. Defaults to the caller package. Can be used more than once.
40+
3841
#### File
3942

4043
Generally the first thing we will want to do for any newly generated file is to
@@ -83,6 +86,8 @@ Type | Description
8386
:--- | :----
8487
`objects` | Creates a basic SELECT statement of the form `SELECT <columns> FROM <table> ORDER BY <columns>`.
8588
`objects-by-<FIELD>-and-<FIELD>...` | Parses a pre-existing SELECT statement variable declaration of the form produced by`objects`, and appends a `WHERE` clause with the given fields located in the associated struct. Specifically looks for a variable declaration of the form `var <entity>Objects = RegisterStmt("SQL String")`
89+
`names` | Creates a basic SELECT statement of the form `SELECT <primary key> FROM <table> ORDER BY <primary key>`.
90+
`names-by-<FIELD>-and-<FIELD>...` | Parses a pre-existing SELECT statement variable declaration of the form produced by`names`, and appends a `WHERE` clause with the given fields located in the associated struct. Specifically looks for a variable declaration of the form `var <entity>Objects = RegisterStmt("SQL String")`
8691
`create` | Creates a basic INSERT statement of the form `INSERT INTO <table> VALUES`.
8792
`create-or-replace` | Creates a basic INSERT statement of the form `INSERT OR REPLACE INTO <table> VALUES`.
8893
`delete-by-<FIELD>-and-<FIELD>...` | Creates a DELETE statement of the form `DELETE FROM <table> WHERE <constraint>` where the constraint is based on the given fields of the associated struct.
@@ -123,6 +128,7 @@ Go function generation supports the following types:
123128

124129
Type | Description
125130
:--- | :----
131+
`GetNames` | Return a slice of primary keys for all rows in a table matching the filter. Cannot be used with composite keys.
126132
`GetMany` | Return a slice of structs for all rows in a table matching the filter.
127133
`GetOne` | Return a single struct corresponding to a row with the given primary keys. Depends on `GetMany`.
128134
`ID` | Return the ID column from the table corresponding to the given primary keys.

cmd/generate-database/db.go

+103-71
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func newDbMapper() *cobra.Command {
6565
}
6666

6767
func newDbMapperGenerate() *cobra.Command {
68-
var pkg string
68+
var pkgs *[]string
6969
var boilerplateFilename string
7070

7171
cmd := &cobra.Command{
@@ -76,21 +76,37 @@ func newDbMapperGenerate() *cobra.Command {
7676
return errors.New("GOPACKAGE environment variable is not set")
7777
}
7878

79-
return generate(pkg, boilerplateFilename)
79+
return generate(*pkgs, boilerplateFilename)
8080
},
8181
}
8282

8383
flags := cmd.Flags()
84-
flags.StringVarP(&pkg, "package", "p", "", "Go package where the entity struct is declared")
84+
pkgs = flags.StringArrayP("package", "p", []string{}, "Go package where the entity struct is declared")
8585
flags.StringVarP(&boilerplateFilename, "boilerplate-file", "b", "-", "Filename of the file where the mapper boilerplate is written to")
8686

8787
return cmd
8888
}
8989

9090
const prefix = "//generate-database:mapper "
9191

92-
func generate(pkg string, boilerplateFilename string) error {
93-
parsedPkg, err := packageLoad(pkg)
92+
func generate(pkgs []string, boilerplateFilename string) error {
93+
localPath, err := os.Getwd()
94+
if err != nil {
95+
return err
96+
}
97+
98+
localPkg, err := packages.Load(&packages.Config{Mode: packages.NeedName}, localPath)
99+
if err != nil {
100+
return err
101+
}
102+
103+
localPkgPath := localPkg[0].PkgPath
104+
105+
if len(pkgs) == 0 {
106+
pkgs = []string{localPkgPath}
107+
}
108+
109+
parsedPkgs, err := packageLoad(pkgs)
94110
if err != nil {
95111
return err
96112
}
@@ -101,59 +117,61 @@ func generate(pkg string, boilerplateFilename string) error {
101117
}
102118

103119
registeredSQLStmts := map[string]string{}
104-
for _, goFile := range parsedPkg.CompiledGoFiles {
105-
body, err := os.ReadFile(goFile)
106-
if err != nil {
107-
return err
108-
}
120+
for _, parsedPkg := range parsedPkgs {
121+
for _, goFile := range parsedPkg.CompiledGoFiles {
122+
body, err := os.ReadFile(goFile)
123+
if err != nil {
124+
return err
125+
}
109126

110-
// Reset target to stdout
111-
target := "-"
112-
113-
lines := strings.Split(string(body), "\n")
114-
for _, line := range lines {
115-
// Lazy matching for prefix, does not consider Go syntax and therefore
116-
// lines starting with prefix, that are part of e.g. multiline strings
117-
// match as well. This is highly unlikely to cause false positives.
118-
if strings.HasPrefix(line, prefix) {
119-
line = strings.TrimPrefix(line, prefix)
120-
121-
// Use csv parser to properly handle arguments surrounded by double quotes.
122-
r := csv.NewReader(strings.NewReader(line))
123-
r.Comma = ' ' // space
124-
args, err := r.Read()
125-
if err != nil {
126-
return err
127-
}
127+
// Reset target to stdout
128+
target := "-"
129+
130+
lines := strings.Split(string(body), "\n")
131+
for _, line := range lines {
132+
// Lazy matching for prefix, does not consider Go syntax and therefore
133+
// lines starting with prefix, that are part of e.g. multiline strings
134+
// match as well. This is highly unlikely to cause false positives.
135+
if strings.HasPrefix(line, prefix) {
136+
line = strings.TrimPrefix(line, prefix)
137+
138+
// Use csv parser to properly handle arguments surrounded by double quotes.
139+
r := csv.NewReader(strings.NewReader(line))
140+
r.Comma = ' ' // space
141+
args, err := r.Read()
142+
if err != nil {
143+
return err
144+
}
128145

129-
if len(args) == 0 {
130-
return fmt.Errorf("command missing")
131-
}
146+
if len(args) == 0 {
147+
return fmt.Errorf("command missing")
148+
}
132149

133-
command := args[0]
150+
command := args[0]
134151

135-
switch command {
136-
case "target":
137-
if len(args) != 2 {
138-
return fmt.Errorf("invalid arguments for command target, one argument for the target filename: %s", line)
139-
}
152+
switch command {
153+
case "target":
154+
if len(args) != 2 {
155+
return fmt.Errorf("invalid arguments for command target, one argument for the target filename: %s", line)
156+
}
140157

141-
target = args[1]
142-
case "reset":
143-
err = commandReset(args[1:], target)
158+
target = args[1]
159+
case "reset":
160+
err = commandReset(args[1:], parsedPkgs, target, localPkgPath)
144161

145-
case "stmt":
146-
err = commandStmt(args[1:], target, parsedPkg, registeredSQLStmts)
162+
case "stmt":
163+
err = commandStmt(args[1:], target, parsedPkgs, registeredSQLStmts, localPkgPath)
147164

148-
case "method":
149-
err = commandMethod(args[1:], target, parsedPkg, registeredSQLStmts)
165+
case "method":
166+
err = commandMethod(args[1:], target, parsedPkgs, registeredSQLStmts, localPkgPath)
150167

151-
default:
152-
err = fmt.Errorf("unknown command: %s", command)
153-
}
168+
default:
169+
err = fmt.Errorf("unknown command: %s", command)
170+
}
154171

155-
if err != nil {
156-
return err
172+
if err != nil {
173+
return err
174+
}
157175
}
158176
}
159177
}
@@ -162,7 +180,7 @@ func generate(pkg string, boilerplateFilename string) error {
162180
return nil
163181
}
164182

165-
func commandReset(commandLine []string, target string) error {
183+
func commandReset(commandLine []string, parsedPkgs []*packages.Package, target string, localPkgPath string) error {
166184
var err error
167185

168186
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
@@ -174,15 +192,24 @@ func commandReset(commandLine []string, target string) error {
174192
return err
175193
}
176194

177-
err = file.Reset(target, db.Imports, *buildComment, *iface)
195+
imports := db.Imports
196+
for _, pkg := range parsedPkgs {
197+
if pkg.PkgPath == localPkgPath {
198+
continue
199+
}
200+
201+
imports = append(imports, pkg.PkgPath)
202+
}
203+
204+
err = file.Reset(target, imports, *buildComment, *iface)
178205
if err != nil {
179206
return err
180207
}
181208

182209
return nil
183210
}
184211

185-
func commandStmt(commandLine []string, target string, parsedPkg *packages.Package, registeredSQLStmts map[string]string) error {
212+
func commandStmt(commandLine []string, target string, parsedPkgs []*packages.Package, registeredSQLStmts map[string]string, localPkgPath string) error {
186213
var err error
187214

188215
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
@@ -203,15 +230,15 @@ func commandStmt(commandLine []string, target string, parsedPkg *packages.Packag
203230
return err
204231
}
205232

206-
stmt, err := db.NewStmt(parsedPkg, *entity, kind, config, registeredSQLStmts)
233+
stmt, err := db.NewStmt(localPkgPath, parsedPkgs, *entity, kind, config, registeredSQLStmts)
207234
if err != nil {
208235
return err
209236
}
210237

211238
return file.Append(*entity, target, stmt, false)
212239
}
213240

214-
func commandMethod(commandLine []string, target string, parsedPkg *packages.Package, registeredSQLStmts map[string]string) error {
241+
func commandMethod(commandLine []string, target string, parsedPkgs []*packages.Package, registeredSQLStmts map[string]string, localPkgPath string) error {
215242
var err error
216243

217244
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
@@ -233,39 +260,44 @@ func commandMethod(commandLine []string, target string, parsedPkg *packages.Pack
233260
return err
234261
}
235262

236-
method, err := db.NewMethod(parsedPkg, *entity, kind, config, registeredSQLStmts)
263+
method, err := db.NewMethod(localPkgPath, parsedPkgs, *entity, kind, config, registeredSQLStmts)
237264
if err != nil {
238265
return err
239266
}
240267

241268
return file.Append(*entity, target, method, *iface)
242269
}
243270

244-
func packageLoad(pkg string) (*packages.Package, error) {
245-
var pkgPath string
246-
if pkg != "" {
247-
importPkg, err := build.Import(pkg, "", build.FindOnly)
248-
if err != nil {
249-
return nil, fmt.Errorf("Invalid import path %q: %w", pkg, err)
250-
}
271+
func packageLoad(pkgs []string) ([]*packages.Package, error) {
272+
pkgPaths := []string{}
251273

252-
pkgPath = importPkg.Dir
253-
} else {
254-
var err error
255-
pkgPath, err = os.Getwd()
256-
if err != nil {
257-
return nil, err
274+
for _, pkg := range pkgs {
275+
if pkg == "" {
276+
var err error
277+
localPath, err := os.Getwd()
278+
if err != nil {
279+
return nil, err
280+
}
281+
282+
pkgPaths = append(pkgPaths, localPath)
283+
} else {
284+
importPkg, err := build.Import(pkg, "", build.FindOnly)
285+
if err != nil {
286+
return nil, fmt.Errorf("Invalid import path %q: %w", pkg, err)
287+
}
288+
289+
pkgPaths = append(pkgPaths, importPkg.Dir)
258290
}
259291
}
260292

261-
parsedPkg, err := packages.Load(&packages.Config{
293+
parsedPkgs, err := packages.Load(&packages.Config{
262294
Mode: packages.LoadTypes | packages.NeedTypesInfo,
263-
}, pkgPath)
295+
}, pkgPaths...)
264296
if err != nil {
265297
return nil, err
266298
}
267299

268-
return parsedPkg[0], nil
300+
return parsedPkgs, nil
269301
}
270302

271303
func parseParams(args []string) (map[string]string, error) {

cmd/generate-database/db/lex.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,14 @@ func activeCriteria(filter []string, ignoredFilter []string) string {
8282

8383
// Return the code for a "dest" function, to be passed as parameter to
8484
// selectObjects in order to scan a single row.
85-
func destFunc(slice string, typ string, fields []*Field) string {
85+
func destFunc(slice string, entity string, importType string, fields []*Field) string {
8686
var builder strings.Builder
8787
writeLine := func(line string) { builder.WriteString(fmt.Sprintf("%s\n", line)) }
8888

8989
writeLine(`func(scan func(dest ...any) error) error {`)
9090

91-
varName := lex.Minuscule(string(typ[0]))
92-
writeLine(fmt.Sprintf("%s := %s{}", varName, typ))
91+
varName := lex.Minuscule(string(entity[0]))
92+
writeLine(fmt.Sprintf("%s := %s{}", varName, importType))
9393

9494
checkErr := func() {
9595
writeLine("if err != nil {\nreturn err\n}")

cmd/generate-database/db/mapping.go

+29-7
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@ import (
1414

1515
// Mapping holds information for mapping database tables to a Go structure.
1616
type Mapping struct {
17-
Package string // Package of the Go struct
18-
Name string // Name of the Go struct.
19-
Fields []*Field // Metadata about the Go struct.
20-
Filterable bool // Whether the Go struct has a Filter companion struct for filtering queries.
21-
Filters []*Field // Metadata about the Go struct used for filter fields.
22-
Type TableType // Type of table structure for this Go struct.
17+
Local bool // Whether the entity is in the same package as the generated code.
18+
FilterLocal bool // Whether the entity is in the same package as the generated code.
19+
Package string // Package of the Go struct
20+
Name string // Name of the Go struct.
21+
Fields []*Field // Metadata about the Go struct.
22+
Filterable bool // Whether the Go struct has a Filter companion struct for filtering queries.
23+
Filters []*Field // Metadata about the Go struct used for filter fields.
24+
Type TableType // Type of table structure for this Go struct.
2325
}
2426

2527
// TableType represents the logical type of the table defined by the Go struct.
@@ -250,6 +252,26 @@ func (m *Mapping) FieldParamsMarshal(fields []*Field) string {
250252
return strings.Join(args, ", ")
251253
}
252254

255+
// ImportType returns the type of the entity for the mapping, prefixing the import package if necessary.
256+
func (m *Mapping) ImportType() string {
257+
name := lex.PascalCase(m.Name)
258+
if m.Local {
259+
return name
260+
}
261+
262+
return m.Package + "." + lex.PascalCase(name)
263+
}
264+
265+
// ImportFilterType returns the Filter type of the entity for the mapping, prefixing the import package if necessary.
266+
func (m *Mapping) ImportFilterType() string {
267+
name := lex.PascalCase(entityFilter(m.Name))
268+
if m.FilterLocal {
269+
return name
270+
}
271+
272+
return m.Package + "." + name
273+
}
274+
253275
// Field holds all information about a field in a Go struct that is relevant
254276
// for database code generation.
255277
type Field struct {
@@ -434,7 +456,7 @@ func (f *Field) JoinClause(mapping *Mapping, table string) (string, error) {
434456
// to select the ID to insert into this table.
435457
// - If a 'joinon' tag is present, but this table is not among the conditions, then the join will be considered indirect,
436458
// and an empty string will be returned.
437-
func (f *Field) InsertColumn(pkg *types.Package, mapping *Mapping, primaryTable string, defs map[*ast.Ident]types.Object, registeredSQLStmts map[string]string) (string, string, error) {
459+
func (f *Field) InsertColumn(mapping *Mapping, primaryTable string, defs map[*ast.Ident]types.Object, registeredSQLStmts map[string]string) (string, string, error) {
438460
var column string
439461
var value string
440462
var err error

0 commit comments

Comments
 (0)