diff --git a/session_insert_test.go b/session_insert_test.go index 88879ef65..d040c9e93 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -908,6 +908,15 @@ func TestInsertWhere(t *testing.T) { assert.True(t, has) assert.EqualValues(t, "trest3", j3.Name) assert.EqualValues(t, 3, j3.Index) + + inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). + Insert(map[string]interface{}{ + "repo_id": 1, + "name": "10';delete * from insert_where; --", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) } type NightlyRate struct { diff --git a/statement_args.go b/statement_args.go index 4ce336f48..23496443f 100644 --- a/statement_args.go +++ b/statement_args.go @@ -6,17 +6,60 @@ package xorm import ( "fmt" + "reflect" + "strings" + "time" "xorm.io/builder" "xorm.io/core" ) +func quoteNeeded(a interface{}) bool { + switch a.(type) { + case int, int8, int16, int32, int64: + return false + case uint, uint8, uint16, uint32, uint64: + return false + case float32, float64: + return false + case bool: + return false + case string: + return true + case time.Time, *time.Time: + return true + case builder.Builder, *builder.Builder: + return false + } + + t := reflect.TypeOf(a) + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return false + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return false + case reflect.Float32, reflect.Float64: + return false + case reflect.Bool: + return false + case reflect.String: + return true + } + + return true +} + +func convertArg(arg interface{}) string { + if quoteNeeded(arg) { + argv := fmt.Sprintf("%v", arg) + return "'" + strings.Replace(argv, "'", "''", -1) + "'" + } + + return fmt.Sprintf("%v", arg) +} + func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error { switch argv := arg.(type) { - case string: - if _, err := w.WriteString("'" + argv + "'"); err != nil { - return err - } case bool: if statement.Engine.dialect.DBType() == core.MSSQL { if argv { @@ -50,7 +93,7 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er return err } default: - if _, err := w.WriteString(fmt.Sprintf("%v", argv)); err != nil { + if _, err := w.WriteString(convertArg(arg)); err != nil { return err } }