Skip to content
This repository was archived by the owner on Sep 7, 2021. It is now read-only.
This repository is currently being migrated. It's locked while the migration is in progress.

Commit 71947cf

Browse files
authored
Fix bug on insert where (#1436)
* fix bug on insert where * fix bug * fix lint
1 parent 691f6e7 commit 71947cf

File tree

2 files changed

+57
-5
lines changed

2 files changed

+57
-5
lines changed

session_insert_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,15 @@ func TestInsertWhere(t *testing.T) {
908908
assert.True(t, has)
909909
assert.EqualValues(t, "trest3", j3.Name)
910910
assert.EqualValues(t, 3, j3.Index)
911+
912+
inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1).
913+
SetExpr("`index`", "coalesce(MAX(`index`),0)+1").
914+
Insert(map[string]interface{}{
915+
"repo_id": 1,
916+
"name": "10';delete * from insert_where; --",
917+
})
918+
assert.NoError(t, err)
919+
assert.EqualValues(t, 1, inserted)
911920
}
912921

913922
type NightlyRate struct {

statement_args.go

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,60 @@ package xorm
66

77
import (
88
"fmt"
9+
"reflect"
10+
"strings"
11+
"time"
912

1013
"xorm.io/builder"
1114
"xorm.io/core"
1215
)
1316

17+
func quoteNeeded(a interface{}) bool {
18+
switch a.(type) {
19+
case int, int8, int16, int32, int64:
20+
return false
21+
case uint, uint8, uint16, uint32, uint64:
22+
return false
23+
case float32, float64:
24+
return false
25+
case bool:
26+
return false
27+
case string:
28+
return true
29+
case time.Time, *time.Time:
30+
return true
31+
case builder.Builder, *builder.Builder:
32+
return false
33+
}
34+
35+
t := reflect.TypeOf(a)
36+
switch t.Kind() {
37+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
38+
return false
39+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
40+
return false
41+
case reflect.Float32, reflect.Float64:
42+
return false
43+
case reflect.Bool:
44+
return false
45+
case reflect.String:
46+
return true
47+
}
48+
49+
return true
50+
}
51+
52+
func convertArg(arg interface{}) string {
53+
if quoteNeeded(arg) {
54+
argv := fmt.Sprintf("%v", arg)
55+
return "'" + strings.Replace(argv, "'", "''", -1) + "'"
56+
}
57+
58+
return fmt.Sprintf("%v", arg)
59+
}
60+
1461
func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error {
1562
switch argv := arg.(type) {
16-
case string:
17-
if _, err := w.WriteString("'" + argv + "'"); err != nil {
18-
return err
19-
}
2063
case bool:
2164
if statement.Engine.dialect.DBType() == core.MSSQL {
2265
if argv {
@@ -50,7 +93,7 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er
5093
return err
5194
}
5295
default:
53-
if _, err := w.WriteString(fmt.Sprintf("%v", argv)); err != nil {
96+
if _, err := w.WriteString(convertArg(arg)); err != nil {
5497
return err
5598
}
5699
}

0 commit comments

Comments
 (0)