@@ -717,18 +717,8 @@ func (m *Method) create(buf *file.Buffer, replace bool) error {
717
717
}
718
718
719
719
kind := "create"
720
- if mapping .Type != AssociationTable {
721
- if replace {
722
- kind = "create_or_replace"
723
- } else {
724
- buf .L ("// Check if a %s with the same key exists." , m .entity )
725
- buf .L ("exists, err := %sExists(ctx, db, %s)" , lex .PascalCase (m .entity ), strings .Join (nkParams , ", " ))
726
- m .ifErrNotNil (buf , true , "-1" , "fmt.Errorf(\" Failed to check for duplicates: %w\" , err)" )
727
- buf .L ("if exists {" )
728
- buf .L (` return -1, ErrConflict` )
729
- buf .L ("}" )
730
- buf .N ()
731
- }
720
+ if mapping .Type != AssociationTable && replace {
721
+ kind = "create_or_replace"
732
722
}
733
723
734
724
if mapping .Type == AssociationTable {
@@ -762,15 +752,29 @@ func (m *Method) create(buf *file.Buffer, replace bool) error {
762
752
763
753
if mapping .Type == AssociationTable {
764
754
m .ifErrNotNil (buf , true , fmt .Sprintf (`fmt.Errorf("Failed to get \"%s\" prepared statement: %%w", err)` , stmtCodeVar (m .entity , kind )))
765
- buf .L ("// Execute the statement. " )
766
- buf .L ("_, err = stmt.Exec(args...)" )
755
+ buf .L (`// Execute the statement.` )
756
+ buf .L (`_, err = stmt.Exec(args...)` )
757
+ buf .L (`var sqliteErr sqlite3.Error` )
758
+ buf .L (`if errors.As(err, &sqliteErr) {` )
759
+ buf .L (` if sqliteErr.Code == sqlite3.ErrConstraint {` )
760
+ buf .L (` return ErrConflict` )
761
+ buf .L (` }` )
762
+ buf .L (`}` )
763
+ buf .N ()
767
764
m .ifErrNotNil (buf , true , fmt .Sprintf (`fmt.Errorf("Failed to create \"%s\" entry: %%w", err)` , entityTable (m .entity , m .config ["table" ])))
768
765
} else {
769
766
m .ifErrNotNil (buf , true , "-1" , fmt .Sprintf (`fmt.Errorf("Failed to get \"%s\" prepared statement: %%w", err)` , stmtCodeVar (m .entity , kind )))
770
- buf .L ("// Execute the statement. " )
771
- buf .L ("result, err := stmt.Exec(args...)" )
767
+ buf .L (`// Execute the statement.` )
768
+ buf .L (`result, err := stmt.Exec(args...)` )
769
+ buf .L (`var sqliteErr sqlite3.Error` )
770
+ buf .L (`if errors.As(err, &sqliteErr) {` )
771
+ buf .L (` if sqliteErr.Code == sqlite3.ErrConstraint {` )
772
+ buf .L (` return -1, ErrConflict` )
773
+ buf .L (` }` )
774
+ buf .L (`}` )
775
+ buf .N ()
772
776
m .ifErrNotNil (buf , true , "-1" , fmt .Sprintf (`fmt.Errorf("Failed to create \"%s\" entry: %%w", err)` , entityTable (m .entity , m .config ["table" ])))
773
- buf .L (" id, err := result.LastInsertId()" )
777
+ buf .L (` id, err := result.LastInsertId()` )
774
778
m .ifErrNotNil (buf , true , "-1" , fmt .Sprintf (`fmt.Errorf("Failed to fetch \"%s\" entry ID: %%w", err)` , entityTable (m .entity , m .config ["table" ])))
775
779
}
776
780
}
@@ -1387,15 +1391,18 @@ func (m *Method) signature(buf *file.Buffer, isInterface bool) error {
1387
1391
}
1388
1392
}
1389
1393
1390
- return m .begin (buf , comment , args , rets , isInterface )
1391
- }
1394
+ m .begin (buf , mapping , comment , args , rets , isInterface )
1392
1395
1393
- func (m * Method ) begin (buf * file.Buffer , comment string , args string , rets string , isInterface bool ) error {
1394
- mapping , err := Parse (m .pkg , lex .PascalCase (m .entity ), m .kind )
1395
- if err != nil {
1396
- return fmt .Errorf ("Parse entity struct: %w" , err )
1396
+ if isInterface {
1397
+ return nil
1397
1398
}
1398
1399
1400
+ m .sqlTxCheck (buf , mapping )
1401
+
1402
+ return nil
1403
+ }
1404
+
1405
+ func (m * Method ) begin (buf * file.Buffer , mapping * Mapping , comment string , args string , rets string , isInterface bool ) {
1399
1406
name := ""
1400
1407
entity := lex .PascalCase (m .entity )
1401
1408
@@ -1471,8 +1478,43 @@ func (m *Method) begin(buf *file.Buffer, comment string, args string, rets strin
1471
1478
buf .L ("}()" )
1472
1479
buf .N ()
1473
1480
}
1481
+ }
1482
+
1483
+ func (m * Method ) sqlTxCheck (buf * file.Buffer , mapping * Mapping ) {
1484
+ txCheck := false
1485
+ rets := []string {}
1486
+
1487
+ switch operation (m .kind ) {
1488
+ case "GetMany" :
1489
+ if mapping .Type != EntityTable || len (mapping .RefFields ()) > 0 {
1490
+ rets = []string {"nil" }
1491
+ txCheck = true
1492
+ }
1493
+
1494
+ case "Create" :
1495
+ if mapping .Type == AssociationTable ||
1496
+ mapping .Type == ReferenceTable ||
1497
+ len (mapping .RefFields ()) > 0 ||
1498
+ m .ref != "" {
1499
+ txCheck = true
1500
+ }
1501
+
1502
+ case "Update" :
1503
+ if mapping .Type != EntityTable {
1504
+ txCheck = true
1505
+ }
1506
+ }
1474
1507
1475
- return nil
1508
+ if ! txCheck {
1509
+ return
1510
+ }
1511
+
1512
+ rets = append (rets , `fmt.Errorf("Committable DB connection (transaction) required")` )
1513
+ buf .L (`_, ok := db.(interface{ Commit() error })` )
1514
+ buf .L (`if !ok {` )
1515
+ buf .L (` return %s` , strings .Join (rets , ", " ))
1516
+ buf .L (`}` )
1517
+ buf .N ()
1476
1518
}
1477
1519
1478
1520
func (m * Method ) ifErrNotNil (buf * file.Buffer , newLine bool , rets ... string ) {
0 commit comments