Skip to content

Commit ef13370

Browse files
committed
support postgres data query
1 parent 04c6d83 commit ef13370

File tree

6 files changed

+155
-63
lines changed

6 files changed

+155
-63
lines changed

e2e/compose.yaml

+26-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ services:
99
condition: service_healthy
1010
greptimedb:
1111
condition: service_healthy
12+
tdengine:
13+
condition: service_healthy
14+
postgres:
15+
condition: service_healthy
1216
extension:
1317
condition: service_started
1418
volumes:
@@ -18,6 +22,8 @@ services:
1822
links:
1923
- mysql
2024
- greptimedb
25+
- tdengine
26+
- postgres
2127
extension:
2228
build:
2329
context: ..
@@ -38,28 +44,43 @@ services:
3844
test: ["CMD", "bash", "-c", "cat < /dev/null > /dev/tcp/127.0.0.1/3306"]
3945
interval: 3s
4046
timeout: 180s
41-
retries: 30
47+
retries: 60
4248
# ports:
4349
# - "3306:3306"
4450

4551
tdengine:
4652
image: ghcr.io/linuxsuren/tdengine/tdengine:3.3.3.0
47-
environment:
48-
TAOS_ROOT_PASSWORD: "root"
53+
healthcheck:
54+
test: ["CMD", "bash", "-c", "cat < /dev/null > /dev/tcp/tdengine/6041"]
55+
interval: 3s
56+
timeout: 180s
57+
retries: 30
4958
# ports:
5059
# - "6030:6030" # REST API port
5160
# - "6031:6031" # client port
5261
# - "6041:6041" # cluster port
5362

5463
greptimedb:
5564
image: ghcr.io/linuxsuren/greptime/greptimedb:v0.12.0
56-
command: standalone start
65+
command: standalone start --mysql-addr=0.0.0.0:4002
5766
healthcheck:
58-
test: ["CMD", "curl", "-f", "http://localhost:4000/health"]
67+
test: ["CMD", "bash", "-c", "cat < /dev/null > /dev/tcp/greptimedb/4002"]
5968
interval: 10s
6069
timeout: 5s
6170
retries: 3
6271
# ports:
6372
# - "4002:4002"
73+
74+
postgres:
75+
image: ghcr.io/linuxsuren/library/postgres:16.0
76+
environment:
77+
POSTGRES_USER: root
78+
POSTGRES_PASSWORD: root
79+
POSTGRES_DB: atest
80+
healthcheck:
81+
test: ["CMD", "bash", "-c", "cat < /dev/null > /dev/tcp/127.0.0.1/5432"]
82+
interval: 3s
83+
timeout: 30s
84+
retries: 10
6485
volumes:
6586
cache:

e2e/entrypoint.sh

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ echo "start to run testing: $cmd"
1313
kind=orm target=mysql:3306 driver=mysql $cmd
1414

1515
kind=orm target=mysql driver=mysql atest run -p testing-data-query.yaml
16-
kind=orm target=greptimedb:4002 driver=mysql atest run -p testing-data-query.yaml
16+
kind=orm target=greptimedb:4002 driver=greptime dbname=public atest run -p testing-data-query.yaml
17+
kind=orm target=tdengine:6041 driver=tdengine password=taosdata dbname=information_schema atest run -p testing-data-query.yaml
18+
kind=orm target=postgres driver=postgres atest run -p testing-data-query.yaml
1719

1820
cat /root/.config/atest/stores.yaml

e2e/testing-data-query.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,6 @@ items:
5555
X-Store-Name: "{{.param.store}}"
5656
body: |
5757
{
58-
"sql": "show tables",
58+
"sql": "",
5959
"key": ""
6060
}

pkg/data_query.go

+100-43
Original file line numberDiff line numberDiff line change
@@ -21,70 +21,41 @@ import (
2121
"fmt"
2222
"github.com/linuxsuren/api-testing/pkg/server"
2323
"gorm.io/gorm"
24+
"log"
2425
"reflect"
2526
"sort"
27+
"strings"
2628
"time"
2729
)
2830

2931
func (s *dbserver) Query(ctx context.Context, query *server.DataQuery) (result *server.DataQueryResult, err error) {
3032
var db *gorm.DB
31-
if db, err = s.getClientWithDatabase(ctx, query.Key); err != nil {
33+
var dbQuery DataQuery
34+
if dbQuery, err = s.getClientWithDatabase(ctx, query.Key); err != nil {
3235
return
3336
}
3437

38+
db = dbQuery.GetClient()
39+
3540
result = &server.DataQueryResult{
3641
Data: []*server.Pair{},
3742
Items: make([]*server.Pairs, 0),
3843
Meta: &server.DataMeta{},
3944
}
4045

4146
// query database and tables
42-
var databaseResult *server.DataQueryResult
43-
if databaseResult, err = sqlQuery(ctx, queryDatabaseSql, db); err == nil {
44-
for _, table := range databaseResult.Items {
45-
for _, item := range table.GetData() {
46-
if item.Key == "Database" || item.Key == "name" {
47-
var found bool
48-
for _, name := range result.Meta.Databases {
49-
if name == item.Value {
50-
found = true
51-
}
52-
}
53-
if !found {
54-
result.Meta.Databases = append(result.Meta.Databases, item.Value)
55-
}
56-
}
57-
}
58-
}
59-
sort.Strings(result.Meta.Databases)
47+
if result.Meta.Databases, err = dbQuery.GetDatabases(ctx); err != nil {
48+
log.Printf("failed to query databases: %v\n", err)
6049
}
6150

62-
var row *sql.Row
63-
if row = db.Raw("SELECT DATABASE() as name").Row(); row != nil {
64-
_ = row.Scan(&result.Meta.CurrentDatabase)
65-
} else {
66-
result.Meta.CurrentDatabase = query.Key
51+
if result.Meta.CurrentDatabase = query.Key; query.Key == "" {
52+
if result.Meta.CurrentDatabase, err = dbQuery.GetCurrentDatabase(); err != nil {
53+
log.Printf("failed to query current database: %v\n", err)
54+
}
6755
}
6856

69-
queryTableSql := "show tables"
70-
var tableResult *server.DataQueryResult
71-
if tableResult, err = sqlQuery(ctx, queryTableSql, db); err == nil {
72-
for _, table := range tableResult.Items {
73-
for _, item := range table.GetData() {
74-
if item.Key == fmt.Sprintf("Tables_in_%s", result.Meta.CurrentDatabase) || item.Key == "table_name" || item.Key == "Tables" {
75-
var found bool
76-
for _, name := range result.Meta.Tables {
77-
if name == item.Value {
78-
found = true
79-
}
80-
}
81-
if !found {
82-
result.Meta.Tables = append(result.Meta.Tables, item.Value)
83-
}
84-
}
85-
}
86-
}
87-
sort.Strings(result.Meta.Tables)
57+
if result.Meta.Tables, err = dbQuery.GetTables(ctx, result.Meta.CurrentDatabase); err != nil {
58+
log.Printf("failed to query tables: %v\n", err)
8859
}
8960

9061
// query data
@@ -187,3 +158,89 @@ func sqlQuery(ctx context.Context, sql string, db *gorm.DB) (result *server.Data
187158
}
188159

189160
const queryDatabaseSql = "show databases"
161+
162+
type DataQuery interface {
163+
GetDatabases(context.Context) (databases []string, err error)
164+
GetTables(ctx context.Context, currentDatabase string) (tables []string, err error)
165+
GetCurrentDatabase() (string, error)
166+
GetClient() *gorm.DB
167+
}
168+
169+
type commonDataQuery struct {
170+
showDatabases, showTables, currentDatabase string
171+
db *gorm.DB
172+
}
173+
174+
var _ DataQuery = &commonDataQuery{}
175+
176+
func NewCommonDataQuery(showDatabases, showTables, currentDatabase string, db *gorm.DB) DataQuery {
177+
return &commonDataQuery{
178+
showDatabases: showDatabases,
179+
showTables: showTables,
180+
currentDatabase: currentDatabase,
181+
db: db,
182+
}
183+
}
184+
185+
func (q *commonDataQuery) GetDatabases(ctx context.Context) (databases []string, err error) {
186+
var databaseResult *server.DataQueryResult
187+
if databaseResult, err = sqlQuery(ctx, q.showDatabases, q.db); err == nil {
188+
for _, table := range databaseResult.Items {
189+
for _, item := range table.GetData() {
190+
if item.Key == "Database" || item.Key == "name" {
191+
var found bool
192+
for _, name := range databases {
193+
if name == item.Value {
194+
found = true
195+
}
196+
}
197+
if !found {
198+
databases = append(databases, item.Value)
199+
}
200+
}
201+
}
202+
}
203+
sort.Strings(databases)
204+
}
205+
return
206+
}
207+
208+
func (q *commonDataQuery) GetTables(ctx context.Context, currentDatabase string) (tables []string, err error) {
209+
showTables := q.showTables
210+
if strings.Contains(showTables, "%s") {
211+
showTables = fmt.Sprintf(showTables, currentDatabase)
212+
}
213+
214+
var tableResult *server.DataQueryResult
215+
if tableResult, err = sqlQuery(ctx, showTables, q.db); err == nil {
216+
for _, table := range tableResult.Items {
217+
for _, item := range table.GetData() {
218+
if item.Key == fmt.Sprintf("Tables_in_%s", currentDatabase) || item.Key == "table_name" ||
219+
item.Key == "Tables" || item.Key == "tablename" {
220+
var found bool
221+
for _, name := range tables {
222+
if name == item.Value {
223+
found = true
224+
}
225+
}
226+
if !found {
227+
tables = append(tables, item.Value)
228+
}
229+
}
230+
}
231+
}
232+
sort.Strings(tables)
233+
}
234+
return
235+
}
236+
func (q *commonDataQuery) GetCurrentDatabase() (current string, err error) {
237+
var row *sql.Row
238+
if row = q.db.Raw(q.currentDatabase).Row(); row != nil {
239+
err = row.Scan(&current)
240+
}
241+
return
242+
}
243+
244+
func (q *commonDataQuery) GetClient() *gorm.DB {
245+
return q.db
246+
}

pkg/server.go

+23-10
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ func createDB(user, password, address, database, driver string) (db *gorm.DB, er
9999
var dbCache = make(map[string]*gorm.DB)
100100
var dbNameCache = make(map[string]string)
101101

102-
func (s *dbserver) getClientWithDatabase(ctx context.Context, dbName string) (db *gorm.DB, err error) {
102+
func (s *dbserver) getClientWithDatabase(ctx context.Context, dbName string) (dbQuery DataQuery, err error) {
103103
store := remote.GetStoreFromContext(ctx)
104104
if store == nil {
105105
err = errors.New("no connect to database")
@@ -118,19 +118,32 @@ func (s *dbserver) getClientWithDatabase(ctx context.Context, dbName string) (db
118118
log.Printf("get client from driver[%s] in database [%s]", driver, database)
119119

120120
var ok bool
121-
if db, ok = dbCache[store.Name]; ok && db != nil && dbNameCache[store.Name] == database {
122-
return
121+
var db *gorm.DB
122+
if db, ok = dbCache[store.Name]; (ok && db != nil && dbNameCache[store.Name] != database) || !ok {
123+
if db, err = createDB(store.Username, store.Password, store.URL, database, driver); err == nil {
124+
dbCache[store.Name] = db
125+
dbNameCache[store.Name] = database
126+
} else {
127+
return
128+
}
123129
}
124130

125-
if db, err = createDB(store.Username, store.Password, store.URL, database, driver); err == nil {
126-
dbCache[store.Name] = db
127-
dbNameCache[store.Name] = database
131+
switch driver {
132+
case "postgres":
133+
dbQuery = NewCommonDataQuery("select table_catalog as name from information_schema.tables",
134+
`SELECT table_name FROM information_schema.tables WHERE table_catalog = '%s' and table_schema != 'pg_catalog' and table_schema != 'information_schema'`, "SELECT current_database() as name", db)
135+
default:
136+
dbQuery = NewCommonDataQuery("show databases", "show tables", "SELECT DATABASE() as name", db)
128137
}
129138
}
130139
return
131140
}
141+
132142
func (s *dbserver) getClient(ctx context.Context) (db *gorm.DB, err error) {
133-
db, err = s.getClientWithDatabase(ctx, "")
143+
var dbQuery DataQuery
144+
if dbQuery, err = s.getClientWithDatabase(ctx, ""); err == nil {
145+
db = dbQuery.GetClient()
146+
}
134147
return
135148
}
136149

@@ -478,9 +491,9 @@ func (s *dbserver) Verify(ctx context.Context, in *server.Empty) (reply *server.
478491
}
479492

480493
var vErr error
481-
var db *gorm.DB
482-
if db, err = s.getClient(ctx); err == nil {
483-
_, vErr = db.ConnPool.QueryContext(ctx, queryDatabaseSql)
494+
var dbQuery DataQuery
495+
if dbQuery, err = s.getClientWithDatabase(ctx, ""); err == nil {
496+
_, vErr = dbQuery.GetDatabases(ctx)
484497
}
485498

486499
reply.Ready = vErr == nil

pkg/server_test.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,8 @@ func TestNewRemoteServer(t *testing.T) {
8585
})
8686

8787
t.Run("Verify", func(t *testing.T) {
88-
reply, err := remoteServer.Verify(defaultCtx, nil)
89-
assert.NoError(t, err)
90-
assert.False(t, reply.Ready)
88+
_, err := remoteServer.Verify(defaultCtx, nil)
89+
assert.Error(t, err)
9190
})
9291

9392
t.Run("CreateTestCaseHistory", func(t *testing.T) {

0 commit comments

Comments
 (0)