Skip to content

Commit 6b399a5

Browse files
committed
add inner sql support to remove dialect
1 parent 457a6c0 commit 6b399a5

File tree

3 files changed

+91
-17
lines changed

3 files changed

+91
-17
lines changed

pkg/data_query.go

+15-10
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ func (s *dbserver) Query(ctx context.Context, query *server.DataQuery) (result *
6464
return
6565
}
6666

67+
query.Sql = dbQuery.GetInnerSQL().ToNativeSQL(query.Sql)
6768
result.Meta.Labels = dbQuery.GetLabels(ctx, query.Sql)
6869

6970
var dataResult *server.DataQueryResult
@@ -170,27 +171,27 @@ type DataQuery interface {
170171
GetCurrentDatabase() (string, error)
171172
GetLabels(context.Context, string) map[string]string
172173
GetClient() *gorm.DB
174+
GetInnerSQL() InnerSQL
173175
}
174176

175177
type commonDataQuery struct {
176-
showDatabases, showTables, currentDatabase string
177-
db *gorm.DB
178+
showTables string
179+
db *gorm.DB
180+
innerSQL InnerSQL
178181
}
179182

180183
var _ DataQuery = &commonDataQuery{}
181184

182-
func NewCommonDataQuery(showDatabases, showTables, currentDatabase string, db *gorm.DB) DataQuery {
185+
func NewCommonDataQuery(innerSQL InnerSQL, db *gorm.DB) DataQuery {
183186
return &commonDataQuery{
184-
showDatabases: showDatabases,
185-
showTables: showTables,
186-
currentDatabase: currentDatabase,
187-
db: db,
187+
innerSQL: innerSQL,
188+
db: db,
188189
}
189190
}
190191

191192
func (q *commonDataQuery) GetDatabases(ctx context.Context) (databases []string, err error) {
192193
var databaseResult *server.DataQueryResult
193-
if databaseResult, err = sqlQuery(ctx, q.showDatabases, q.db); err == nil {
194+
if databaseResult, err = sqlQuery(ctx, q.GetInnerSQL().ToNativeSQL(innerShowDatabases), q.db); err == nil {
194195
for _, table := range databaseResult.Items {
195196
for _, item := range table.GetData() {
196197
if item.Key == "Database" || item.Key == "name" {
@@ -212,7 +213,7 @@ func (q *commonDataQuery) GetDatabases(ctx context.Context) (databases []string,
212213
}
213214

214215
func (q *commonDataQuery) GetTables(ctx context.Context, currentDatabase string) (tables []string, err error) {
215-
showTables := q.showTables
216+
showTables := q.GetInnerSQL().ToNativeSQL(innerShowTables)
216217
if strings.Contains(showTables, "%s") {
217218
showTables = fmt.Sprintf(showTables, currentDatabase)
218219
}
@@ -242,7 +243,7 @@ func (q *commonDataQuery) GetTables(ctx context.Context, currentDatabase string)
242243

243244
func (q *commonDataQuery) GetCurrentDatabase() (current string, err error) {
244245
var row *sql.Row
245-
if row = q.db.Raw(q.currentDatabase).Row(); row != nil {
246+
if row = q.db.Raw(q.GetInnerSQL().ToNativeSQL(innerCurrentDB)).Row(); row != nil {
246247
err = row.Scan(&current)
247248
}
248249
return
@@ -266,3 +267,7 @@ func (q *commonDataQuery) GetLabels(ctx context.Context, sql string) (metadata m
266267
func (q *commonDataQuery) GetClient() *gorm.DB {
267268
return q.db
268269
}
270+
271+
func (q *commonDataQuery) GetInnerSQL() InnerSQL {
272+
return q.innerSQL
273+
}

pkg/inner_sql.go

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
Copyright 2025 API Testing Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package pkg
18+
19+
import "strings"
20+
21+
type InnerSQL interface {
22+
ToNativeSQL(query string) string
23+
}
24+
25+
const (
26+
innerSelectTable_ = "@selectTable_"
27+
innerShowDatabases = "@showDatabases"
28+
innerShowTables = "@showTables"
29+
innerCurrentDB = "@currentDB"
30+
)
31+
32+
func GetInnerSQL(dialect string) InnerSQL {
33+
switch dialect {
34+
case "postgres":
35+
return &postgresDialect{}
36+
default:
37+
return &mysqlDialect{}
38+
}
39+
}
40+
41+
type mysqlDialect struct {
42+
}
43+
44+
func (m *mysqlDialect) ToNativeSQL(query string) (sql string) {
45+
if strings.HasPrefix(query, innerSelectTable_) {
46+
sql = "SELECT * FROM " + strings.ReplaceAll(query, innerSelectTable_, "")
47+
} else if query == innerShowDatabases {
48+
sql = "SHOW DATABASES"
49+
} else if query == innerShowTables {
50+
sql = "SHOW TABLES"
51+
} else if query == innerCurrentDB {
52+
sql = "SELECT DATABASE() as name"
53+
} else {
54+
sql = query
55+
}
56+
return
57+
}
58+
59+
type postgresDialect struct {
60+
}
61+
62+
func (p *postgresDialect) ToNativeSQL(query string) (sql string) {
63+
if strings.HasPrefix(query, innerSelectTable_) {
64+
sql = `SELECT * FROM "` + strings.ReplaceAll(query, innerSelectTable_, "") + `"`
65+
} else if query == innerShowDatabases {
66+
sql = "SELECT table_catalog as name FROM information_schema.tables"
67+
} else if query == innerShowTables {
68+
sql = `SELECT table_name FROM information_schema.tables WHERE table_catalog = '%s' and table_schema != 'pg_catalog' and table_schema != 'information_schema'`
69+
} else if query == innerCurrentDB {
70+
sql = "SELECT current_database() as name"
71+
} else {
72+
sql = query
73+
}
74+
return
75+
}

pkg/server.go

+1-7
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,7 @@ func (s *dbserver) getClientWithDatabase(ctx context.Context, dbName string) (db
128128
}
129129
}
130130

131-
switch driver {
132-
case "postgres":
133-
dbQuery = NewCommonDataQuery("select table_catalog as name from information_schema.tables",
134-
`SELECT table_name FROM information_schema.tables WHERE table_catalog = '%s' and table_schema != 'pg_catalog' and table_schema != 'information_schema'`, "SELECT current_database() as name", db)
135-
default:
136-
dbQuery = NewCommonDataQuery("show databases", "show tables", "SELECT DATABASE() as name", db)
137-
}
131+
dbQuery = NewCommonDataQuery(GetInnerSQL(driver), db)
138132
}
139133
return
140134
}

0 commit comments

Comments
 (0)