Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generate-database: Abstract db connection / db transaction #1721

Merged
merged 3 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/generate-database/db/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package db

// Imports is a list of the package imports every generated source file has.
var Imports = []string{
"context",
"database/sql",
"fmt",
"strings",
}
90 changes: 45 additions & 45 deletions cmd/generate-database/db/method.go

Large diffs are not rendered by default.

20 changes: 16 additions & 4 deletions cmd/generate-database/file/boilerplate/boilerplate.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ import (
"fmt"
)

type dbtx interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}

// RegisterStmt register a SQL statement.
//
// Registered statements will be prepared upfront and re-used, to speed up
Expand Down Expand Up @@ -42,13 +49,18 @@ var stmts = map[int]string{} // Statement code to statement SQL text.
var PreparedStmts = map[int]*sql.Stmt{}

// Stmt prepares the in-memory prepared statement for the transaction.
func Stmt(tx *sql.Tx, code int) (*sql.Stmt, error) {
func Stmt(db dbtx, code int) (*sql.Stmt, error) {
stmt, ok := PreparedStmts[code]
if !ok {
return nil, fmt.Errorf("No prepared statement registered with code %d", code)
}

return tx.Stmt(stmt), nil
tx, ok := db.(*sql.Tx)
if ok {
return tx.Stmt(stmt), nil
}

return stmt, nil
}

// StmtString returns the in-memory query string with the given code.
Expand Down Expand Up @@ -138,8 +150,8 @@ func selectObjects(ctx context.Context, stmt *sql.Stmt, rowFunc dest, args ...an

// scan runs a query with inArgs and provides the rowFunc with the scan function for each row.
// It handles closing the rows and errors from the result set.
func scan(ctx context.Context, tx *sql.Tx, sqlStmt string, rowFunc dest, inArgs ...any) error {
rows, err := tx.QueryContext(ctx, sqlStmt, inArgs...)
func scan(ctx context.Context, db dbtx, sqlStmt string, rowFunc dest, inArgs ...any) error {
rows, err := db.QueryContext(ctx, sqlStmt, inArgs...)
if err != nil {
return err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,23 @@

package cluster

import (
"context"
"database/sql"
)
import "context"

// CertificateProjectGenerated is an interface of generated methods for CertificateProject.
type CertificateProjectGenerated interface {
// GetCertificateProjects returns all available Projects for the Certificate.
// generator: certificate_project GetMany
GetCertificateProjects(ctx context.Context, tx *sql.Tx, certificateID int) ([]Project, error)
GetCertificateProjects(ctx context.Context, db dbtx, certificateID int) ([]Project, error)

// DeleteCertificateProjects deletes the certificate_project matching the given key parameters.
// generator: certificate_project DeleteMany
DeleteCertificateProjects(ctx context.Context, tx *sql.Tx, certificateID int) error
DeleteCertificateProjects(ctx context.Context, db dbtx, certificateID int) error

// CreateCertificateProjects adds a new certificate_project to the database.
// generator: certificate_project Create
CreateCertificateProjects(ctx context.Context, tx *sql.Tx, objects []CertificateProject) error
CreateCertificateProjects(ctx context.Context, db dbtx, objects []CertificateProject) error

// UpdateCertificateProjects updates the certificate_project matching the given key parameters.
// generator: certificate_project Update
UpdateCertificateProjects(ctx context.Context, tx *sql.Tx, certificateID int, projectNames []string) error
UpdateCertificateProjects(ctx context.Context, db dbtx, certificateID int, projectNames []string) error
}
26 changes: 13 additions & 13 deletions internal/server/db/cluster/certificate_projects.mapper.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 8 additions & 9 deletions internal/server/db/cluster/certificates.interface.mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package cluster

import (
"context"
"database/sql"

"github.com/lxc/incus/v6/internal/server/certificate"
)
Expand All @@ -13,33 +12,33 @@ import (
type CertificateGenerated interface {
// GetCertificates returns all available certificates.
// generator: certificate GetMany
GetCertificates(ctx context.Context, tx *sql.Tx, filters ...CertificateFilter) ([]Certificate, error)
GetCertificates(ctx context.Context, db dbtx, filters ...CertificateFilter) ([]Certificate, error)

// GetCertificate returns the certificate with the given key.
// generator: certificate GetOne
GetCertificate(ctx context.Context, tx *sql.Tx, fingerprint string) (*Certificate, error)
GetCertificate(ctx context.Context, db dbtx, fingerprint string) (*Certificate, error)

// GetCertificateID return the ID of the certificate with the given key.
// generator: certificate ID
GetCertificateID(ctx context.Context, tx *sql.Tx, fingerprint string) (int64, error)
GetCertificateID(ctx context.Context, db dbtx, fingerprint string) (int64, error)

// CertificateExists checks if a certificate with the given key exists.
// generator: certificate Exists
CertificateExists(ctx context.Context, tx *sql.Tx, fingerprint string) (bool, error)
CertificateExists(ctx context.Context, db dbtx, fingerprint string) (bool, error)

// CreateCertificate adds a new certificate to the database.
// generator: certificate Create
CreateCertificate(ctx context.Context, tx *sql.Tx, object Certificate) (int64, error)
CreateCertificate(ctx context.Context, db dbtx, object Certificate) (int64, error)

// DeleteCertificate deletes the certificate matching the given key parameters.
// generator: certificate DeleteOne-by-Fingerprint
DeleteCertificate(ctx context.Context, tx *sql.Tx, fingerprint string) error
DeleteCertificate(ctx context.Context, db dbtx, fingerprint string) error

// DeleteCertificates deletes the certificate matching the given key parameters.
// generator: certificate DeleteMany-by-Name-and-Type
DeleteCertificates(ctx context.Context, tx *sql.Tx, name string, certificateType certificate.Type) error
DeleteCertificates(ctx context.Context, db dbtx, name string, certificateType certificate.Type) error

// UpdateCertificate updates the certificate matching the given key parameters.
// generator: certificate Update
UpdateCertificate(ctx context.Context, tx *sql.Tx, fingerprint string, object Certificate) error
UpdateCertificate(ctx context.Context, db dbtx, fingerprint string, object Certificate) error
}
Loading