Skip to content

Commit 87d1d3b

Browse files
committed
feat: Automatially configure connections using DNS
1 parent fd1a9f1 commit 87d1d3b

File tree

4 files changed

+249
-3
lines changed

4 files changed

+249
-3
lines changed

README.md

+57
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,63 @@ func connect() {
234234
// ... etc
235235
}
236236
```
237+
### Using DNS to identify an instance
238+
239+
The connector can be configured to use DNS to look up an instance. This would
240+
allow you to configure your application to connect to a database instance, and
241+
centrally configure which instance in your DNS zone.
242+
243+
#### Configure your DNS Records
244+
245+
Add a DNS SRV record for the Cloud SQL instance to a **private** DNS server
246+
or a private Google Cloud DNS Zone used by your application.
247+
248+
**Note:** You are strongly discouraged from adding DNS records for your
249+
Cloud SQL instances to a public DNS server. This would allow anyone on the
250+
internet to discover the Cloud SQL instance name.
251+
252+
For example: suppose you wanted to use the domain name
253+
`prod-db.mycompany.example.com` to connect to your database instance
254+
`my-project:region:my-instance`.
255+
256+
- Record type: `SRV`
257+
- Name: `prod-db.mycompany.example.com` – This is the domain name used by the application
258+
- Target: `my-project:region:my-instance` – This is the instance name
259+
- Port: `3307` – always use port 3307
260+
- Priority: `0` – always use priority 0
261+
- Weight: `1` - always use weight 1
262+
263+
#### Configure the connector
264+
265+
Configure the connector as described above, replacing the conenctor ID with
266+
the DNS name.
267+
268+
Adapting the MySQL + database/sql example above:
269+
270+
```go
271+
import (
272+
"database/sql"
273+
274+
"cloud.google.com/go/cloudsqlconn"
275+
"cloud.google.com/go/cloudsqlconn/mysql/mysql"
276+
)
277+
278+
func connect() {
279+
cleanup, err := mysql.RegisterDriver("cloudsql-mysql", cloudsqlconn.WithCredentialsFile("key.json"))
280+
if err != nil {
281+
// ... handle error
282+
}
283+
// call cleanup when you're done with the database connection
284+
defer cleanup()
285+
286+
db, err := sql.Open(
287+
"cloudsql-mysql",
288+
"myuser:mypass@cloudsql-mysql(prod-db.mycompany.example.com)/mydb",
289+
)
290+
// ... etc
291+
}
292+
```
293+
237294

238295
### Using Options
239296

dialer.go

+84-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"fmt"
2525
"io"
2626
"net"
27+
"sort"
2728
"strings"
2829
"sync"
2930
"sync/atomic"
@@ -153,6 +154,10 @@ type Dialer struct {
153154

154155
// iamTokenSource supplies the OAuth2 token used for IAM DB Authn.
155156
iamTokenSource oauth2.TokenSource
157+
158+
// resolver does SRV record DNS lookups when resolving DNS name dialer
159+
// configuration.
160+
resolver NetResolver
156161
}
157162

158163
var (
@@ -253,6 +258,11 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
253258
if err != nil {
254259
return nil, err
255260
}
261+
var r NetResolver = net.DefaultResolver
262+
if cfg.resolver != nil {
263+
r = cfg.resolver
264+
}
265+
256266
d := &Dialer{
257267
closed: make(chan struct{}),
258268
cache: make(map[instance.ConnName]monitoredCache),
@@ -265,10 +275,25 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
265275
dialerID: uuid.New().String(),
266276
iamTokenSource: cfg.iamLoginTokenSource,
267277
dialFunc: cfg.dialFunc,
278+
resolver: r,
268279
}
269280
return d, nil
270281
}
271282

283+
func (d *Dialer) resolveInstanceName(ctx context.Context, icn string) (instance.ConnName, error) {
284+
cn, err := instance.ParseConnName(icn)
285+
if err != nil {
286+
// The connection name was not project:region:instance
287+
// Attempt to query a SRV record and see if it works instead.
288+
cn, err = d.queryDNS(ctx, icn)
289+
if err != nil {
290+
return instance.ConnName{}, err
291+
}
292+
}
293+
294+
return cn, nil
295+
}
296+
272297
// Dial returns a net.Conn connected to the specified Cloud SQL instance. The
273298
// icn argument must be the instance's connection name, which is in the format
274299
// "project-name:region:instance-name".
@@ -288,7 +313,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
288313
go trace.RecordDialError(context.Background(), icn, d.dialerID, err)
289314
endDial(err)
290315
}()
291-
cn, err := instance.ParseConnName(icn)
316+
cn, err := d.resolveInstanceName(ctx, icn)
292317
if err != nil {
293318
return nil, err
294319
}
@@ -429,7 +454,7 @@ func validClientCert(
429454
// the instance:
430455
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion
431456
func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error) {
432-
cn, err := instance.ParseConnName(icn)
457+
cn, err := d.resolveInstanceName(ctx, icn)
433458
if err != nil {
434459
return "", err
435460
}
@@ -449,7 +474,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
449474
// Use Warmup to start the refresh process early if you don't know when you'll
450475
// need to call "Dial".
451476
func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) error {
452-
cn, err := instance.ParseConnName(icn)
477+
cn, err := d.resolveInstanceName(ctx, icn)
453478
if err != nil {
454479
return err
455480
}
@@ -565,3 +590,59 @@ func (d *Dialer) connectionInfoCache(
565590

566591
return c, nil
567592
}
593+
594+
// queryDNS attempts to resolve a SRV record for the domain name.
595+
// The DNS SRV record's target field is used as instance name.
596+
//
597+
// This handles several conditions where the DNS records may be missing or
598+
// invalid:
599+
// - The domain name resolves to 0 DNS records - return an error
600+
// - Some DNS records to not contain a well-formed instance name - return the
601+
// first well-formed instance name. If none found return an error.
602+
// - The domain name resolves to 2 or more DNS record - return first valid
603+
// record when sorted by priority: lowest value first, then by target:
604+
// alphabetically.
605+
func (d *Dialer) queryDNS(ctx context.Context, domainName string) (instance.ConnName, error) {
606+
// Attempt to query the SRV records.
607+
// This could return a partial error where both err != nil && len(records) > 0.
608+
_, records, err := d.resolver.LookupSRV(ctx, "", "", domainName)
609+
610+
// Process the records returning the first valid SRV record.
611+
612+
// Sort the record slice so that lowest priority comes first, then
613+
// alphabetically by instance name
614+
sort.Slice(records, func(i, j int) bool {
615+
if records[i].Priority == records[j].Priority {
616+
return records[i].Target < records[j].Target
617+
}
618+
return records[i].Priority < records[j].Priority
619+
})
620+
621+
var perr error
622+
// Attempt to parse records, returning the first valid record.
623+
for _, record := range records {
624+
// Remove trailing '.' from target value.
625+
target := strings.TrimRight(record.Target, ".")
626+
627+
// Parse the target as a CN
628+
cn, parseErr := instance.ParseConnName(target)
629+
if parseErr != nil {
630+
perr = fmt.Errorf("unable to parse SRV for %q: %v", domainName, parseErr)
631+
continue
632+
}
633+
return cn, nil
634+
}
635+
636+
// If resolve failed and no records were found, return the error.
637+
if err != nil {
638+
return instance.ConnName{}, fmt.Errorf("unable to resolve SRV record for %q: %v", domainName, err)
639+
}
640+
641+
// If all the records failed to parse, return one of the parse errors
642+
if perr != nil {
643+
return instance.ConnName{}, perr
644+
}
645+
646+
// No records were found, return an error.
647+
return instance.ConnName{}, fmt.Errorf("no valid SRV records found for %q", domainName)
648+
}

dialer_test.go

+89
Original file line numberDiff line numberDiff line change
@@ -1016,3 +1016,92 @@ func TestDialerInitializesLazyCache(t *testing.T) {
10161016
t.Fatalf("dialer was initialized with non-lazy type: %T", tt)
10171017
}
10181018
}
1019+
1020+
type fakeResolver struct{}
1021+
1022+
func (r *fakeResolver) LookupSRV(_ context.Context, _, _, name string) (cname string, addrs []*net.SRV, err error) {
1023+
// For TestDialerSuccessfullyDialsDnsSrvRecord
1024+
if name == "db.example.com" {
1025+
return "", []*net.SRV{
1026+
&net.SRV{Target: "my-project:my-region:my-instance."},
1027+
}, nil
1028+
}
1029+
if name == "db2.example.com" {
1030+
return "", []*net.SRV{
1031+
&net.SRV{Target: "my-project:my-region:my-instance"},
1032+
}, nil
1033+
}
1034+
// For TestDialerFailsDnsSrvRecordMalformed
1035+
if name == "malformed.example.com" {
1036+
return "", []*net.SRV{
1037+
&net.SRV{Target: "an-invalid-instance-name"},
1038+
}, nil
1039+
}
1040+
return "", nil, fmt.Errorf("no resolution for %v", name)
1041+
}
1042+
1043+
func TestDialerSuccessfullyDialsDnsSrvRecord(t *testing.T) {
1044+
inst := mock.NewFakeCSQLInstance(
1045+
"my-project", "my-region", "my-instance",
1046+
)
1047+
d := setupDialer(t, setupConfig{
1048+
testInstance: inst,
1049+
reqs: []*mock.Request{
1050+
mock.InstanceGetSuccess(inst, 1),
1051+
mock.CreateEphemeralSuccess(inst, 1),
1052+
},
1053+
dialerOptions: []Option{
1054+
WithTokenSource(mock.EmptyTokenSource{}),
1055+
WithResolver(&fakeResolver{}),
1056+
},
1057+
})
1058+
1059+
// Target has a trailing '.'
1060+
testSuccessfulDial(
1061+
context.Background(), t, d,
1062+
"db.example.com",
1063+
)
1064+
// Target does not have a trailing '.'
1065+
testSuccessfulDial(
1066+
context.Background(), t, d,
1067+
"db2.example.com",
1068+
)
1069+
}
1070+
1071+
func TestDialerFailsDnsSrvRecordMissing(t *testing.T) {
1072+
inst := mock.NewFakeCSQLInstance(
1073+
"my-project", "my-region", "my-instance",
1074+
)
1075+
d := setupDialer(t, setupConfig{
1076+
testInstance: inst,
1077+
reqs: []*mock.Request{},
1078+
dialerOptions: []Option{
1079+
WithTokenSource(mock.EmptyTokenSource{}),
1080+
WithResolver(&fakeResolver{}),
1081+
},
1082+
})
1083+
_, err := d.Dial(context.Background(), "doesnt-exist.example.com")
1084+
wantMsg := "unable to resolve SRV record for doesnt-exist.example.com"
1085+
if !strings.Contains(err.Error(), wantMsg) {
1086+
t.Fatalf("want = %v, got = %v", wantMsg, err)
1087+
}
1088+
}
1089+
1090+
func TestDialerFailsDnsSrvRecordMalformed(t *testing.T) {
1091+
inst := mock.NewFakeCSQLInstance(
1092+
"my-project", "my-region", "my-instance",
1093+
)
1094+
d := setupDialer(t, setupConfig{
1095+
testInstance: inst,
1096+
reqs: []*mock.Request{},
1097+
dialerOptions: []Option{
1098+
WithTokenSource(mock.EmptyTokenSource{}),
1099+
WithResolver(&fakeResolver{}),
1100+
},
1101+
})
1102+
_, err := d.Dial(context.Background(), "malformed.example.com")
1103+
wantMsg := "unable to parse SRV for malformed.example.com"
1104+
if !strings.Contains(err.Error(), wantMsg) {
1105+
t.Fatalf("want = %v, got = %v", wantMsg, err)
1106+
}
1107+
}

options.go

+19
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ type dialerConfig struct {
5252
setCredentials bool
5353
setTokenSource bool
5454
setIAMAuthNTokenSource bool
55+
resolver NetResolver
5556
// err tracks any dialer options that may have failed.
5657
err error
5758
}
@@ -234,6 +235,24 @@ func WithIAMAuthN() Option {
234235
}
235236
}
236237

238+
// NetResolver groups the methods on net.Resolver that are used by the DNS
239+
// resolver implementation. This allows an application to replace the default
240+
// net.DefaultResolver with a custom implementation. For example: the
241+
// application may need to connect to a specific DNS server using a specially
242+
// configured instance of net.Resolver.
243+
type NetResolver interface {
244+
LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)
245+
}
246+
247+
// WithResolver replaces the default DNS resolver with an alternate
248+
// implementation to use when resolving SRV records containing the
249+
// instance name. By default, the dialer will use net.DefaultResolver.
250+
func WithResolver(r NetResolver) Option {
251+
return func(d *dialerConfig) {
252+
d.resolver = r
253+
}
254+
}
255+
237256
type debugLoggerWithoutContext struct {
238257
logger debug.Logger
239258
}

0 commit comments

Comments
 (0)