@@ -22,8 +22,8 @@ import (
22
22
"github.com/hashicorp/go-multierror"
23
23
"github.com/jackc/pgconn"
24
24
"github.com/jackc/pgerrcode"
25
+ "github.com/jackc/pgx/v4"
25
26
_ "github.com/jackc/pgx/v4/stdlib"
26
- "github.com/lib/pq"
27
27
)
28
28
29
29
const (
@@ -69,27 +69,26 @@ type Config struct {
69
69
70
70
type Postgres struct {
71
71
// Locking and unlocking need to use the same connection
72
- conn * sql.Conn
73
- db * sql.DB
72
+ conn * pgx.Conn
74
73
isLocked atomic.Bool
75
74
76
75
// Open and WithInstance need to guarantee that config is never nil
77
76
config * Config
78
77
}
79
78
80
- func WithInstance (instance * sql. DB , config * Config ) (database.Driver , error ) {
79
+ func WithInstance (instance * pgx. Conn , config * Config ) (database.Driver , error ) {
81
80
if config == nil {
82
81
return nil , ErrNilConfig
83
82
}
84
83
85
- if err := instance .Ping (); err != nil {
84
+ if err := instance .Ping (context . Background () ); err != nil {
86
85
return nil , err
87
86
}
88
87
89
88
if config .DatabaseName == "" {
90
89
query := `SELECT CURRENT_DATABASE()`
91
90
var databaseName string
92
- if err := instance .QueryRow (query ).Scan (& databaseName ); err != nil {
91
+ if err := instance .QueryRow (context . Background (), query ).Scan (& databaseName ); err != nil {
93
92
return nil , & database.Error {OrigErr : err , Query : []byte (query )}
94
93
}
95
94
@@ -103,7 +102,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
103
102
if config .SchemaName == "" {
104
103
query := `SELECT CURRENT_SCHEMA()`
105
104
var schemaName string
106
- if err := instance .QueryRow (query ).Scan (& schemaName ); err != nil {
105
+ if err := instance .QueryRow (context . Background (), query ).Scan (& schemaName ); err != nil {
107
106
return nil , & database.Error {OrigErr : err , Query : []byte (query )}
108
107
}
109
108
@@ -139,15 +138,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
139
138
}
140
139
}
141
140
142
- conn , err := instance .Conn (context .Background ())
143
-
144
- if err != nil {
145
- return nil , err
146
- }
147
-
148
141
px := & Postgres {
149
- conn : conn ,
150
- db : instance ,
142
+ conn : instance ,
151
143
config : config ,
152
144
}
153
145
@@ -173,7 +165,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
173
165
// i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
174
166
purl .Scheme = "postgres"
175
167
176
- db , err := sql . Open ( "pgx/v4" , migrate .FilterCustomQuery (purl ).String ())
168
+ db , err := pgx . Connect ( context . Background () , migrate .FilterCustomQuery (purl ).String ())
177
169
if err != nil {
178
170
return nil , err
179
171
}
@@ -240,10 +232,9 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
240
232
}
241
233
242
234
func (p * Postgres ) Close () error {
243
- connErr := p .conn .Close ()
244
- dbErr := p .db .Close ()
245
- if connErr != nil || dbErr != nil {
246
- return fmt .Errorf ("conn: %v, db: %v" , connErr , dbErr )
235
+ connErr := p .conn .Close (context .Background ())
236
+ if connErr != nil {
237
+ return fmt .Errorf ("conn: %w" , connErr )
247
238
}
248
239
return nil
249
240
}
@@ -283,19 +274,19 @@ func (p *Postgres) applyAdvisoryLock() error {
283
274
284
275
// This will wait indefinitely until the lock can be acquired.
285
276
query := `SELECT pg_advisory_lock($1)`
286
- if _ , err := p .conn .ExecContext (context .Background (), query , aid ); err != nil {
277
+ if _ , err := p .conn .Exec (context .Background (), query , aid ); err != nil {
287
278
return & database.Error {OrigErr : err , Err : "try lock failed" , Query : []byte (query )}
288
279
}
289
280
return nil
290
281
}
291
282
292
283
func (p * Postgres ) applyTableLock () error {
293
- tx , err := p .conn .BeginTx (context .Background (), & sql .TxOptions {})
284
+ tx , err := p .conn .BeginTx (context .Background (), pgx .TxOptions {})
294
285
if err != nil {
295
286
return & database.Error {OrigErr : err , Err : "transaction start failed" }
296
287
}
297
288
defer func () {
298
- errRollback := tx .Rollback ()
289
+ errRollback := tx .Rollback (context . Background () )
299
290
if errRollback != nil {
300
291
err = multierror .Append (err , errRollback )
301
292
}
@@ -306,30 +297,25 @@ func (p *Postgres) applyTableLock() error {
306
297
return err
307
298
}
308
299
309
- query := "SELECT * FROM " + pq . QuoteIdentifier (p .config .LockTable ) + " WHERE lock_id = $1"
310
- rows , err := tx .Query (query , aid )
300
+ query := "SELECT * FROM " + quoteIdentifier (p .config .LockTable ) + " WHERE lock_id = $1"
301
+ rows , err := tx .Query (context . Background (), query , aid )
311
302
if err != nil {
312
303
return database.Error {OrigErr : err , Err : "failed to fetch migration lock" , Query : []byte (query )}
313
304
}
314
-
315
- defer func () {
316
- if errClose := rows .Close (); errClose != nil {
317
- err = multierror .Append (err , errClose )
318
- }
319
- }()
305
+ defer rows .Close ()
320
306
321
307
// If row exists at all, lock is present
322
308
locked := rows .Next ()
323
309
if locked {
324
310
return database .ErrLocked
325
311
}
326
312
327
- query = "INSERT INTO " + pq . QuoteIdentifier (p .config .LockTable ) + " (lock_id) VALUES ($1)"
328
- if _ , err := tx .Exec (query , aid ); err != nil {
313
+ query = "INSERT INTO " + quoteIdentifier (p .config .LockTable ) + " (lock_id) VALUES ($1)"
314
+ if _ , err := tx .Exec (context . Background (), query , aid ); err != nil {
329
315
return database.Error {OrigErr : err , Err : "failed to set migration lock" , Query : []byte (query )}
330
316
}
331
317
332
- return tx .Commit ()
318
+ return tx .Commit (context . Background () )
333
319
}
334
320
335
321
func (p * Postgres ) releaseAdvisoryLock () error {
@@ -339,7 +325,7 @@ func (p *Postgres) releaseAdvisoryLock() error {
339
325
}
340
326
341
327
query := `SELECT pg_advisory_unlock($1)`
342
- if _ , err := p .conn .ExecContext (context .Background (), query , aid ); err != nil {
328
+ if _ , err := p .conn .Exec (context .Background (), query , aid ); err != nil {
343
329
return & database.Error {OrigErr : err , Query : []byte (query )}
344
330
}
345
331
@@ -352,8 +338,8 @@ func (p *Postgres) releaseTableLock() error {
352
338
return err
353
339
}
354
340
355
- query := "DELETE FROM " + pq . QuoteIdentifier (p .config .LockTable ) + " WHERE lock_id = $1"
356
- if _ , err := p .db .Exec (query , aid ); err != nil {
341
+ query := "DELETE FROM " + quoteIdentifier (p .config .LockTable ) + " WHERE lock_id = $1"
342
+ if _ , err := p .conn .Exec (context . Background (), query , aid ); err != nil {
357
343
return database.Error {OrigErr : err , Err : "failed to release migration lock" , Query : []byte (query )}
358
344
}
359
345
@@ -391,7 +377,7 @@ func (p *Postgres) runStatement(statement []byte) error {
391
377
if strings .TrimSpace (query ) == "" {
392
378
return nil
393
379
}
394
- if _ , err := p .conn .ExecContext (ctx , query ); err != nil {
380
+ if _ , err := p .conn .Exec (ctx , query ); err != nil {
395
381
396
382
if pgErr , ok := err .(* pgconn.PgError ); ok {
397
383
var line uint
@@ -448,14 +434,14 @@ func runesLastIndex(input []rune, target rune) int {
448
434
}
449
435
450
436
func (p * Postgres ) SetVersion (version int , dirty bool ) error {
451
- tx , err := p .conn .BeginTx (context .Background (), & sql .TxOptions {})
437
+ tx , err := p .conn .BeginTx (context .Background (), pgx .TxOptions {})
452
438
if err != nil {
453
439
return & database.Error {OrigErr : err , Err : "transaction start failed" }
454
440
}
455
441
456
442
query := `TRUNCATE ` + quoteIdentifier (p .config .migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName )
457
- if _ , err := tx .Exec (query ); err != nil {
458
- if errRollback := tx .Rollback (); errRollback != nil {
443
+ if _ , err := tx .Exec (context . Background (), query ); err != nil {
444
+ if errRollback := tx .Rollback (context . Background () ); errRollback != nil {
459
445
err = multierror .Append (err , errRollback )
460
446
}
461
447
return & database.Error {OrigErr : err , Query : []byte (query )}
@@ -466,15 +452,15 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
466
452
// See: https://github.com/golang-migrate/migrate/issues/330
467
453
if version >= 0 || (version == database .NilVersion && dirty ) {
468
454
query = `INSERT INTO ` + quoteIdentifier (p .config .migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName ) + ` (version, dirty) VALUES ($1, $2)`
469
- if _ , err := tx .Exec (query , version , dirty ); err != nil {
470
- if errRollback := tx .Rollback (); errRollback != nil {
455
+ if _ , err := tx .Exec (context . Background (), query , version , dirty ); err != nil {
456
+ if errRollback := tx .Rollback (context . Background () ); errRollback != nil {
471
457
err = multierror .Append (err , errRollback )
472
458
}
473
459
return & database.Error {OrigErr : err , Query : []byte (query )}
474
460
}
475
461
}
476
462
477
- if err := tx .Commit (); err != nil {
463
+ if err := tx .Commit (context . Background () ); err != nil {
478
464
return & database.Error {OrigErr : err , Err : "transaction commit failed" }
479
465
}
480
466
@@ -483,7 +469,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
483
469
484
470
func (p * Postgres ) Version () (version int , dirty bool , err error ) {
485
471
query := `SELECT version, dirty FROM ` + quoteIdentifier (p .config .migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName ) + ` LIMIT 1`
486
- err = p .conn .QueryRowContext (context .Background (), query ).Scan (& version , & dirty )
472
+ err = p .conn .QueryRow (context .Background (), query ).Scan (& version , & dirty )
487
473
switch {
488
474
case err == sql .ErrNoRows :
489
475
return database .NilVersion , false , nil
@@ -504,15 +490,11 @@ func (p *Postgres) Version() (version int, dirty bool, err error) {
504
490
func (p * Postgres ) Drop () (err error ) {
505
491
// select all tables in current schema
506
492
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
507
- tables , err := p .conn .QueryContext (context .Background (), query )
493
+ tables , err := p .conn .Query (context .Background (), query )
508
494
if err != nil {
509
495
return & database.Error {OrigErr : err , Query : []byte (query )}
510
496
}
511
- defer func () {
512
- if errClose := tables .Close (); errClose != nil {
513
- err = multierror .Append (err , errClose )
514
- }
515
- }()
497
+ defer tables .Close ()
516
498
517
499
// delete one table after another
518
500
tableNames := make ([]string , 0 )
@@ -539,7 +521,7 @@ func (p *Postgres) Drop() (err error) {
539
521
// delete one by one ...
540
522
for _ , t := range tableNames {
541
523
query = `DROP TABLE IF EXISTS ` + quoteIdentifier (t ) + ` CASCADE`
542
- if _ , err := p .conn .ExecContext (context .Background (), query ); err != nil {
524
+ if _ , err := p .conn .Exec (context .Background (), query ); err != nil {
543
525
return & database.Error {OrigErr : err , Query : []byte (query )}
544
526
}
545
527
}
@@ -571,7 +553,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
571
553
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
572
554
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
573
555
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
574
- row := p .conn .QueryRowContext (context .Background (), query , p .config .migrationsSchemaName , p .config .migrationsTableName )
556
+ row := p .conn .QueryRow (context .Background (), query , p .config .migrationsSchemaName , p .config .migrationsTableName )
575
557
576
558
var count int
577
559
err = row .Scan (& count )
@@ -584,7 +566,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
584
566
}
585
567
586
568
query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier (p .config .migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName ) + ` (version bigint not null primary key, dirty boolean not null)`
587
- if _ , err = p .conn .ExecContext (context .Background (), query ); err != nil {
569
+ if _ , err = p .conn .Exec (context .Background (), query ); err != nil {
588
570
return & database.Error {OrigErr : err , Query : []byte (query )}
589
571
}
590
572
@@ -598,15 +580,15 @@ func (p *Postgres) ensureLockTable() error {
598
580
599
581
var count int
600
582
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
601
- if err := p .db .QueryRow (query , p .config .LockTable ).Scan (& count ); err != nil {
583
+ if err := p .conn .QueryRow (context . Background (), query , p .config .LockTable ).Scan (& count ); err != nil {
602
584
return & database.Error {OrigErr : err , Query : []byte (query )}
603
585
}
604
586
if count == 1 {
605
587
return nil
606
588
}
607
589
608
- query = `CREATE TABLE ` + pq . QuoteIdentifier (p .config .LockTable ) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
609
- if _ , err := p .db .Exec (query ); err != nil {
590
+ query = `CREATE TABLE ` + quoteIdentifier (p .config .LockTable ) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
591
+ if _ , err := p .conn .Exec (context . Background (), query ); err != nil {
610
592
return & database.Error {OrigErr : err , Query : []byte (query )}
611
593
}
612
594
@@ -615,9 +597,5 @@ func (p *Postgres) ensureLockTable() error {
615
597
616
598
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
617
599
func quoteIdentifier (name string ) string {
618
- end := strings .IndexRune (name , 0 )
619
- if end > - 1 {
620
- name = name [:end ]
621
- }
622
- return `"` + strings .Replace (name , `"` , `""` , - 1 ) + `"`
600
+ return pgx .Identifier ([]string {name }).Sanitize ()
623
601
}
0 commit comments