Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extra generate-database features #1817

Merged
merged 9 commits into from
Mar 20, 2025
6 changes: 6 additions & 0 deletions cmd/generate-database/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ This will initiate a call to `generate-database db mapper generate`,
which will then search for `//generate-database:mapper` directives in the same file
and process those.

The following flags are available:
* `--package` / `-p`: Package import paths to search for structs to parse. Defaults to the caller package. Can be used more than once.

#### File

Generally the first thing we will want to do for any newly generated file is to
Expand Down Expand Up @@ -83,6 +86,8 @@ Type | Description
:--- | :----
`objects` | Creates a basic SELECT statement of the form `SELECT <columns> FROM <table> ORDER BY <columns>`.
`objects-by-<FIELD>-and-<FIELD>...` | Parses a pre-existing SELECT statement variable declaration of the form produced by`objects`, and appends a `WHERE` clause with the given fields located in the associated struct. Specifically looks for a variable declaration of the form `var <entity>Objects = RegisterStmt("SQL String")`
`names` | Creates a basic SELECT statement of the form `SELECT <primary key> FROM <table> ORDER BY <primary key>`.
`names-by-<FIELD>-and-<FIELD>...` | Parses a pre-existing SELECT statement variable declaration of the form produced by`names`, and appends a `WHERE` clause with the given fields located in the associated struct. Specifically looks for a variable declaration of the form `var <entity>Objects = RegisterStmt("SQL String")`
`create` | Creates a basic INSERT statement of the form `INSERT INTO <table> VALUES`.
`create-or-replace` | Creates a basic INSERT statement of the form `INSERT OR REPLACE INTO <table> VALUES`.
`delete-by-<FIELD>-and-<FIELD>...` | Creates a DELETE statement of the form `DELETE FROM <table> WHERE <constraint>` where the constraint is based on the given fields of the associated struct.
Expand Down Expand Up @@ -123,6 +128,7 @@ Go function generation supports the following types:

Type | Description
:--- | :----
`GetNames` | Return a slice of primary keys for all rows in a table matching the filter. Cannot be used with composite keys.
`GetMany` | Return a slice of structs for all rows in a table matching the filter.
`GetOne` | Return a single struct corresponding to a row with the given primary keys. Depends on `GetMany`.
`ID` | Return the ID column from the table corresponding to the given primary keys.
Expand Down
174 changes: 103 additions & 71 deletions cmd/generate-database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func newDbMapper() *cobra.Command {
}

func newDbMapperGenerate() *cobra.Command {
var pkg string
var pkgs *[]string
var boilerplateFilename string

cmd := &cobra.Command{
Expand All @@ -76,21 +76,37 @@ func newDbMapperGenerate() *cobra.Command {
return errors.New("GOPACKAGE environment variable is not set")
}

return generate(pkg, boilerplateFilename)
return generate(*pkgs, boilerplateFilename)
},
}

flags := cmd.Flags()
flags.StringVarP(&pkg, "package", "p", "", "Go package where the entity struct is declared")
pkgs = flags.StringArrayP("package", "p", []string{}, "Go package where the entity struct is declared")
flags.StringVarP(&boilerplateFilename, "boilerplate-file", "b", "-", "Filename of the file where the mapper boilerplate is written to")

return cmd
}

const prefix = "//generate-database:mapper "

func generate(pkg string, boilerplateFilename string) error {
parsedPkg, err := packageLoad(pkg)
func generate(pkgs []string, boilerplateFilename string) error {
localPath, err := os.Getwd()
if err != nil {
return err
}

localPkg, err := packages.Load(&packages.Config{Mode: packages.NeedName}, localPath)
if err != nil {
return err
}

localPkgPath := localPkg[0].PkgPath

if len(pkgs) == 0 {
pkgs = []string{localPkgPath}
}

parsedPkgs, err := packageLoad(pkgs)
if err != nil {
return err
}
Expand All @@ -101,59 +117,61 @@ func generate(pkg string, boilerplateFilename string) error {
}

registeredSQLStmts := map[string]string{}
for _, goFile := range parsedPkg.CompiledGoFiles {
body, err := os.ReadFile(goFile)
if err != nil {
return err
}
for _, parsedPkg := range parsedPkgs {
for _, goFile := range parsedPkg.CompiledGoFiles {
body, err := os.ReadFile(goFile)
if err != nil {
return err
}

// Reset target to stdout
target := "-"

lines := strings.Split(string(body), "\n")
for _, line := range lines {
// Lazy matching for prefix, does not consider Go syntax and therefore
// lines starting with prefix, that are part of e.g. multiline strings
// match as well. This is highly unlikely to cause false positives.
if strings.HasPrefix(line, prefix) {
line = strings.TrimPrefix(line, prefix)

// Use csv parser to properly handle arguments surrounded by double quotes.
r := csv.NewReader(strings.NewReader(line))
r.Comma = ' ' // space
args, err := r.Read()
if err != nil {
return err
}
// Reset target to stdout
target := "-"

lines := strings.Split(string(body), "\n")
for _, line := range lines {
// Lazy matching for prefix, does not consider Go syntax and therefore
// lines starting with prefix, that are part of e.g. multiline strings
// match as well. This is highly unlikely to cause false positives.
if strings.HasPrefix(line, prefix) {
line = strings.TrimPrefix(line, prefix)

// Use csv parser to properly handle arguments surrounded by double quotes.
r := csv.NewReader(strings.NewReader(line))
r.Comma = ' ' // space
args, err := r.Read()
if err != nil {
return err
}

if len(args) == 0 {
return fmt.Errorf("command missing")
}
if len(args) == 0 {
return fmt.Errorf("command missing")
}

command := args[0]
command := args[0]

switch command {
case "target":
if len(args) != 2 {
return fmt.Errorf("invalid arguments for command target, one argument for the target filename: %s", line)
}
switch command {
case "target":
if len(args) != 2 {
return fmt.Errorf("invalid arguments for command target, one argument for the target filename: %s", line)
}

target = args[1]
case "reset":
err = commandReset(args[1:], target)
target = args[1]
case "reset":
err = commandReset(args[1:], parsedPkgs, target, localPkgPath)

case "stmt":
err = commandStmt(args[1:], target, parsedPkg, registeredSQLStmts)
case "stmt":
err = commandStmt(args[1:], target, parsedPkgs, registeredSQLStmts, localPkgPath)

case "method":
err = commandMethod(args[1:], target, parsedPkg, registeredSQLStmts)
case "method":
err = commandMethod(args[1:], target, parsedPkgs, registeredSQLStmts, localPkgPath)

default:
err = fmt.Errorf("unknown command: %s", command)
}
default:
err = fmt.Errorf("unknown command: %s", command)
}

if err != nil {
return err
if err != nil {
return err
}
}
}
}
Expand All @@ -162,7 +180,7 @@ func generate(pkg string, boilerplateFilename string) error {
return nil
}

func commandReset(commandLine []string, target string) error {
func commandReset(commandLine []string, parsedPkgs []*packages.Package, target string, localPkgPath string) error {
var err error

flags := pflag.NewFlagSet("", pflag.ContinueOnError)
Expand All @@ -174,15 +192,24 @@ func commandReset(commandLine []string, target string) error {
return err
}

err = file.Reset(target, db.Imports, *buildComment, *iface)
imports := db.Imports
for _, pkg := range parsedPkgs {
if pkg.PkgPath == localPkgPath {
continue
}

imports = append(imports, pkg.PkgPath)
}

err = file.Reset(target, imports, *buildComment, *iface)
if err != nil {
return err
}

return nil
}

func commandStmt(commandLine []string, target string, parsedPkg *packages.Package, registeredSQLStmts map[string]string) error {
func commandStmt(commandLine []string, target string, parsedPkgs []*packages.Package, registeredSQLStmts map[string]string, localPkgPath string) error {
var err error

flags := pflag.NewFlagSet("", pflag.ContinueOnError)
Expand All @@ -203,15 +230,15 @@ func commandStmt(commandLine []string, target string, parsedPkg *packages.Packag
return err
}

stmt, err := db.NewStmt(parsedPkg, *entity, kind, config, registeredSQLStmts)
stmt, err := db.NewStmt(localPkgPath, parsedPkgs, *entity, kind, config, registeredSQLStmts)
if err != nil {
return err
}

return file.Append(*entity, target, stmt, false)
}

func commandMethod(commandLine []string, target string, parsedPkg *packages.Package, registeredSQLStmts map[string]string) error {
func commandMethod(commandLine []string, target string, parsedPkgs []*packages.Package, registeredSQLStmts map[string]string, localPkgPath string) error {
var err error

flags := pflag.NewFlagSet("", pflag.ContinueOnError)
Expand All @@ -233,39 +260,44 @@ func commandMethod(commandLine []string, target string, parsedPkg *packages.Pack
return err
}

method, err := db.NewMethod(parsedPkg, *entity, kind, config, registeredSQLStmts)
method, err := db.NewMethod(localPkgPath, parsedPkgs, *entity, kind, config, registeredSQLStmts)
if err != nil {
return err
}

return file.Append(*entity, target, method, *iface)
}

func packageLoad(pkg string) (*packages.Package, error) {
var pkgPath string
if pkg != "" {
importPkg, err := build.Import(pkg, "", build.FindOnly)
if err != nil {
return nil, fmt.Errorf("Invalid import path %q: %w", pkg, err)
}
func packageLoad(pkgs []string) ([]*packages.Package, error) {
pkgPaths := []string{}

pkgPath = importPkg.Dir
} else {
var err error
pkgPath, err = os.Getwd()
if err != nil {
return nil, err
for _, pkg := range pkgs {
if pkg == "" {
var err error
localPath, err := os.Getwd()
if err != nil {
return nil, err
}

pkgPaths = append(pkgPaths, localPath)
} else {
importPkg, err := build.Import(pkg, "", build.FindOnly)
if err != nil {
return nil, fmt.Errorf("Invalid import path %q: %w", pkg, err)
}

pkgPaths = append(pkgPaths, importPkg.Dir)
}
}

parsedPkg, err := packages.Load(&packages.Config{
parsedPkgs, err := packages.Load(&packages.Config{
Mode: packages.LoadTypes | packages.NeedTypesInfo,
}, pkgPath)
}, pkgPaths...)
if err != nil {
return nil, err
}

return parsedPkg[0], nil
return parsedPkgs, nil
}

func parseParams(args []string) (map[string]string, error) {
Expand Down
6 changes: 3 additions & 3 deletions cmd/generate-database/db/lex.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,14 @@ func activeCriteria(filter []string, ignoredFilter []string) string {

// Return the code for a "dest" function, to be passed as parameter to
// selectObjects in order to scan a single row.
func destFunc(slice string, typ string, fields []*Field) string {
func destFunc(slice string, entity string, importType string, fields []*Field) string {
var builder strings.Builder
writeLine := func(line string) { builder.WriteString(fmt.Sprintf("%s\n", line)) }

writeLine(`func(scan func(dest ...any) error) error {`)

varName := lex.Minuscule(string(typ[0]))
writeLine(fmt.Sprintf("%s := %s{}", varName, typ))
varName := lex.Minuscule(string(entity[0]))
writeLine(fmt.Sprintf("%s := %s{}", varName, importType))

checkErr := func() {
writeLine("if err != nil {\nreturn err\n}")
Expand Down
36 changes: 29 additions & 7 deletions cmd/generate-database/db/mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ import (

// Mapping holds information for mapping database tables to a Go structure.
type Mapping struct {
Package string // Package of the Go struct
Name string // Name of the Go struct.
Fields []*Field // Metadata about the Go struct.
Filterable bool // Whether the Go struct has a Filter companion struct for filtering queries.
Filters []*Field // Metadata about the Go struct used for filter fields.
Type TableType // Type of table structure for this Go struct.
Local bool // Whether the entity is in the same package as the generated code.
FilterLocal bool // Whether the entity is in the same package as the generated code.
Package string // Package of the Go struct
Name string // Name of the Go struct.
Fields []*Field // Metadata about the Go struct.
Filterable bool // Whether the Go struct has a Filter companion struct for filtering queries.
Filters []*Field // Metadata about the Go struct used for filter fields.
Type TableType // Type of table structure for this Go struct.
}

// TableType represents the logical type of the table defined by the Go struct.
Expand Down Expand Up @@ -250,6 +252,26 @@ func (m *Mapping) FieldParamsMarshal(fields []*Field) string {
return strings.Join(args, ", ")
}

// ImportType returns the type of the entity for the mapping, prefixing the import package if necessary.
func (m *Mapping) ImportType() string {
name := lex.PascalCase(m.Name)
if m.Local {
return name
}

return m.Package + "." + lex.PascalCase(name)
}

// ImportFilterType returns the Filter type of the entity for the mapping, prefixing the import package if necessary.
func (m *Mapping) ImportFilterType() string {
name := lex.PascalCase(entityFilter(m.Name))
if m.FilterLocal {
return name
}

return m.Package + "." + name
}

// Field holds all information about a field in a Go struct that is relevant
// for database code generation.
type Field struct {
Expand Down Expand Up @@ -434,7 +456,7 @@ func (f *Field) JoinClause(mapping *Mapping, table string) (string, error) {
// to select the ID to insert into this table.
// - If a 'joinon' tag is present, but this table is not among the conditions, then the join will be considered indirect,
// and an empty string will be returned.
func (f *Field) InsertColumn(pkg *types.Package, mapping *Mapping, primaryTable string, defs map[*ast.Ident]types.Object, registeredSQLStmts map[string]string) (string, string, error) {
func (f *Field) InsertColumn(mapping *Mapping, primaryTable string, defs map[*ast.Ident]types.Object, registeredSQLStmts map[string]string) (string, string, error) {
var column string
var value string
var err error
Expand Down
Loading