@@ -24,6 +24,7 @@ import (
24
24
"fmt"
25
25
"io"
26
26
"net"
27
+ "sort"
27
28
"strings"
28
29
"sync"
29
30
"sync/atomic"
@@ -153,6 +154,10 @@ type Dialer struct {
153
154
154
155
// iamTokenSource supplies the OAuth2 token used for IAM DB Authn.
155
156
iamTokenSource oauth2.TokenSource
157
+
158
+ // resolver does SRV record DNS lookups when resolving DNS name dialer
159
+ // configuration.
160
+ resolver NetResolver
156
161
}
157
162
158
163
var (
@@ -253,6 +258,11 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
253
258
if err != nil {
254
259
return nil , err
255
260
}
261
+ var r NetResolver = net .DefaultResolver
262
+ if cfg .resolver != nil {
263
+ r = cfg .resolver
264
+ }
265
+
256
266
d := & Dialer {
257
267
closed : make (chan struct {}),
258
268
cache : make (map [instance.ConnName ]monitoredCache ),
@@ -265,10 +275,25 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
265
275
dialerID : uuid .New ().String (),
266
276
iamTokenSource : cfg .iamLoginTokenSource ,
267
277
dialFunc : cfg .dialFunc ,
278
+ resolver : r ,
268
279
}
269
280
return d , nil
270
281
}
271
282
283
+ func (d * Dialer ) resolveInstanceName (ctx context.Context , icn string ) (instance.ConnName , error ) {
284
+ cn , err := instance .ParseConnName (icn )
285
+ if err != nil {
286
+ // The connection name was not project:region:instance
287
+ // Attempt to query a SRV record and see if it works instead.
288
+ cn , err = d .queryDNS (ctx , icn )
289
+ if err != nil {
290
+ return instance.ConnName {}, err
291
+ }
292
+ }
293
+
294
+ return cn , nil
295
+ }
296
+
272
297
// Dial returns a net.Conn connected to the specified Cloud SQL instance. The
273
298
// icn argument must be the instance's connection name, which is in the format
274
299
// "project-name:region:instance-name".
@@ -288,7 +313,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
288
313
go trace .RecordDialError (context .Background (), icn , d .dialerID , err )
289
314
endDial (err )
290
315
}()
291
- cn , err := instance . ParseConnName ( icn )
316
+ cn , err := d . resolveInstanceName ( ctx , icn )
292
317
if err != nil {
293
318
return nil , err
294
319
}
@@ -429,7 +454,7 @@ func validClientCert(
429
454
// the instance:
430
455
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion
431
456
func (d * Dialer ) EngineVersion (ctx context.Context , icn string ) (string , error ) {
432
- cn , err := instance . ParseConnName ( icn )
457
+ cn , err := d . resolveInstanceName ( ctx , icn )
433
458
if err != nil {
434
459
return "" , err
435
460
}
@@ -449,7 +474,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
449
474
// Use Warmup to start the refresh process early if you don't know when you'll
450
475
// need to call "Dial".
451
476
func (d * Dialer ) Warmup (ctx context.Context , icn string , opts ... DialOption ) error {
452
- cn , err := instance . ParseConnName ( icn )
477
+ cn , err := d . resolveInstanceName ( ctx , icn )
453
478
if err != nil {
454
479
return err
455
480
}
@@ -565,3 +590,59 @@ func (d *Dialer) connectionInfoCache(
565
590
566
591
return c , nil
567
592
}
593
+
594
+ // queryDNS attempts to resolve a SRV record for the domain name.
595
+ // The DNS SRV record's target field is used as instance name.
596
+ //
597
+ // This handles several conditions where the DNS records may be missing or
598
+ // invalid:
599
+ // - The domain name resolves to 0 DNS records - return an error
600
+ // - Some DNS records to not contain a well-formed instance name - return the
601
+ // first well-formed instance name. If none found return an error.
602
+ // - The domain name resolves to 2 or more DNS record - return first valid
603
+ // record when sorted by priority: lowest value first, then by target:
604
+ // alphabetically.
605
+ func (d * Dialer ) queryDNS (ctx context.Context , domainName string ) (instance.ConnName , error ) {
606
+ // Attempt to query the SRV records.
607
+ // This could return a partial error where both err != nil && len(records) > 0.
608
+ _ , records , err := d .resolver .LookupSRV (ctx , "" , "" , domainName )
609
+
610
+ // Process the records returning the first valid SRV record.
611
+
612
+ // Sort the record slice so that lowest priority comes first, then
613
+ // alphabetically by instance name
614
+ sort .Slice (records , func (i , j int ) bool {
615
+ if records [i ].Priority == records [j ].Priority {
616
+ return records [i ].Target < records [j ].Target
617
+ }
618
+ return records [i ].Priority < records [j ].Priority
619
+ })
620
+
621
+ var perr error
622
+ // Attempt to parse records, returning the first valid record.
623
+ for _ , record := range records {
624
+ // Remove trailing '.' from target value.
625
+ target := strings .TrimRight (record .Target , "." )
626
+
627
+ // Parse the target as a CN
628
+ cn , parseErr := instance .ParseConnName (target )
629
+ if parseErr != nil {
630
+ perr = fmt .Errorf ("unable to parse SRV for %q: %v" , domainName , parseErr )
631
+ continue
632
+ }
633
+ return cn , nil
634
+ }
635
+
636
+ // If resolve failed and no records were found, return the error.
637
+ if err != nil {
638
+ return instance.ConnName {}, fmt .Errorf ("unable to resolve SRV record for %q: %v" , domainName , err )
639
+ }
640
+
641
+ // If all the records failed to parse, return one of the parse errors
642
+ if perr != nil {
643
+ return instance.ConnName {}, perr
644
+ }
645
+
646
+ // No records were found, return an error.
647
+ return instance.ConnName {}, fmt .Errorf ("no valid SRV records found for %q" , domainName )
648
+ }
0 commit comments