Skip to content

Commit 02b8e06

Browse files
authored
🐛 fix logic of parsing multiple columns (i.e. for PRIMARY KEYS, CONSTRAINT) (#193)
* ✨ impl parseAllColumns * 🐛 replacing getAllColumns with parseAllColumns (cherry picked from commit 0c7e33b) * 🐛 (parseAllColumns)fix for []quoted cases
1 parent b29e7fc commit 02b8e06

File tree

3 files changed

+177
-20
lines changed

3 files changed

+177
-20
lines changed

ddlmod.go

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,11 @@ var (
1717
indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]?(?s:.*?)ON (.*)$`, sqliteSeparator, sqliteSeparator))
1818
tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator))
1919
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
20-
columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator))
2120
columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator))
2221
defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`)
2322
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
2423
)
2524

26-
func getAllColumns(s string) []string {
27-
allMatches := columnsRegexp.FindAllStringSubmatch(s, -1)
28-
columns := make([]string, 0, len(allMatches))
29-
for _, matches := range allMatches {
30-
if len(matches) > 1 {
31-
columns = append(columns, matches[1])
32-
}
33-
}
34-
return columns
35-
}
36-
3725
type ddl struct {
3826
head string
3927
fields []string
@@ -110,9 +98,10 @@ func parseDDL(strs ...string) (*ddl, error) {
11098
if strings.HasPrefix(fUpper, "CONSTRAINT") {
11199
matches := uniqueRegexp.FindStringSubmatch(f)
112100
if len(matches) > 0 {
113-
if columns := getAllColumns(matches[1]); len(columns) == 1 {
101+
cols, err := parseAllColumns(matches[1])
102+
if err == nil && len(cols) == 1 {
114103
for idx, column := range result.columns {
115-
if column.NameValue.String == columns[0] {
104+
if column.NameValue.String == cols[0] {
116105
column.UniqueValue = sql.NullBool{Bool: true, Valid: true}
117106
result.columns[idx] = column
118107
break
@@ -123,12 +112,15 @@ func parseDDL(strs ...string) (*ddl, error) {
123112
continue
124113
}
125114
if strings.HasPrefix(fUpper, "PRIMARY KEY") {
126-
for _, name := range getAllColumns(f) {
127-
for idx, column := range result.columns {
128-
if column.NameValue.String == name {
129-
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
130-
result.columns[idx] = column
131-
break
115+
cols, err := parseAllColumns(f)
116+
if err == nil {
117+
for _, name := range cols {
118+
for idx, column := range result.columns {
119+
if column.NameValue.String == name {
120+
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
121+
result.columns[idx] = column
122+
break
123+
}
132124
}
133125
}
134126
}

ddlmod_parse_all_columns.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package sqlite
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
)
7+
8+
type parseAllColumnsState int
9+
10+
const (
11+
parseAllColumnsState_NONE parseAllColumnsState = iota
12+
parseAllColumnsState_Beginning
13+
parseAllColumnsState_ReadingRawName
14+
parseAllColumnsState_ReadingQuotedName
15+
parseAllColumnsState_EndOfName
16+
parseAllColumnsState_State_End
17+
)
18+
19+
func parseAllColumns(in string) ([]string, error) {
20+
s := []rune(in)
21+
columns := make([]string, 0)
22+
state := parseAllColumnsState_NONE
23+
quote := rune(0)
24+
name := make([]rune, 0)
25+
for i := 0; i < len(s); i++ {
26+
switch state {
27+
case parseAllColumnsState_NONE:
28+
if s[i] == '(' {
29+
state = parseAllColumnsState_Beginning
30+
}
31+
case parseAllColumnsState_Beginning:
32+
if isSpace(s[i]) {
33+
continue
34+
}
35+
if isQuote(s[i]) {
36+
state = parseAllColumnsState_ReadingQuotedName
37+
quote = s[i]
38+
continue
39+
}
40+
if s[i] == '[' {
41+
state = parseAllColumnsState_ReadingQuotedName
42+
quote = ']'
43+
continue
44+
} else if s[i] == ')' {
45+
return columns, fmt.Errorf("unexpected token: %s", string(s[i]))
46+
}
47+
state = parseAllColumnsState_ReadingRawName
48+
name = append(name, s[i])
49+
case parseAllColumnsState_ReadingRawName:
50+
if isSeparator(s[i]) {
51+
state = parseAllColumnsState_Beginning
52+
columns = append(columns, string(name))
53+
name = make([]rune, 0)
54+
continue
55+
}
56+
if s[i] == ')' {
57+
state = parseAllColumnsState_State_End
58+
columns = append(columns, string(name))
59+
}
60+
if isQuote(s[i]) {
61+
return nil, fmt.Errorf("unexpected token: %s", string(s[i]))
62+
}
63+
if isSpace(s[i]) {
64+
state = parseAllColumnsState_EndOfName
65+
columns = append(columns, string(name))
66+
name = make([]rune, 0)
67+
continue
68+
}
69+
name = append(name, s[i])
70+
case parseAllColumnsState_ReadingQuotedName:
71+
if s[i] == quote {
72+
// check if quote character is escaped
73+
if i+1 < len(s) && s[i+1] == quote {
74+
name = append(name, quote)
75+
i++
76+
continue
77+
}
78+
state = parseAllColumnsState_EndOfName
79+
columns = append(columns, string(name))
80+
name = make([]rune, 0)
81+
continue
82+
}
83+
name = append(name, s[i])
84+
case parseAllColumnsState_EndOfName:
85+
if isSpace(s[i]) {
86+
continue
87+
}
88+
if isSeparator(s[i]) {
89+
state = parseAllColumnsState_Beginning
90+
continue
91+
}
92+
if s[i] == ')' {
93+
state = parseAllColumnsState_State_End
94+
continue
95+
}
96+
return nil, fmt.Errorf("unexpected token: %s", string(s[i]))
97+
case parseAllColumnsState_State_End:
98+
break
99+
}
100+
}
101+
if state != parseAllColumnsState_State_End {
102+
return nil, errors.New("unexpected end")
103+
}
104+
return columns, nil
105+
}
106+
107+
func isSpace(r rune) bool {
108+
return r == ' ' || r == '\t'
109+
}
110+
111+
func isQuote(r rune) bool {
112+
return r == '`' || r == '"' || r == '\''
113+
}
114+
115+
func isSeparator(r rune) bool {
116+
return r == ','
117+
}

ddlmod_parse_all_columns_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package sqlite
2+
3+
import "testing"
4+
5+
func TestParseAllColumns(t *testing.T) {
6+
tc := []struct {
7+
name string
8+
input string
9+
expected []string
10+
}{
11+
{
12+
name: "Simple case",
13+
input: "PRIMARY KEY (column1, column2)",
14+
expected: []string{"column1", "column2"},
15+
},
16+
{
17+
name: "Quoted column name",
18+
input: "PRIMARY KEY (`column,xxx`, \"column 2\", \"column)3\", 'column''4', \"column\"\"5\")",
19+
expected: []string{"column,xxx", "column 2", "column)3", "column'4", "column\"5"},
20+
},
21+
{
22+
name: "Japanese column name",
23+
input: "PRIMARY KEY (カラム1, `カラム2`)",
24+
expected: []string{"カラム1", "カラム2"},
25+
},
26+
{
27+
name: "Column name quoted with []",
28+
input: "PRIMARY KEY ([column1], [column2])",
29+
expected: []string{"column1", "column2"},
30+
},
31+
}
32+
for _, tt := range tc {
33+
t.Run(tt.name, func(t *testing.T) {
34+
cols, err := parseAllColumns(tt.input)
35+
if err != nil {
36+
t.Errorf("Failed to parse columns: %s", err)
37+
}
38+
if len(cols) != len(tt.expected) {
39+
t.Errorf("Expected %d columns, got %d", len(tt.expected), len(cols))
40+
}
41+
for i, col := range cols {
42+
if col != tt.expected[i] {
43+
t.Errorf("Expected %s, got %s", tt.expected[i], col)
44+
}
45+
}
46+
})
47+
}
48+
}

0 commit comments

Comments
 (0)