Skip to content

Commit 9e765e5

Browse files
committed
pgx: don't use database/sql interface
it'd be nice to have contexts for many of these methods, but that'd be a much wider change
1 parent c378583 commit 9e765e5

File tree

4 files changed

+115
-169
lines changed

4 files changed

+115
-169
lines changed

database/pgx/pgx.go

Lines changed: 40 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ import (
2222
"github.com/hashicorp/go-multierror"
2323
"github.com/jackc/pgconn"
2424
"github.com/jackc/pgerrcode"
25+
"github.com/jackc/pgx/v4"
2526
_ "github.com/jackc/pgx/v4/stdlib"
26-
"github.com/lib/pq"
2727
)
2828

2929
const (
@@ -69,27 +69,26 @@ type Config struct {
6969

7070
type Postgres struct {
7171
// Locking and unlocking need to use the same connection
72-
conn *sql.Conn
73-
db *sql.DB
72+
conn *pgx.Conn
7473
isLocked atomic.Bool
7574

7675
// Open and WithInstance need to guarantee that config is never nil
7776
config *Config
7877
}
7978

80-
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
79+
func WithInstance(instance *pgx.Conn, config *Config) (database.Driver, error) {
8180
if config == nil {
8281
return nil, ErrNilConfig
8382
}
8483

85-
if err := instance.Ping(); err != nil {
84+
if err := instance.Ping(context.Background()); err != nil {
8685
return nil, err
8786
}
8887

8988
if config.DatabaseName == "" {
9089
query := `SELECT CURRENT_DATABASE()`
9190
var databaseName string
92-
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
91+
if err := instance.QueryRow(context.Background(), query).Scan(&databaseName); err != nil {
9392
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
9493
}
9594

@@ -103,7 +102,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
103102
if config.SchemaName == "" {
104103
query := `SELECT CURRENT_SCHEMA()`
105104
var schemaName string
106-
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
105+
if err := instance.QueryRow(context.Background(), query).Scan(&schemaName); err != nil {
107106
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
108107
}
109108

@@ -139,15 +138,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
139138
}
140139
}
141140

142-
conn, err := instance.Conn(context.Background())
143-
144-
if err != nil {
145-
return nil, err
146-
}
147-
148141
px := &Postgres{
149-
conn: conn,
150-
db: instance,
142+
conn: instance,
151143
config: config,
152144
}
153145

@@ -173,7 +165,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
173165
// i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
174166
purl.Scheme = "postgres"
175167

176-
db, err := sql.Open("pgx/v4", migrate.FilterCustomQuery(purl).String())
168+
db, err := pgx.Connect(context.Background(), migrate.FilterCustomQuery(purl).String())
177169
if err != nil {
178170
return nil, err
179171
}
@@ -240,10 +232,9 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
240232
}
241233

242234
func (p *Postgres) Close() error {
243-
connErr := p.conn.Close()
244-
dbErr := p.db.Close()
245-
if connErr != nil || dbErr != nil {
246-
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
235+
connErr := p.conn.Close(context.Background())
236+
if connErr != nil {
237+
return fmt.Errorf("conn: %w", connErr)
247238
}
248239
return nil
249240
}
@@ -283,19 +274,19 @@ func (p *Postgres) applyAdvisoryLock() error {
283274

284275
// This will wait indefinitely until the lock can be acquired.
285276
query := `SELECT pg_advisory_lock($1)`
286-
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
277+
if _, err := p.conn.Exec(context.Background(), query, aid); err != nil {
287278
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
288279
}
289280
return nil
290281
}
291282

292283
func (p *Postgres) applyTableLock() error {
293-
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
284+
tx, err := p.conn.BeginTx(context.Background(), pgx.TxOptions{})
294285
if err != nil {
295286
return &database.Error{OrigErr: err, Err: "transaction start failed"}
296287
}
297288
defer func() {
298-
errRollback := tx.Rollback()
289+
errRollback := tx.Rollback(context.Background())
299290
if errRollback != nil {
300291
err = multierror.Append(err, errRollback)
301292
}
@@ -306,30 +297,25 @@ func (p *Postgres) applyTableLock() error {
306297
return err
307298
}
308299

309-
query := "SELECT * FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
310-
rows, err := tx.Query(query, aid)
300+
query := "SELECT * FROM " + quoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
301+
rows, err := tx.Query(context.Background(), query, aid)
311302
if err != nil {
312303
return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
313304
}
314-
315-
defer func() {
316-
if errClose := rows.Close(); errClose != nil {
317-
err = multierror.Append(err, errClose)
318-
}
319-
}()
305+
defer rows.Close()
320306

321307
// If row exists at all, lock is present
322308
locked := rows.Next()
323309
if locked {
324310
return database.ErrLocked
325311
}
326312

327-
query = "INSERT INTO " + pq.QuoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)"
328-
if _, err := tx.Exec(query, aid); err != nil {
313+
query = "INSERT INTO " + quoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)"
314+
if _, err := tx.Exec(context.Background(), query, aid); err != nil {
329315
return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
330316
}
331317

332-
return tx.Commit()
318+
return tx.Commit(context.Background())
333319
}
334320

335321
func (p *Postgres) releaseAdvisoryLock() error {
@@ -339,7 +325,7 @@ func (p *Postgres) releaseAdvisoryLock() error {
339325
}
340326

341327
query := `SELECT pg_advisory_unlock($1)`
342-
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
328+
if _, err := p.conn.Exec(context.Background(), query, aid); err != nil {
343329
return &database.Error{OrigErr: err, Query: []byte(query)}
344330
}
345331

@@ -352,8 +338,8 @@ func (p *Postgres) releaseTableLock() error {
352338
return err
353339
}
354340

355-
query := "DELETE FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
356-
if _, err := p.db.Exec(query, aid); err != nil {
341+
query := "DELETE FROM " + quoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
342+
if _, err := p.conn.Exec(context.Background(), query, aid); err != nil {
357343
return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
358344
}
359345

@@ -391,7 +377,7 @@ func (p *Postgres) runStatement(statement []byte) error {
391377
if strings.TrimSpace(query) == "" {
392378
return nil
393379
}
394-
if _, err := p.conn.ExecContext(ctx, query); err != nil {
380+
if _, err := p.conn.Exec(ctx, query); err != nil {
395381

396382
if pgErr, ok := err.(*pgconn.PgError); ok {
397383
var line uint
@@ -448,14 +434,14 @@ func runesLastIndex(input []rune, target rune) int {
448434
}
449435

450436
func (p *Postgres) SetVersion(version int, dirty bool) error {
451-
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
437+
tx, err := p.conn.BeginTx(context.Background(), pgx.TxOptions{})
452438
if err != nil {
453439
return &database.Error{OrigErr: err, Err: "transaction start failed"}
454440
}
455441

456442
query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
457-
if _, err := tx.Exec(query); err != nil {
458-
if errRollback := tx.Rollback(); errRollback != nil {
443+
if _, err := tx.Exec(context.Background(), query); err != nil {
444+
if errRollback := tx.Rollback(context.Background()); errRollback != nil {
459445
err = multierror.Append(err, errRollback)
460446
}
461447
return &database.Error{OrigErr: err, Query: []byte(query)}
@@ -466,15 +452,15 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
466452
// See: https://github.com/golang-migrate/migrate/issues/330
467453
if version >= 0 || (version == database.NilVersion && dirty) {
468454
query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
469-
if _, err := tx.Exec(query, version, dirty); err != nil {
470-
if errRollback := tx.Rollback(); errRollback != nil {
455+
if _, err := tx.Exec(context.Background(), query, version, dirty); err != nil {
456+
if errRollback := tx.Rollback(context.Background()); errRollback != nil {
471457
err = multierror.Append(err, errRollback)
472458
}
473459
return &database.Error{OrigErr: err, Query: []byte(query)}
474460
}
475461
}
476462

477-
if err := tx.Commit(); err != nil {
463+
if err := tx.Commit(context.Background()); err != nil {
478464
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
479465
}
480466

@@ -483,7 +469,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
483469

484470
func (p *Postgres) Version() (version int, dirty bool, err error) {
485471
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
486-
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
472+
err = p.conn.QueryRow(context.Background(), query).Scan(&version, &dirty)
487473
switch {
488474
case err == sql.ErrNoRows:
489475
return database.NilVersion, false, nil
@@ -504,15 +490,11 @@ func (p *Postgres) Version() (version int, dirty bool, err error) {
504490
func (p *Postgres) Drop() (err error) {
505491
// select all tables in current schema
506492
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
507-
tables, err := p.conn.QueryContext(context.Background(), query)
493+
tables, err := p.conn.Query(context.Background(), query)
508494
if err != nil {
509495
return &database.Error{OrigErr: err, Query: []byte(query)}
510496
}
511-
defer func() {
512-
if errClose := tables.Close(); errClose != nil {
513-
err = multierror.Append(err, errClose)
514-
}
515-
}()
497+
defer tables.Close()
516498

517499
// delete one table after another
518500
tableNames := make([]string, 0)
@@ -539,7 +521,7 @@ func (p *Postgres) Drop() (err error) {
539521
// delete one by one ...
540522
for _, t := range tableNames {
541523
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
542-
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
524+
if _, err := p.conn.Exec(context.Background(), query); err != nil {
543525
return &database.Error{OrigErr: err, Query: []byte(query)}
544526
}
545527
}
@@ -571,7 +553,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
571553
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
572554
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
573555
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
574-
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
556+
row := p.conn.QueryRow(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
575557

576558
var count int
577559
err = row.Scan(&count)
@@ -584,7 +566,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
584566
}
585567

586568
query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
587-
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
569+
if _, err = p.conn.Exec(context.Background(), query); err != nil {
588570
return &database.Error{OrigErr: err, Query: []byte(query)}
589571
}
590572

@@ -598,15 +580,15 @@ func (p *Postgres) ensureLockTable() error {
598580

599581
var count int
600582
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
601-
if err := p.db.QueryRow(query, p.config.LockTable).Scan(&count); err != nil {
583+
if err := p.conn.QueryRow(context.Background(), query, p.config.LockTable).Scan(&count); err != nil {
602584
return &database.Error{OrigErr: err, Query: []byte(query)}
603585
}
604586
if count == 1 {
605587
return nil
606588
}
607589

608-
query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
609-
if _, err := p.db.Exec(query); err != nil {
590+
query = `CREATE TABLE ` + quoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
591+
if _, err := p.conn.Exec(context.Background(), query); err != nil {
610592
return &database.Error{OrigErr: err, Query: []byte(query)}
611593
}
612594

@@ -615,9 +597,5 @@ func (p *Postgres) ensureLockTable() error {
615597

616598
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
617599
func quoteIdentifier(name string) string {
618-
end := strings.IndexRune(name, 0)
619-
if end > -1 {
620-
name = name[:end]
621-
}
622-
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
600+
return pgx.Identifier([]string{name}).Sanitize()
623601
}

0 commit comments

Comments
 (0)