@@ -10,6 +10,7 @@ import (
10
10
"github.com/actiontech/sqle/sqle/errors"
11
11
12
12
"github.com/pingcap/parser/ast"
13
+ "github.com/pingcap/parser/format"
13
14
_model "github.com/pingcap/parser/model"
14
15
)
15
16
@@ -415,7 +416,7 @@ func (i *MysqlDriverImpl) generateInsertRollbackSql(stmt *ast.InsertStmt) (strin
415
416
for n , name := range columnsName {
416
417
_ , isPk := pkColumnsName [name ]
417
418
if isPk {
418
- where = append (where , fmt .Sprintf ("%s = '%s'" , name , util . ExprFormat (value [n ])))
419
+ where = append (where , fmt .Sprintf ("%s = '%s'" , name , restore (value [n ])))
419
420
}
420
421
}
421
422
if len (where ) != len (pkColumnsName ) {
@@ -437,7 +438,7 @@ func (i *MysqlDriverImpl) generateInsertRollbackSql(stmt *ast.InsertStmt) (strin
437
438
name := setExpr .Column .Name .String ()
438
439
_ , isPk := pkColumnsName [name ]
439
440
if isPk {
440
- where = append (where , fmt .Sprintf ("%s = '%s'" , name , util . ExprFormat (setExpr .Expr )))
441
+ where = append (where , fmt .Sprintf ("%s = '%s'" , name , restore (setExpr .Expr )))
441
442
}
442
443
}
443
444
if len (where ) != len (pkColumnsName ) {
@@ -449,6 +450,18 @@ func (i *MysqlDriverImpl) generateInsertRollbackSql(stmt *ast.InsertStmt) (strin
449
450
return rollbackSql , "" , nil
450
451
}
451
452
453
+ // 还原抽象语法树节点至SQL
454
+ func restore (node ast.Node ) (sql string ) {
455
+ var buf strings.Builder
456
+ rc := format .NewRestoreCtx (format .DefaultRestoreFlags , & buf )
457
+
458
+ if err := node .Restore (rc ); err != nil {
459
+ return
460
+ }
461
+ sql = buf .String ()
462
+ return
463
+ }
464
+
452
465
// generateDeleteRollbackSql generate insert SQL for delete.
453
466
func (i * MysqlDriverImpl ) generateDeleteRollbackSql (stmt * ast.DeleteStmt ) (string , string , error ) {
454
467
// not support multi-table syntax
@@ -603,7 +616,7 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin
603
616
colChanged = true
604
617
if isPk {
605
618
isPkChanged = true
606
- pkValue = util . ExprFormat (l .Expr )
619
+ pkValue = restore (l .Expr )
607
620
}
608
621
}
609
622
}
@@ -682,12 +695,12 @@ func (i *MysqlDriverImpl) generateGetRecordsSql(expr string, tableName *ast.Tabl
682
695
recordSql = fmt .Sprintf ("%s AS %s" , recordSql , tableAlias )
683
696
}
684
697
if where != nil {
685
- recordSql = fmt .Sprintf ("%s WHERE %s" , recordSql , util . ExprFormat (where ))
698
+ recordSql = fmt .Sprintf ("%s WHERE %s" , recordSql , restore (where ))
686
699
}
687
700
if order != nil {
688
701
recordSql = fmt .Sprintf ("%s ORDER BY" , recordSql )
689
702
for _ , item := range order .Items {
690
- recordSql = fmt .Sprintf ("%s %s" , recordSql , util . ExprFormat (item .Expr ))
703
+ recordSql = fmt .Sprintf ("%s %s" , recordSql , restore (item .Expr ))
691
704
if item .Desc {
692
705
recordSql = fmt .Sprintf ("%s DESC" , recordSql )
693
706
}
0 commit comments