@@ -24,6 +24,7 @@ import (
24
24
"fmt"
25
25
"io"
26
26
"net"
27
+ "sort"
27
28
"strings"
28
29
"sync"
29
30
"sync/atomic"
@@ -153,6 +154,10 @@ type Dialer struct {
153
154
154
155
// iamTokenSource supplies the OAuth2 token used for IAM DB Authn.
155
156
iamTokenSource oauth2.TokenSource
157
+
158
+ // resolver does SRV record DNS lookups when resolving DNS name dialer
159
+ // configuration.
160
+ resolver NetResolver
156
161
}
157
162
158
163
var (
@@ -253,6 +258,13 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
253
258
if err != nil {
254
259
return nil , err
255
260
}
261
+ var r NetResolver
262
+ if cfg .resolver != nil {
263
+ r = cfg .resolver
264
+ } else {
265
+ r = net .DefaultResolver
266
+ }
267
+
256
268
d := & Dialer {
257
269
closed : make (chan struct {}),
258
270
cache : make (map [instance.ConnName ]monitoredCache ),
@@ -265,10 +277,25 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
265
277
dialerID : uuid .New ().String (),
266
278
iamTokenSource : cfg .iamLoginTokenSource ,
267
279
dialFunc : cfg .dialFunc ,
280
+ resolver : r ,
268
281
}
269
282
return d , nil
270
283
}
271
284
285
+ func (d * Dialer ) resolveInstanceName (ctx context.Context , icn string ) (instance.ConnName , error ) {
286
+ cn , err := instance .ParseConnName (icn )
287
+ if err != nil {
288
+ // The connection name was not project:region:instance
289
+ // Attempt to query a SRV record and see if it works instead.
290
+ cn , err = d .queryDNS (ctx , icn )
291
+ if err != nil {
292
+ return instance.ConnName {}, err
293
+ }
294
+ }
295
+
296
+ return cn , nil
297
+ }
298
+
272
299
// Dial returns a net.Conn connected to the specified Cloud SQL instance. The
273
300
// icn argument must be the instance's connection name, which is in the format
274
301
// "project-name:region:instance-name".
@@ -288,9 +315,14 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
288
315
go trace .RecordDialError (context .Background (), icn , d .dialerID , err )
289
316
endDial (err )
290
317
}()
291
- cn , err := instance . ParseConnName ( icn )
318
+ cn , err := d . resolveInstanceName ( ctx , icn )
292
319
if err != nil {
293
- return nil , err
320
+ // The connection name was not project:region:instance
321
+ // Attempt to query a SRV record and see if it works instead.
322
+ cn , err = d .queryDNS (ctx , icn )
323
+ if err != nil {
324
+ return nil , err
325
+ }
294
326
}
295
327
296
328
cfg := d .defaultDialConfig
@@ -429,7 +461,7 @@ func validClientCert(
429
461
// the instance:
430
462
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion
431
463
func (d * Dialer ) EngineVersion (ctx context.Context , icn string ) (string , error ) {
432
- cn , err := instance . ParseConnName ( icn )
464
+ cn , err := d . resolveInstanceName ( ctx , icn )
433
465
if err != nil {
434
466
return "" , err
435
467
}
@@ -449,7 +481,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
449
481
// Use Warmup to start the refresh process early if you don't know when you'll
450
482
// need to call "Dial".
451
483
func (d * Dialer ) Warmup (ctx context.Context , icn string , opts ... DialOption ) error {
452
- cn , err := instance . ParseConnName ( icn )
484
+ cn , err := d . resolveInstanceName ( ctx , icn )
453
485
if err != nil {
454
486
return err
455
487
}
@@ -565,3 +597,43 @@ func (d *Dialer) connectionInfoCache(
565
597
566
598
return c , nil
567
599
}
600
+
601
+ func (d * Dialer ) queryDNS (ctx context.Context , domainName string ) (instance.ConnName , error ) {
602
+ // Attempt to query the SRV records.
603
+ // This could return a partial error where both err != nil && len(records) > 0.
604
+ _ , records , err := d .resolver .LookupSRV (ctx , "" , "" , domainName )
605
+
606
+ // Process the records returning the first valid SRV record.
607
+
608
+ // Sort the record slice so that lowest priority comes first.
609
+ sort .Slice (records , func (i , j int ) bool {
610
+ return records [i ].Priority < records [j ].Priority
611
+ })
612
+ var perr error
613
+ // Attempt to parse records, returning the first valid record.
614
+ for _ , record := range records {
615
+ // Remove trailing '.' from target value.
616
+ target := strings .TrimRight (record .Target , "." )
617
+
618
+ // Parse the target as a CN
619
+ cn , parseErr := instance .ParseConnName (target )
620
+ if parseErr != nil {
621
+ perr = fmt .Errorf ("unable to parse SRV for %s: %v" , domainName , parseErr )
622
+ continue
623
+ }
624
+ return cn , nil
625
+ }
626
+
627
+ // If resolve failed and no records were found, return the error.
628
+ if err != nil {
629
+ return instance.ConnName {}, fmt .Errorf ("unable to resolve SRV record for %s: %v" , domainName , err )
630
+ }
631
+
632
+ // If all the records failed to parse, return one of the parse errors
633
+ if perr != nil {
634
+ return instance.ConnName {}, perr
635
+ }
636
+
637
+ // No records were found, return an error.
638
+ return instance.ConnName {}, fmt .Errorf ("no valid SRV records found for %s" , domainName )
639
+ }
0 commit comments