Skip to content

Commit 3f322af

Browse files
authored
Merge pull request #3047 from actiontech/fix_generate_rollback_panic
Improves rollback SQL generation
2 parents 139b2a3 + 39ac7bc commit 3f322af

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

sqle/driver/mysql/rollback.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/actiontech/sqle/sqle/errors"
1111

1212
"github.com/pingcap/parser/ast"
13+
"github.com/pingcap/parser/format"
1314
_model "github.com/pingcap/parser/model"
1415
)
1516

@@ -415,7 +416,7 @@ func (i *MysqlDriverImpl) generateInsertRollbackSql(stmt *ast.InsertStmt) (strin
415416
for n, name := range columnsName {
416417
_, isPk := pkColumnsName[name]
417418
if isPk {
418-
where = append(where, fmt.Sprintf("%s = '%s'", name, util.ExprFormat(value[n])))
419+
where = append(where, fmt.Sprintf("%s = '%s'", name, restore(value[n])))
419420
}
420421
}
421422
if len(where) != len(pkColumnsName) {
@@ -437,7 +438,7 @@ func (i *MysqlDriverImpl) generateInsertRollbackSql(stmt *ast.InsertStmt) (strin
437438
name := setExpr.Column.Name.String()
438439
_, isPk := pkColumnsName[name]
439440
if isPk {
440-
where = append(where, fmt.Sprintf("%s = '%s'", name, util.ExprFormat(setExpr.Expr)))
441+
where = append(where, fmt.Sprintf("%s = '%s'", name, restore(setExpr.Expr)))
441442
}
442443
}
443444
if len(where) != len(pkColumnsName) {
@@ -449,6 +450,18 @@ func (i *MysqlDriverImpl) generateInsertRollbackSql(stmt *ast.InsertStmt) (strin
449450
return rollbackSql, "", nil
450451
}
451452

453+
// 还原抽象语法树节点至SQL
454+
func restore(node ast.Node) (sql string) {
455+
var buf strings.Builder
456+
rc := format.NewRestoreCtx(format.DefaultRestoreFlags, &buf)
457+
458+
if err := node.Restore(rc); err != nil {
459+
return
460+
}
461+
sql = buf.String()
462+
return
463+
}
464+
452465
// generateDeleteRollbackSql generate insert SQL for delete.
453466
func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (string, string, error) {
454467
// not support multi-table syntax
@@ -603,7 +616,7 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin
603616
colChanged = true
604617
if isPk {
605618
isPkChanged = true
606-
pkValue = util.ExprFormat(l.Expr)
619+
pkValue = restore(l.Expr)
607620
}
608621
}
609622
}
@@ -682,12 +695,12 @@ func (i *MysqlDriverImpl) generateGetRecordsSql(expr string, tableName *ast.Tabl
682695
recordSql = fmt.Sprintf("%s AS %s", recordSql, tableAlias)
683696
}
684697
if where != nil {
685-
recordSql = fmt.Sprintf("%s WHERE %s", recordSql, util.ExprFormat(where))
698+
recordSql = fmt.Sprintf("%s WHERE %s", recordSql, restore(where))
686699
}
687700
if order != nil {
688701
recordSql = fmt.Sprintf("%s ORDER BY", recordSql)
689702
for _, item := range order.Items {
690-
recordSql = fmt.Sprintf("%s %s", recordSql, util.ExprFormat(item.Expr))
703+
recordSql = fmt.Sprintf("%s %s", recordSql, restore(item.Expr))
691704
if item.Desc {
692705
recordSql = fmt.Sprintf("%s DESC", recordSql)
693706
}

0 commit comments

Comments
 (0)