Skip to content

Commit 561c650

Browse files
committed
Move SQL escape functions to the Postgres package
These are specific to Postgres. Add tests and remove unused functions.
1 parent c7f5e99 commit 561c650

File tree

5 files changed

+44
-68
lines changed

5 files changed

+44
-68
lines changed

internal/pgbouncer/postgres.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ func sqlAuthenticationQuery(sqlFunctionName string) string {
4141
// No replicators.
4242
`NOT pg_authid.rolreplication`,
4343
// Not the PgBouncer role itself.
44-
`pg_authid.rolname <> ` + util.SQLQuoteLiteral(postgresqlUser),
44+
`pg_authid.rolname <> ` + postgres.QuoteLiteral(postgresqlUser),
4545
// Those without a password expiration or an expiration in the future.
4646
`(pg_authid.rolvaliduntil IS NULL OR pg_authid.rolvaliduntil >= CURRENT_TIMESTAMP)`,
4747
}, "\n AND ")
4848

4949
return strings.TrimSpace(`
5050
CREATE OR REPLACE FUNCTION ` + sqlFunctionName + `(username TEXT)
51-
RETURNS TABLE(username TEXT, password TEXT) AS ` + util.SQLQuoteLiteral(`
51+
RETURNS TABLE(username TEXT, password TEXT) AS ` + postgres.QuoteLiteral(`
5252
SELECT rolname::TEXT, rolpassword::TEXT
5353
FROM pg_catalog.pg_authid
5454
WHERE pg_authid.rolname = $1

internal/pgbouncer/postgres_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ import (
1919
func TestSQLAuthenticationQuery(t *testing.T) {
2020
assert.Equal(t, sqlAuthenticationQuery("some.fn_name"),
2121
`CREATE OR REPLACE FUNCTION some.fn_name(username TEXT)
22-
RETURNS TABLE(username TEXT, password TEXT) AS '
22+
RETURNS TABLE(username TEXT, password TEXT) AS E'
2323
SELECT rolname::TEXT, rolpassword::TEXT
2424
FROM pg_catalog.pg_authid
2525
WHERE pg_authid.rolname = $1
2626
AND pg_authid.rolcanlogin
2727
AND NOT pg_authid.rolsuper
2828
AND NOT pg_authid.rolreplication
29-
AND pg_authid.rolname <> ''_crunchypgbouncer''
29+
AND pg_authid.rolname <> E''_crunchypgbouncer''
3030
AND (pg_authid.rolvaliduntil IS NULL OR pg_authid.rolvaliduntil >= CURRENT_TIMESTAMP)'
3131
LANGUAGE SQL STABLE SECURITY DEFINER;`)
3232
}
@@ -150,14 +150,14 @@ REVOKE ALL PRIVILEGES
150150
GRANT USAGE
151151
ON SCHEMA :"namespace" TO :"username";
152152
CREATE OR REPLACE FUNCTION :"namespace".get_auth(username TEXT)
153-
RETURNS TABLE(username TEXT, password TEXT) AS '
153+
RETURNS TABLE(username TEXT, password TEXT) AS E'
154154
SELECT rolname::TEXT, rolpassword::TEXT
155155
FROM pg_catalog.pg_authid
156156
WHERE pg_authid.rolname = $1
157157
AND pg_authid.rolcanlogin
158158
AND NOT pg_authid.rolsuper
159159
AND NOT pg_authid.rolreplication
160-
AND pg_authid.rolname <> ''_crunchypgbouncer''
160+
AND pg_authid.rolname <> E''_crunchypgbouncer''
161161
AND (pg_authid.rolvaliduntil IS NULL OR pg_authid.rolvaliduntil >= CURRENT_TIMESTAMP)'
162162
LANGUAGE SQL STABLE SECURITY DEFINER;
163163
REVOKE ALL PRIVILEGES

internal/postgres/sql.go

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright 2021 - 2024 Crunchy Data Solutions, Inc.
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
package postgres
6+
7+
import "strings"
8+
9+
// escapeLiteral is called by QuoteLiteral to add backslashes before special
10+
// characters of the "escape" string syntax. Double quote marks to escape them
11+
// regardless of the "backslash_quote" parameter.
12+
var escapeLiteral = strings.NewReplacer(`'`, `''`, `\`, `\\`).Replace
13+
14+
// QuoteLiteral escapes v so it can be safely used as a literal (or constant)
15+
// in an SQL statement.
16+
func QuoteLiteral(v string) string {
17+
// Use the "escape" syntax to ensure that backslashes behave consistently regardless
18+
// of the "standard_conforming_strings" parameter. Include a space before so
19+
// the "E" cannot change the meaning of an adjacent SQL keyword or identifier.
20+
// - https://www.postgresql.org/docs/current/sql-syntax-lexical.html
21+
return ` E'` + escapeLiteral(v) + `'`
22+
}

internal/postgres/sql_test.go

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Copyright 2021 - 2024 Crunchy Data Solutions, Inc.
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
package postgres
6+
7+
import (
8+
"testing"
9+
10+
"gotest.tools/v3/assert"
11+
)
12+
13+
func TestQuoteLiteral(t *testing.T) {
14+
assert.Equal(t, QuoteLiteral(``), ` E''`)
15+
assert.Equal(t, QuoteLiteral(`ab"cd\ef'gh`), ` E'ab"cd\\ef''gh'`)
16+
}

internal/util/util.go

-62
This file was deleted.

0 commit comments

Comments
 (0)