@@ -122,12 +122,41 @@ type monitoredCache struct {
122
122
connectionInfoCache
123
123
}
124
124
125
+ func (c * monitoredCache ) cleanupClosed () {
126
+ // Remove closed sockets from cache.openSockets
127
+ c .lock .Lock ()
128
+ defer c .lock .Unlock ()
129
+
130
+ var newOpenSockets []* instrumentedConn
131
+ for _ , s := range c .openSockets {
132
+ if ! s .closed {
133
+ newOpenSockets = append (newOpenSockets , s )
134
+ }
135
+ }
136
+ c .openSockets = newOpenSockets
137
+
138
+ }
139
+ func (c * monitoredCache ) closeMonitored () {
140
+ c .lock .Lock ()
141
+ defer c .lock .Unlock ()
142
+ if c .openConns != nil {
143
+ for _ , socket := range c .openSockets {
144
+ if ! socket .closed {
145
+ socket .Close ()
146
+ }
147
+ }
148
+ atomic .StoreUint64 (c .openConns , 0 )
149
+ }
150
+ if c .connectionInfoCache != nil {
151
+ c .connectionInfoCache .Close ()
152
+ }
153
+ }
154
+
125
155
// A Dialer is used to create connections to Cloud SQL instances.
126
156
//
127
157
// Use NewDialer to initialize a Dialer.
128
158
type Dialer struct {
129
- lock sync.RWMutex
130
- cache map [instance.ConnName ]monitoredCache
159
+ cache * DialerCache
131
160
keyGenerator * keyGenerator
132
161
refreshTimeout time.Duration
133
162
// closed reports if the dialer has been closed.
@@ -273,7 +302,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
273
302
274
303
d := & Dialer {
275
304
closed : make (chan struct {}),
276
- cache : make ( map [instance. ConnName ] monitoredCache ),
305
+ cache : newDialerCache ( cfg . logger ),
277
306
lazyRefresh : cfg .lazyRefresh ,
278
307
keyGenerator : g ,
279
308
refreshTimeout : cfg .refreshTimeout ,
@@ -431,16 +460,14 @@ func (d *Dialer) removeCached(
431
460
ctx context.Context ,
432
461
i instance.ConnName , c connectionInfoCache , err error ,
433
462
) {
463
+ d .cache .remove (i )
434
464
d .logger .Debugf (
435
465
ctx ,
436
466
"[%v] Removing connection info from cache: %v" ,
437
467
i .String (),
438
468
err ,
439
469
)
440
- d .lock .Lock ()
441
- defer d .lock .Unlock ()
442
470
c .Close ()
443
- delete (d .cache , i )
444
471
}
445
472
446
473
// validClientCert checks that the ephemeral client certificate retrieved from
@@ -476,6 +503,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
476
503
if err != nil {
477
504
return "" , err
478
505
}
506
+ // Create a connectionInfoCache without adding it to the cache.
479
507
c , err := d .connectionInfoCache (ctx , cn , & d .defaultDialConfig .useIAMAuthN )
480
508
if err != nil {
481
509
return "" , err
@@ -574,94 +602,60 @@ func (d *Dialer) Close() error {
574
602
default :
575
603
}
576
604
close (d .closed )
577
- d .lock .Lock ()
578
- defer d .lock .Unlock ()
605
+
606
+ d .cache .replaceAll (func (cn instance.ConnName , c monitoredCache ) (instance.ConnName , monitoredCache ) {
607
+ c .closeMonitored () // close the monitoredCache
608
+ return instance.ConnName {}, monitoredCache {} // Remove from cache
609
+ })
610
+
579
611
if d .domainNameTicker != nil {
580
612
d .domainNameTicker .Stop ()
581
613
}
582
- for _ , i := range d .cache {
583
- i .Close ()
584
- }
585
614
return nil
586
615
}
587
616
588
617
func (d * Dialer ) pollDomainNames (ctx context.Context ) {
589
- type cacheEntry struct {
590
- cn instance.ConnName
591
- cache monitoredCache
592
- }
593
-
594
- // List all the cache entries created with a domain name
595
- d .lock .RLock ()
596
- cacheEntries := make ([]cacheEntry , 0 , len (d .cache ))
597
- for cn , cache := range d .cache {
598
-
599
- // Ignore cache entries that were not opened by domain name.
618
+ // Check if domain changed.
619
+ d .cache .replaceAll (func (cn instance.ConnName , c monitoredCache ) (instance.ConnName , monitoredCache ) {
620
+ // No domain set, do nothing.
600
621
if cn .DomainName () == "" {
601
- continue
602
- }
603
-
604
- cacheEntries = append (cacheEntries , cacheEntry {cn : cn , cache : cache })
605
- }
606
- d .lock .RUnlock ()
607
-
608
- for _ , entry := range cacheEntries {
609
- // Remove closed sockets from cache.openSockets
610
- entry .cache .lock .Lock ()
611
- var newOpenSockets []* instrumentedConn
612
- for _ , s := range entry .cache .openSockets {
613
- if ! s .closed {
614
- newOpenSockets = append (newOpenSockets , s )
615
- }
622
+ return cn , c // no change
616
623
}
617
- entry .cache .openSockets = newOpenSockets
618
- entry .cache .lock .Unlock ()
619
624
620
625
// Resolve the domain name.
621
- newCn , err := d .resolver .Resolve (ctx , entry . cn .DomainName ())
626
+ newCn , err := d .resolver .Resolve (ctx , cn .DomainName ())
622
627
623
- // the domain name no longer resolves to a valid instance
628
+ // The domain name no longer resolves, remove from cache.
624
629
if err != nil {
625
- d .logger .Debugf (ctx , "[failover] unable to resolve DNS for instance %s: %v" , entry .cn .DomainName (), err )
630
+ d .logger .Debugf (ctx , "[failover] unable to resolve DNS for instance %s: %v" , cn .DomainName (), err )
631
+ c .closeMonitored () // Close the instance
632
+ return cn , monitoredCache {} // remove from cache
626
633
}
627
634
628
- // The domain name points to a different instance.
629
- if newCn != entry . cn {
630
- d .logger .Debugf (ctx , "domain name %s changed from old instance %s to new instance %s" ,
631
- entry . cn .DomainName (), entry . cn .String (), newCn .String ())
635
+ // The domain name points to a different instance, replace .
636
+ if newCn != cn {
637
+ d .logger .Debugf (ctx , "[failover] domain name %s changed from old instance %s to new instance %s" ,
638
+ cn .DomainName (), cn .String (), newCn .String ())
632
639
633
- d .closeDomainNameChanged (ctx , entry .cn , entry .cache ,
634
- fmt .Errorf ("domain name %s changed from old instance %s to new instance %s" ,
635
- entry .cn .DomainName (), entry .cn .String (), newCn .String ()))
636
- // preload the new cache entry
637
- b := entry .cache .UseIAMAuthN ()
638
- d .connectionInfoCache (ctx , newCn , & b )
639
- }
640
- }
641
-
642
- }
640
+ // Close the old instance
641
+ go c .closeMonitored ()
643
642
644
- func (d * Dialer ) closeDomainNameChanged (ctx context.Context , cn instance.ConnName , cache monitoredCache , err error ) {
645
- d .removeCached (ctx , cn , cache , err )
646
- if atomic .LoadUint64 (cache .openConns ) > 0 {
647
- for _ , socket := range cache .openSockets {
648
- if ! socket .closed {
649
- socket .Close ()
643
+ // preload the new cache entry
644
+ b := c .UseIAMAuthN ()
645
+ newCache , err := d .createConnectionInfoCache (ctx , newCn , & b )
646
+ if err != nil {
647
+ return cn , monitoredCache {} // create instance failed, remove from cache
650
648
}
649
+
650
+ // replace with newCn with updated domain and new connectionInfoCache
651
+ return newCn , newCache
651
652
}
652
- }
653
- }
654
653
655
- func (d * Dialer ) findByDomainName (dn string ) (instance.ConnName , monitoredCache , bool ) {
654
+ // Nothing changed, clean up any closed sockets from the list of sockets.
655
+ c .cleanupClosed ()
656
+ return cn , c
657
+ })
656
658
657
- d .lock .RLock ()
658
- defer d .lock .RUnlock ()
659
- for cn , cache := range d .cache {
660
- if cn .DomainName () == dn {
661
- return cn , cache , true
662
- }
663
- }
664
- return instance.ConnName {}, monitoredCache {}, false
665
659
}
666
660
667
661
// connectionInfoCache is a helper function for returning the appropriate
@@ -670,60 +664,49 @@ func (d *Dialer) findByDomainName(dn string) (instance.ConnName, monitoredCache,
670
664
func (d * Dialer ) connectionInfoCache (
671
665
ctx context.Context , cn instance.ConnName , useIAMAuthN * bool ,
672
666
) (monitoredCache , error ) {
673
- d .lock .RLock ()
674
- c , ok := d .cache [cn ]
675
- d .lock .RUnlock ()
676
-
677
- if ! ok {
678
- // Check if the domain name was previously associated with a different
679
- // instance, and if so, close that cache.
680
- if cn .DomainName () != "" {
681
- oldCn , c , ok := d .findByDomainName (cn .DomainName ())
682
- if ok {
683
- d .closeDomainNameChanged (ctx , oldCn , c , fmt .Errorf (
684
- "domain name %s changed from old instance %s to new instance %s" ,
685
- cn .DomainName (), oldCn .String (), cn .String ()))
686
- }
687
- }
688
667
689
- // Create a new connectionInfoCache
690
- d .lock .Lock ()
691
- defer d .lock .Unlock ()
692
- // Recheck to ensure instance wasn't created or changed between locks
693
- c , ok = d .cache [cn ]
694
- if ! ok {
695
- var useIAMAuthNDial bool
696
- if useIAMAuthN != nil {
697
- useIAMAuthNDial = * useIAMAuthN
698
- }
699
- d .logger .Debugf (ctx , "[%v] Connection info added to cache" , cn .String ())
700
- k , err := d .keyGenerator .rsaKey ()
701
- if err != nil {
702
- return monitoredCache {}, err
703
- }
704
- var cache connectionInfoCache
705
- if d .lazyRefresh {
706
- cache = cloudsql .NewLazyRefreshCache (
707
- cn ,
708
- d .logger ,
709
- d .sqladmin , k ,
710
- d .refreshTimeout , d .iamTokenSource ,
711
- d .dialerID , useIAMAuthNDial ,
712
- )
713
- } else {
714
- cache = cloudsql .NewRefreshAheadCache (
715
- cn ,
716
- d .logger ,
717
- d .sqladmin , k ,
718
- d .refreshTimeout , d .iamTokenSource ,
719
- d .dialerID , useIAMAuthNDial ,
720
- )
721
- }
722
- var count uint64
723
- c = monitoredCache {openConns : & count , connectionInfoCache : cache }
724
- d .cache [cn ] = c
725
- }
668
+ c , oldC , err := d .cache .get (cn , func () (monitoredCache , error ) {
669
+ return d .createConnectionInfoCache (ctx , cn , useIAMAuthN )
670
+ })
671
+
672
+ oldC .closeMonitored ()
673
+
674
+ return c , err
675
+ }
676
+
677
+ func (d * Dialer ) createConnectionInfoCache (
678
+ ctx context.Context , cn instance.ConnName , useIAMAuthN * bool ,
679
+ ) (monitoredCache , error ) {
680
+
681
+ var useIAMAuthNDial bool
682
+ if useIAMAuthN != nil {
683
+ useIAMAuthNDial = * useIAMAuthN
684
+ }
685
+ d .logger .Debugf (ctx , "[%v] Connection info created" , cn .String ())
686
+ k , err := d .keyGenerator .rsaKey ()
687
+ if err != nil {
688
+ return monitoredCache {}, err
689
+ }
690
+ var cache connectionInfoCache
691
+ if d .lazyRefresh {
692
+ cache = cloudsql .NewLazyRefreshCache (
693
+ cn ,
694
+ d .logger ,
695
+ d .sqladmin , k ,
696
+ d .refreshTimeout , d .iamTokenSource ,
697
+ d .dialerID , useIAMAuthNDial ,
698
+ )
699
+ } else {
700
+ cache = cloudsql .NewRefreshAheadCache (
701
+ cn ,
702
+ d .logger ,
703
+ d .sqladmin , k ,
704
+ d .refreshTimeout , d .iamTokenSource ,
705
+ d .dialerID , useIAMAuthNDial ,
706
+ )
726
707
}
708
+ var count uint64
709
+ c := monitoredCache {openConns : & count , connectionInfoCache : cache }
727
710
728
711
c .UpdateRefresh (useIAMAuthN )
729
712
0 commit comments