@@ -76,7 +76,7 @@ type keyGenerator struct {
76
76
// - generate an RSA key lazily when it's requested, or
77
77
// - (default) immediately generate an RSA key as part of the initializer.
78
78
func newKeyGenerator (
79
- k * rsa.PrivateKey , lazy bool , genFunc func () (* rsa.PrivateKey , error ),
79
+ k * rsa.PrivateKey , lazy bool , genFunc func () (* rsa.PrivateKey , error ),
80
80
) (* keyGenerator , error ) {
81
81
g := & keyGenerator {genFunc : genFunc }
82
82
switch {
@@ -107,23 +107,16 @@ type connectionInfoCache interface {
107
107
ConnectionInfo (context.Context ) (cloudsql.ConnectionInfo , error )
108
108
UpdateRefresh (* bool )
109
109
ForceRefresh ()
110
+ UseIAMAuthN () bool
110
111
io.Closer
111
112
}
112
113
113
- // monitoredCache is a wrapper around a connectionInfoCache that tracks the
114
- // number of connections to the associated instance.
115
- type monitoredCache struct {
116
- openConns * uint64
117
-
118
- connectionInfoCache
119
- }
120
-
121
114
// A Dialer is used to create connections to Cloud SQL instances.
122
115
//
123
116
// Use NewDialer to initialize a Dialer.
124
117
type Dialer struct {
125
118
lock sync.RWMutex
126
- cache map [instance.ConnName ]monitoredCache
119
+ cache map [instance.ConnName ]* monitoredCache
127
120
keyGenerator * keyGenerator
128
121
refreshTimeout time.Duration
129
122
// closed reports if the dialer has been closed.
@@ -155,7 +148,8 @@ type Dialer struct {
155
148
iamTokenSource oauth2.TokenSource
156
149
157
150
// resolver converts instance names into DNS names.
158
- resolver instance.ConnectionNameResolver
151
+ resolver instance.ConnectionNameResolver
152
+ failoverPeriod time.Duration
159
153
}
160
154
161
155
var (
@@ -179,6 +173,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
179
173
logger : nullLogger {},
180
174
useragents : []string {userAgent },
181
175
serviceUniverse : "googleapis.com" ,
176
+ failoverPeriod : cloudsql .FailoverPeriod ,
182
177
}
183
178
for _ , opt := range opts {
184
179
opt (cfg )
@@ -192,6 +187,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
192
187
if cfg .setIAMAuthNTokenSource && ! cfg .useIAMAuthN {
193
188
return nil , errUseTokenSource
194
189
}
190
+
195
191
// Add this to the end to make sure it's not overridden
196
192
cfg .sqladminOpts = append (cfg .sqladminOpts , option .WithUserAgent (strings .Join (cfg .useragents , " " )))
197
193
@@ -219,7 +215,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
219
215
if cfg .setUniverseDomain && cfg .setAdminAPIEndpoint {
220
216
return nil , errors .New (
221
217
"can not use WithAdminAPIEndpoint and WithUniverseDomain Options together, " +
222
- "use WithAdminAPIEndpoint (it already contains the universe domain)" ,
218
+ "use WithAdminAPIEndpoint (it already contains the universe domain)" ,
223
219
)
224
220
}
225
221
@@ -263,7 +259,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
263
259
264
260
d := & Dialer {
265
261
closed : make (chan struct {}),
266
- cache : make (map [instance.ConnName ]monitoredCache ),
262
+ cache : make (map [instance.ConnName ]* monitoredCache ),
267
263
lazyRefresh : cfg .lazyRefresh ,
268
264
keyGenerator : g ,
269
265
refreshTimeout : cfg .refreshTimeout ,
@@ -274,7 +270,9 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
274
270
iamTokenSource : cfg .iamLoginTokenSource ,
275
271
dialFunc : cfg .dialFunc ,
276
272
resolver : r ,
273
+ failoverPeriod : cfg .failoverPeriod ,
277
274
}
275
+
278
276
return d , nil
279
277
}
280
278
@@ -380,22 +378,31 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
380
378
381
379
latency := time .Since (startTime ).Milliseconds ()
382
380
go func () {
383
- n := atomic .AddUint64 (c .openConns , 1 )
381
+ n := atomic .AddUint64 (c .openConnsCount , 1 )
384
382
trace .RecordOpenConnections (ctx , int64 (n ), d .dialerID , cn .String ())
385
383
trace .RecordDialLatency (ctx , icn , d .dialerID , latency )
386
384
}()
387
385
388
- return newInstrumentedConn (tlsConn , func () {
389
- n := atomic .AddUint64 (c .openConns , ^ uint64 (0 ))
386
+ iConn := newInstrumentedConn (tlsConn , func () {
387
+ n := atomic .AddUint64 (c .openConnsCount , ^ uint64 (0 ))
390
388
trace .RecordOpenConnections (context .Background (), int64 (n ), d .dialerID , cn .String ())
391
- }, d .dialerID , cn .String ()), nil
389
+ }, d .dialerID , cn .String ())
390
+
391
+ // If this connection was opened using a Domain Name, then store it for later
392
+ // in case it needs to be forcibly closed.
393
+ if cn .DomainName () != "" {
394
+ c .mu .Lock ()
395
+ c .openConns = append (c .openConns , iConn )
396
+ c .mu .Unlock ()
397
+ }
398
+ return iConn , nil
392
399
}
393
400
394
401
// removeCached stops all background refreshes and deletes the connection
395
402
// info cache from the map of caches.
396
403
func (d * Dialer ) removeCached (
397
- ctx context.Context ,
398
- i instance.ConnName , c connectionInfoCache , err error ,
404
+ ctx context.Context ,
405
+ i instance.ConnName , c connectionInfoCache , err error ,
399
406
) {
400
407
d .logger .Debugf (
401
408
ctx ,
@@ -413,8 +420,8 @@ func (d *Dialer) removeCached(
413
420
// the cache is unexpired. The time comparisons strip the monotonic clock value
414
421
// to ensure an accurate result, even after laptop sleep.
415
422
func validClientCert (
416
- ctx context.Context , cn instance.ConnName ,
417
- l debug.ContextLogger , expiration time.Time ,
423
+ ctx context.Context , cn instance.ConnName ,
424
+ l debug.ContextLogger , expiration time.Time ,
418
425
) bool {
419
426
// Use UTC() to strip monotonic clock value to guard against inaccurate
420
427
// comparisons, especially after laptop sleep.
@@ -448,7 +455,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
448
455
}
449
456
ci , err := c .ConnectionInfo (ctx )
450
457
if err != nil {
451
- d .removeCached (ctx , cn , c , err )
458
+ d .removeCached (ctx , cn , c . connectionInfoCache , err )
452
459
return "" , err
453
460
}
454
461
return ci .DBVersion , nil
@@ -472,7 +479,7 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
472
479
}
473
480
_ , err = c .ConnectionInfo (ctx )
474
481
if err != nil {
475
- d .removeCached (ctx , cn , c , err )
482
+ d .removeCached (ctx , cn , c . connectionInfoCache , err )
476
483
}
477
484
return err
478
485
}
@@ -493,6 +500,8 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
493
500
type instrumentedConn struct {
494
501
net.Conn
495
502
closeFunc func ()
503
+ mu sync.RWMutex
504
+ closed bool
496
505
dialerID string
497
506
connName string
498
507
}
@@ -517,9 +526,19 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
517
526
return bytesWritten , err
518
527
}
519
528
529
+ // isClosed returns true if this connection is closing or is already closed.
530
+ func (i * instrumentedConn ) isClosed () bool {
531
+ i .mu .RLock ()
532
+ defer i .mu .RUnlock ()
533
+ return i .closed
534
+ }
535
+
520
536
// Close delegates to the underlying net.Conn interface and reports the close
521
537
// to the provided closeFunc only when Close returns no error.
522
538
func (i * instrumentedConn ) Close () error {
539
+ i .mu .Lock ()
540
+ defer i .mu .Unlock ()
541
+ i .closed = true
523
542
err := i .Conn .Close ()
524
543
if err != nil {
525
544
return err
@@ -550,51 +569,105 @@ func (d *Dialer) Close() error {
550
569
// connection info Cache in a threadsafe way. It will create a new cache,
551
570
// modify the existing one, or leave it unchanged as needed.
552
571
func (d * Dialer ) connectionInfoCache (
553
- ctx context.Context , cn instance.ConnName , useIAMAuthN * bool ,
554
- ) (monitoredCache , error ) {
572
+ ctx context.Context , cn instance.ConnName , useIAMAuthN * bool ,
573
+ ) (* monitoredCache , error ) {
555
574
d .lock .RLock ()
556
575
c , ok := d .cache [cn ]
557
576
d .lock .RUnlock ()
558
- if ! ok {
559
- d .lock .Lock ()
560
- defer d .lock .Unlock ()
561
- // Recheck to ensure instance wasn't created or changed between locks
562
- c , ok = d .cache [cn ]
563
- if ! ok {
564
- var useIAMAuthNDial bool
565
- if useIAMAuthN != nil {
566
- useIAMAuthNDial = * useIAMAuthN
567
- }
568
- d .logger .Debugf (ctx , "[%v] Connection info added to cache" , cn .String ())
569
- k , err := d .keyGenerator .rsaKey ()
570
- if err != nil {
571
- return monitoredCache {}, err
572
- }
573
- var cache connectionInfoCache
574
- if d .lazyRefresh {
575
- cache = cloudsql .NewLazyRefreshCache (
576
- cn ,
577
- d .logger ,
578
- d .sqladmin , k ,
579
- d .refreshTimeout , d .iamTokenSource ,
580
- d .dialerID , useIAMAuthNDial ,
581
- )
582
- } else {
583
- cache = cloudsql .NewRefreshAheadCache (
584
- cn ,
585
- d .logger ,
586
- d .sqladmin , k ,
587
- d .refreshTimeout , d .iamTokenSource ,
588
- d .dialerID , useIAMAuthNDial ,
589
- )
590
- }
591
- var count uint64
592
- c = monitoredCache {openConns : & count , connectionInfoCache : cache }
593
- d .cache [cn ] = c
594
- }
577
+
578
+ // recheck the domain name, this may close the cache.
579
+ if ok {
580
+ c .checkDomainName (ctx )
581
+ }
582
+
583
+ if ok && ! c .isClosed () {
584
+ c .UpdateRefresh (useIAMAuthN )
585
+ return c , nil
586
+ }
587
+
588
+ d .lock .Lock ()
589
+ defer d .lock .Unlock ()
590
+
591
+ // Recheck to ensure instance wasn't created or changed between locks
592
+ c , ok = d .cache [cn ]
593
+
594
+ // c exists and is not closed
595
+ if ok && ! c .isClosed () {
596
+ c .UpdateRefresh (useIAMAuthN )
597
+ return c , nil
598
+ }
599
+
600
+ // c exists and is closed, remove it from the cache
601
+ if ok {
602
+ // remove it.
603
+ _ = c .Close ()
604
+ delete (d .cache , cn )
595
605
}
596
606
597
- c .UpdateRefresh (useIAMAuthN )
607
+ // c does not exist, check for matching domain and close it
608
+ oldCn , old , ok := d .findByDn (cn )
609
+ if ok {
610
+ _ = old .Close ()
611
+ delete (d .cache , oldCn )
612
+ }
598
613
614
+ // Create a new instance of monitoredCache
615
+ var useIAMAuthNDial bool
616
+ if useIAMAuthN != nil {
617
+ useIAMAuthNDial = * useIAMAuthN
618
+ }
619
+ d .logger .Debugf (ctx , "[%v] Connection info added to cache" , cn .String ())
620
+ k , err := d .keyGenerator .rsaKey ()
621
+ if err != nil {
622
+ return nil , err
623
+ }
624
+ var cache connectionInfoCache
625
+ if d .lazyRefresh {
626
+ cache = cloudsql .NewLazyRefreshCache (
627
+ cn ,
628
+ d .logger ,
629
+ d .sqladmin , k ,
630
+ d .refreshTimeout , d .iamTokenSource ,
631
+ d .dialerID , useIAMAuthNDial ,
632
+ )
633
+ } else {
634
+ cache = cloudsql .NewRefreshAheadCache (
635
+ cn ,
636
+ d .logger ,
637
+ d .sqladmin , k ,
638
+ d .refreshTimeout , d .iamTokenSource ,
639
+ d .dialerID , useIAMAuthNDial ,
640
+ )
641
+ }
642
+ c = newMonitoredCache (ctx , cache , cn , d .failoverPeriod , d .resolver , d .logger )
643
+ d .cache [cn ] = c
644
+
599
645
return c , nil
600
646
}
647
+
648
+ // getOrAdd returns the cache entry, creating it if necessary. This will also
649
+ // take care to remove entries with the same domain name.
650
+ //
651
+ // cn - the connection name to getOrAdd
652
+ //
653
+ // returns:
654
+ //
655
+ // monitoredCache - the cached entry
656
+ // bool ok - the instance exists
657
+ // instance.ConnName - the key to the old entry with the same domain name
658
+ //
659
+ // This method does not manage locks.
660
+ func (d * Dialer ) findByDn (cn instance.ConnName ) (instance.ConnName , * monitoredCache , bool ) {
661
+
662
+ // Try to get an instance with the same domain name but different instance
663
+ // Remove this instance from the cache, it will be replaced.
664
+ if cn .HasDomainName () {
665
+ for oldCn , oc := range d .cache {
666
+ if oldCn .DomainName () == cn .DomainName () && oldCn != cn {
667
+ return oldCn , oc , true
668
+ }
669
+ }
670
+ }
671
+
672
+ return instance.ConnName {}, nil , false
673
+ }
0 commit comments