Skip to content

Commit 9e90938

Browse files
authored
Merge pull request #1745 from breml/generate-database-non-tx
generate-database: Handle non tx DB connections
2 parents a099032 + adc6787 commit 9e90938

15 files changed

+268
-104
lines changed

cmd/generate-database/db/constants.go

+1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ var Imports = []string{
88
"database/sql",
99
"fmt",
1010
"strings",
11+
"github.com/mattn/go-sqlite3",
1112
}

cmd/generate-database/db/method.go

+66-24
Original file line numberDiff line numberDiff line change
@@ -717,18 +717,8 @@ func (m *Method) create(buf *file.Buffer, replace bool) error {
717717
}
718718

719719
kind := "create"
720-
if mapping.Type != AssociationTable {
721-
if replace {
722-
kind = "create_or_replace"
723-
} else {
724-
buf.L("// Check if a %s with the same key exists.", m.entity)
725-
buf.L("exists, err := %sExists(ctx, db, %s)", lex.PascalCase(m.entity), strings.Join(nkParams, ", "))
726-
m.ifErrNotNil(buf, true, "-1", "fmt.Errorf(\"Failed to check for duplicates: %w\", err)")
727-
buf.L("if exists {")
728-
buf.L(` return -1, ErrConflict`)
729-
buf.L("}")
730-
buf.N()
731-
}
720+
if mapping.Type != AssociationTable && replace {
721+
kind = "create_or_replace"
732722
}
733723

734724
if mapping.Type == AssociationTable {
@@ -762,15 +752,29 @@ func (m *Method) create(buf *file.Buffer, replace bool) error {
762752

763753
if mapping.Type == AssociationTable {
764754
m.ifErrNotNil(buf, true, fmt.Sprintf(`fmt.Errorf("Failed to get \"%s\" prepared statement: %%w", err)`, stmtCodeVar(m.entity, kind)))
765-
buf.L("// Execute the statement. ")
766-
buf.L("_, err = stmt.Exec(args...)")
755+
buf.L(`// Execute the statement.`)
756+
buf.L(`_, err = stmt.Exec(args...)`)
757+
buf.L(`var sqliteErr sqlite3.Error`)
758+
buf.L(`if errors.As(err, &sqliteErr) {`)
759+
buf.L(` if sqliteErr.Code == sqlite3.ErrConstraint {`)
760+
buf.L(` return ErrConflict`)
761+
buf.L(` }`)
762+
buf.L(`}`)
763+
buf.N()
767764
m.ifErrNotNil(buf, true, fmt.Sprintf(`fmt.Errorf("Failed to create \"%s\" entry: %%w", err)`, entityTable(m.entity, m.config["table"])))
768765
} else {
769766
m.ifErrNotNil(buf, true, "-1", fmt.Sprintf(`fmt.Errorf("Failed to get \"%s\" prepared statement: %%w", err)`, stmtCodeVar(m.entity, kind)))
770-
buf.L("// Execute the statement. ")
771-
buf.L("result, err := stmt.Exec(args...)")
767+
buf.L(`// Execute the statement.`)
768+
buf.L(`result, err := stmt.Exec(args...)`)
769+
buf.L(`var sqliteErr sqlite3.Error`)
770+
buf.L(`if errors.As(err, &sqliteErr) {`)
771+
buf.L(` if sqliteErr.Code == sqlite3.ErrConstraint {`)
772+
buf.L(` return -1, ErrConflict`)
773+
buf.L(` }`)
774+
buf.L(`}`)
775+
buf.N()
772776
m.ifErrNotNil(buf, true, "-1", fmt.Sprintf(`fmt.Errorf("Failed to create \"%s\" entry: %%w", err)`, entityTable(m.entity, m.config["table"])))
773-
buf.L("id, err := result.LastInsertId()")
777+
buf.L(`id, err := result.LastInsertId()`)
774778
m.ifErrNotNil(buf, true, "-1", fmt.Sprintf(`fmt.Errorf("Failed to fetch \"%s\" entry ID: %%w", err)`, entityTable(m.entity, m.config["table"])))
775779
}
776780
}
@@ -1387,15 +1391,18 @@ func (m *Method) signature(buf *file.Buffer, isInterface bool) error {
13871391
}
13881392
}
13891393

1390-
return m.begin(buf, comment, args, rets, isInterface)
1391-
}
1394+
m.begin(buf, mapping, comment, args, rets, isInterface)
13921395

1393-
func (m *Method) begin(buf *file.Buffer, comment string, args string, rets string, isInterface bool) error {
1394-
mapping, err := Parse(m.pkg, lex.PascalCase(m.entity), m.kind)
1395-
if err != nil {
1396-
return fmt.Errorf("Parse entity struct: %w", err)
1396+
if isInterface {
1397+
return nil
13971398
}
13981399

1400+
m.sqlTxCheck(buf, mapping)
1401+
1402+
return nil
1403+
}
1404+
1405+
func (m *Method) begin(buf *file.Buffer, mapping *Mapping, comment string, args string, rets string, isInterface bool) {
13991406
name := ""
14001407
entity := lex.PascalCase(m.entity)
14011408

@@ -1471,8 +1478,43 @@ func (m *Method) begin(buf *file.Buffer, comment string, args string, rets strin
14711478
buf.L("}()")
14721479
buf.N()
14731480
}
1481+
}
1482+
1483+
func (m *Method) sqlTxCheck(buf *file.Buffer, mapping *Mapping) {
1484+
txCheck := false
1485+
rets := []string{}
1486+
1487+
switch operation(m.kind) {
1488+
case "GetMany":
1489+
if mapping.Type != EntityTable || len(mapping.RefFields()) > 0 {
1490+
rets = []string{"nil"}
1491+
txCheck = true
1492+
}
1493+
1494+
case "Create":
1495+
if mapping.Type == AssociationTable ||
1496+
mapping.Type == ReferenceTable ||
1497+
len(mapping.RefFields()) > 0 ||
1498+
m.ref != "" {
1499+
txCheck = true
1500+
}
1501+
1502+
case "Update":
1503+
if mapping.Type != EntityTable {
1504+
txCheck = true
1505+
}
1506+
}
14741507

1475-
return nil
1508+
if !txCheck {
1509+
return
1510+
}
1511+
1512+
rets = append(rets, `fmt.Errorf("Committable DB connection (transaction) required")`)
1513+
buf.L(`_, ok := db.(interface{ Commit() error })`)
1514+
buf.L(`if !ok {`)
1515+
buf.L(` return %s`, strings.Join(rets, ", "))
1516+
buf.L(`}`)
1517+
buf.N()
14761518
}
14771519

14781520
func (m *Method) ifErrNotNil(buf *file.Buffer, newLine bool, rets ...string) {

internal/server/db/cluster/certificate_projects.mapper.go

+25
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/server/db/cluster/certificates.mapper.go

+8-10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/server/db/cluster/cluster_groups.mapper.go

+14-10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/server/db/cluster/config.mapper.go

+10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/server/db/cluster/devices.mapper.go

+15
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)