Skip to content

Commit f6f0a47

Browse files
author
Daniel Rice
authored
Merge pull request #76 from danrice-square/fix_sql_injection
Prevent a SQL injection
2 parents afa27bf + 033350b commit f6f0a47

File tree

3 files changed

+46
-6
lines changed

3 files changed

+46
-6
lines changed

ast.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"bytes"
1313
"fmt"
1414
"io"
15+
"strings"
1516
)
1617

1718
// Instructions for creating new types: If a type needs to satisfy an
@@ -594,7 +595,7 @@ type SimpleTableExpr interface {
594595
func (*TableName) simpleTableExpr() {}
595596
func (*Subquery) simpleTableExpr() {}
596597

597-
// TableName represents a table name.
598+
// TableName represents a table name.
598599
type TableName struct {
599600
Name, Qualifier string
600601
}
@@ -1113,8 +1114,10 @@ type ColName struct {
11131114
}
11141115

11151116
var (
1116-
astBackquote = []byte("`")
1117-
astPeriod = []byte(".")
1117+
astBackquoteStr = "`"
1118+
astDoubleBackquoteStr = "``"
1119+
astBackquote = []byte(astBackquoteStr)
1120+
astPeriod = []byte(".")
11181121
)
11191122

11201123
func (node *ColName) Serialize(w Writer) error {
@@ -1129,12 +1132,13 @@ func (node *ColName) Serialize(w Writer) error {
11291132
return quoteName(w, node.Name)
11301133
}
11311134

1132-
// note: quoteName does not escape s. quoteName is indirectly
1135+
// note: quoteName escapes any backquote (`) characters in s. quoteName is indirectly
11331136
// called by builder.go, which checks that column/table names exist.
11341137
func quoteName(w io.Writer, s string) error {
11351138
if _, err := w.Write(astBackquote); err != nil {
11361139
return err
11371140
}
1141+
s = strings.ReplaceAll(s, astBackquoteStr, astDoubleBackquoteStr)
11381142
if _, err := io.WriteString(w, s); err != nil {
11391143
return err
11401144
}

table.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,10 +355,15 @@ func (t *Table) aliasOrName() string {
355355
return t.Name
356356
}
357357

358+
// Enclose a name in backquotes and escape any internal backquotes.
359+
func quoteNameStr(s string) string {
360+
return astBackquoteStr + strings.ReplaceAll(s, astBackquoteStr, astDoubleBackquoteStr) + astBackquoteStr
361+
}
362+
358363
// loadColumns loads a table's columns from a database. MySQL
359364
// specific.
360365
func (t *Table) loadColumns(db *sql.DB) error {
361-
rows, err := db.Query("SHOW FULL COLUMNS FROM " + t.Name)
366+
rows, err := db.Query("SHOW FULL COLUMNS FROM " + quoteNameStr(t.Name))
362367
if err != nil {
363368
return err
364369
}
@@ -450,7 +455,7 @@ func (t *Table) columnCount(name string) int {
450455
// loadKeys loads a table's keys (indexes) from a database. MySQL
451456
// specific.
452457
func (t *Table) loadKeys(db *sql.DB) error {
453-
rows, err := db.Query("SHOW INDEX FROM " + t.Name)
458+
rows, err := db.Query("SHOW INDEX FROM " + quoteNameStr(t.Name))
454459
if err != nil {
455460
return err
456461
}

table_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,37 @@ func TestLoadTable(t *testing.T) {
3636
fmt.Printf("%s\n", table)
3737
}
3838

39+
func TestLoadTableNameInjection(t *testing.T) {
40+
db := makeTestDB(t, objectsDDL)
41+
defer db.Close()
42+
43+
// Ensure the table name is quoted to avoid possible SQL injection.
44+
table, err := LoadTable(db.DB, "objects WHERE false")
45+
if table != nil {
46+
t.Fatalf("Expected nil table returned from injection attempt, got %v", table)
47+
}
48+
expectedError := "Error 1146: Table 'squalor_test.objects where false' doesn't exist"
49+
if err == nil {
50+
t.Fatalf("Expected error %q from injection attempt, got nil", expectedError)
51+
}
52+
if err.Error() != expectedError {
53+
t.Fatalf("Expected error %q from injection attempt, got %q", expectedError, err.Error())
54+
}
55+
56+
// Ensure the table name is quoted to avoid possible SQL injection.
57+
table, err = LoadTable(db.DB, "foo`;bar")
58+
if table != nil {
59+
t.Fatalf("Expected nil table returned from injection attempt, got %v", table)
60+
}
61+
expectedError = "Error 1146: Table 'squalor_test.foo`;bar' doesn't exist"
62+
if err == nil {
63+
t.Fatalf("Expected error %q from injection attempt, got nil", expectedError)
64+
}
65+
if err.Error() != expectedError {
66+
t.Fatalf("Expected error %q from injection attempt, got %q", expectedError, err.Error())
67+
}
68+
}
69+
3970
func TestGetKey(t *testing.T) {
4071
table := mustLoadTable(t, "objects")
4172

0 commit comments

Comments
 (0)