Skip to content

Commit dd0c834

Browse files
committed
feat: Automatially configure connections using DNS
1 parent fd1a9f1 commit dd0c834

File tree

3 files changed

+182
-4
lines changed

3 files changed

+182
-4
lines changed

dialer.go

+76-4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"fmt"
2525
"io"
2626
"net"
27+
"sort"
2728
"strings"
2829
"sync"
2930
"sync/atomic"
@@ -153,6 +154,10 @@ type Dialer struct {
153154

154155
// iamTokenSource supplies the OAuth2 token used for IAM DB Authn.
155156
iamTokenSource oauth2.TokenSource
157+
158+
// resolver does SRV record DNS lookups when resolving DNS name dialer
159+
// configuration.
160+
resolver NetResolver
156161
}
157162

158163
var (
@@ -253,6 +258,13 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
253258
if err != nil {
254259
return nil, err
255260
}
261+
var r NetResolver
262+
if cfg.resolver != nil {
263+
r = cfg.resolver
264+
} else {
265+
r = net.DefaultResolver
266+
}
267+
256268
d := &Dialer{
257269
closed: make(chan struct{}),
258270
cache: make(map[instance.ConnName]monitoredCache),
@@ -265,10 +277,25 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
265277
dialerID: uuid.New().String(),
266278
iamTokenSource: cfg.iamLoginTokenSource,
267279
dialFunc: cfg.dialFunc,
280+
resolver: r,
268281
}
269282
return d, nil
270283
}
271284

285+
func (d *Dialer) resolveInstanceName(ctx context.Context, icn string) (instance.ConnName, error) {
286+
cn, err := instance.ParseConnName(icn)
287+
if err != nil {
288+
// The connection name was not project:region:instance
289+
// Attempt to query a SRV record and see if it works instead.
290+
cn, err = d.queryDNS(ctx, icn)
291+
if err != nil {
292+
return instance.ConnName{}, err
293+
}
294+
}
295+
296+
return cn, nil
297+
}
298+
272299
// Dial returns a net.Conn connected to the specified Cloud SQL instance. The
273300
// icn argument must be the instance's connection name, which is in the format
274301
// "project-name:region:instance-name".
@@ -288,9 +315,14 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
288315
go trace.RecordDialError(context.Background(), icn, d.dialerID, err)
289316
endDial(err)
290317
}()
291-
cn, err := instance.ParseConnName(icn)
318+
cn, err := d.resolveInstanceName(ctx, icn)
292319
if err != nil {
293-
return nil, err
320+
// The connection name was not project:region:instance
321+
// Attempt to query a SRV record and see if it works instead.
322+
cn, err = d.queryDNS(ctx, icn)
323+
if err != nil {
324+
return nil, err
325+
}
294326
}
295327

296328
cfg := d.defaultDialConfig
@@ -429,7 +461,7 @@ func validClientCert(
429461
// the instance:
430462
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion
431463
func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error) {
432-
cn, err := instance.ParseConnName(icn)
464+
cn, err := d.resolveInstanceName(ctx, icn)
433465
if err != nil {
434466
return "", err
435467
}
@@ -449,7 +481,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
449481
// Use Warmup to start the refresh process early if you don't know when you'll
450482
// need to call "Dial".
451483
func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) error {
452-
cn, err := instance.ParseConnName(icn)
484+
cn, err := d.resolveInstanceName(ctx, icn)
453485
if err != nil {
454486
return err
455487
}
@@ -565,3 +597,43 @@ func (d *Dialer) connectionInfoCache(
565597

566598
return c, nil
567599
}
600+
601+
func (d *Dialer) queryDNS(ctx context.Context, domainName string) (instance.ConnName, error) {
602+
// Attempt to query the SRV records.
603+
// This could return a partial error where both err != nil && len(records) > 0.
604+
_, records, err := d.resolver.LookupSRV(ctx, "", "", domainName)
605+
606+
// Process the records returning the first valid SRV record.
607+
608+
// Sort the record slice so that lowest priority comes first.
609+
sort.Slice(records, func(i, j int) bool {
610+
return records[i].Priority < records[j].Priority
611+
})
612+
var perr error
613+
// Attempt to parse records, returning the first valid record.
614+
for _, record := range records {
615+
// Remove trailing '.' from target value.
616+
target := strings.TrimRight(record.Target, ".")
617+
618+
// Parse the target as a CN
619+
cn, parseErr := instance.ParseConnName(target)
620+
if parseErr != nil {
621+
perr = fmt.Errorf("unable to parse SRV for %s: %v", domainName, parseErr)
622+
continue
623+
}
624+
return cn, nil
625+
}
626+
627+
// If resolve failed and no records were found, return the error.
628+
if err != nil {
629+
return instance.ConnName{}, fmt.Errorf("unable to resolve SRV record for %s: %v", domainName, err)
630+
}
631+
632+
// If all the records failed to parse, return one of the parse errors
633+
if perr != nil {
634+
return instance.ConnName{}, perr
635+
}
636+
637+
// No records were found, return an error.
638+
return instance.ConnName{}, fmt.Errorf("no valid SRV records found for %s", domainName)
639+
}

dialer_test.go

+89
Original file line numberDiff line numberDiff line change
@@ -1016,3 +1016,92 @@ func TestDialerInitializesLazyCache(t *testing.T) {
10161016
t.Fatalf("dialer was initialized with non-lazy type: %T", tt)
10171017
}
10181018
}
1019+
1020+
type fakeResolver struct{}
1021+
1022+
func (r *fakeResolver) LookupSRV(_ context.Context, _, _, name string) (cname string, addrs []*net.SRV, err error) {
1023+
// For TestDialerSuccessfullyDialsDnsSrvRecord
1024+
if name == "db.example.com" {
1025+
return "", []*net.SRV{
1026+
&net.SRV{Target: "my-project:my-region:my-instance."},
1027+
}, nil
1028+
}
1029+
if name == "db2.example.com" {
1030+
return "", []*net.SRV{
1031+
&net.SRV{Target: "my-project:my-region:my-instance"},
1032+
}, nil
1033+
}
1034+
// For TestDialerFailsDnsSrvRecordMalformed
1035+
if name == "malformed.example.com" {
1036+
return "", []*net.SRV{
1037+
&net.SRV{Target: "an-invalid-instance-name"},
1038+
}, nil
1039+
}
1040+
return "", nil, fmt.Errorf("no resolution for %v", name)
1041+
}
1042+
1043+
func TestDialerSuccessfullyDialsDnsSrvRecord(t *testing.T) {
1044+
inst := mock.NewFakeCSQLInstance(
1045+
"my-project", "my-region", "my-instance",
1046+
)
1047+
d := setupDialer(t, setupConfig{
1048+
testInstance: inst,
1049+
reqs: []*mock.Request{
1050+
mock.InstanceGetSuccess(inst, 1),
1051+
mock.CreateEphemeralSuccess(inst, 1),
1052+
},
1053+
dialerOptions: []Option{
1054+
WithTokenSource(mock.EmptyTokenSource{}),
1055+
WithResolver(&fakeResolver{}),
1056+
},
1057+
})
1058+
1059+
// Target has a trailing '.'
1060+
testSuccessfulDial(
1061+
context.Background(), t, d,
1062+
"db.example.com",
1063+
)
1064+
// Target does not have a trailing '.'
1065+
testSuccessfulDial(
1066+
context.Background(), t, d,
1067+
"db2.example.com",
1068+
)
1069+
}
1070+
1071+
func TestDialerFailsDnsSrvRecordMissing(t *testing.T) {
1072+
inst := mock.NewFakeCSQLInstance(
1073+
"my-project", "my-region", "my-instance",
1074+
)
1075+
d := setupDialer(t, setupConfig{
1076+
testInstance: inst,
1077+
reqs: []*mock.Request{},
1078+
dialerOptions: []Option{
1079+
WithTokenSource(mock.EmptyTokenSource{}),
1080+
WithResolver(&fakeResolver{}),
1081+
},
1082+
})
1083+
_, err := d.Dial(context.Background(), "doesnt-exist.example.com")
1084+
wantMsg := "unable to resolve SRV record for doesnt-exist.example.com"
1085+
if !strings.Contains(err.Error(), wantMsg) {
1086+
t.Fatalf("want = %v, got = %v", wantMsg, err)
1087+
}
1088+
}
1089+
1090+
func TestDialerFailsDnsSrvRecordMalformed(t *testing.T) {
1091+
inst := mock.NewFakeCSQLInstance(
1092+
"my-project", "my-region", "my-instance",
1093+
)
1094+
d := setupDialer(t, setupConfig{
1095+
testInstance: inst,
1096+
reqs: []*mock.Request{},
1097+
dialerOptions: []Option{
1098+
WithTokenSource(mock.EmptyTokenSource{}),
1099+
WithResolver(&fakeResolver{}),
1100+
},
1101+
})
1102+
_, err := d.Dial(context.Background(), "malformed.example.com")
1103+
wantMsg := "unable to parse SRV for malformed.example.com"
1104+
if !strings.Contains(err.Error(), wantMsg) {
1105+
t.Fatalf("want = %v, got = %v", wantMsg, err)
1106+
}
1107+
}

options.go

+17
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ import (
3434
// An Option is an option for configuring a Dialer.
3535
type Option func(d *dialerConfig)
3636

37+
// NetResolver groups the methods on net.Resolver that are used by the DNS
38+
// resolver implementation. This allows the default net.Resolver instance to be
39+
// overridden from tests.
40+
type NetResolver interface {
41+
LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)
42+
}
43+
3744
type dialerConfig struct {
3845
rsaKey *rsa.PrivateKey
3946
sqladminOpts []apiopt.ClientOption
@@ -52,6 +59,7 @@ type dialerConfig struct {
5259
setCredentials bool
5360
setTokenSource bool
5461
setIAMAuthNTokenSource bool
62+
resolver NetResolver
5563
// err tracks any dialer options that may have failed.
5664
err error
5765
}
@@ -234,6 +242,15 @@ func WithIAMAuthN() Option {
234242
}
235243
}
236244

245+
// WithResolver replaces the default DNS resolver with an alternate
246+
// implementation to use when resolving SRV records containing the
247+
// instance name. By default, the dialer will use net.DefaultResolver.
248+
func WithResolver(r NetResolver) Option {
249+
return func(d *dialerConfig) {
250+
d.resolver = r
251+
}
252+
}
253+
237254
type debugLoggerWithoutContext struct {
238255
logger debug.Logger
239256
}

0 commit comments

Comments
 (0)