@@ -118,12 +118,38 @@ type monitoredCache struct {
118
118
connectionInfoCache
119
119
}
120
120
121
+ func (c monitoredCache ) Close () error {
122
+ if c .connectionInfoCache == nil {
123
+ return nil
124
+ }
125
+ return c .connectionInfoCache .Close ()
126
+ }
127
+
128
+ func (c monitoredCache ) ForceRefresh () {
129
+ if c .connectionInfoCache == nil {
130
+ return
131
+ }
132
+ c .connectionInfoCache .ForceRefresh ()
133
+ }
134
+
135
+ func (c monitoredCache ) UpdateRefresh (b * bool ) {
136
+ if c .connectionInfoCache == nil {
137
+ return
138
+ }
139
+ c .connectionInfoCache .UpdateRefresh (b )
140
+ }
141
+ func (c monitoredCache ) ConnectionInfo (ctx context.Context ) (cloudsql.ConnectionInfo , error ) {
142
+ if c .connectionInfoCache == nil {
143
+ return cloudsql.ConnectionInfo {}, nil
144
+ }
145
+ return c .connectionInfoCache .ConnectionInfo (ctx )
146
+ }
147
+
121
148
// A Dialer is used to create connections to Cloud SQL instances.
122
149
//
123
150
// Use NewDialer to initialize a Dialer.
124
151
type Dialer struct {
125
- lock sync.RWMutex
126
- cache map [instance.ConnName ]monitoredCache
152
+ cache * DialerCache
127
153
keyGenerator * keyGenerator
128
154
refreshTimeout time.Duration
129
155
// closed reports if the dialer has been closed.
@@ -205,7 +231,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
205
231
}
206
232
ud , err := c .GetUniverseDomain ()
207
233
if err != nil {
208
- return nil , fmt .Errorf ("failed to get universe domain: %v" , err )
234
+ return nil , fmt .Errorf ("failed to getOrAdd universe domain: %v" , err )
209
235
}
210
236
cfg .credentialsUniverse = ud
211
237
cfg .sqladminOpts = append (cfg .sqladminOpts , option .WithTokenSource (c .TokenSource ))
@@ -263,7 +289,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
263
289
264
290
d := & Dialer {
265
291
closed : make (chan struct {}),
266
- cache : make ( map [instance. ConnName ] monitoredCache ),
292
+ cache : newDialerCache ( cfg . logger ),
267
293
lazyRefresh : cfg .lazyRefresh ,
268
294
keyGenerator : g ,
269
295
refreshTimeout : cfg .refreshTimeout ,
@@ -385,10 +411,12 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
385
411
trace .RecordDialLatency (ctx , icn , d .dialerID , latency )
386
412
}()
387
413
388
- return newInstrumentedConn (tlsConn , func () {
414
+ iConn := newInstrumentedConn (tlsConn , func () {
389
415
n := atomic .AddUint64 (c .openConns , ^ uint64 (0 ))
390
416
trace .RecordOpenConnections (context .Background (), int64 (n ), d .dialerID , cn .String ())
391
- }, d .dialerID , cn .String ()), nil
417
+ }, d .dialerID , cn .String ())
418
+
419
+ return iConn , nil
392
420
}
393
421
394
422
// removeCached stops all background refreshes and deletes the connection
@@ -397,16 +425,14 @@ func (d *Dialer) removeCached(
397
425
ctx context.Context ,
398
426
i instance.ConnName , c connectionInfoCache , err error ,
399
427
) {
428
+ mc := d .cache .remove (i )
429
+ mc .Close ()
400
430
d .logger .Debugf (
401
431
ctx ,
402
432
"[%v] Removing connection info from cache: %v" ,
403
433
i .String (),
404
434
err ,
405
435
)
406
- d .lock .Lock ()
407
- defer d .lock .Unlock ()
408
- c .Close ()
409
- delete (d .cache , i )
410
436
}
411
437
412
438
// validClientCert checks that the ephemeral client certificate retrieved from
@@ -442,6 +468,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
442
468
if err != nil {
443
469
return "" , err
444
470
}
471
+
445
472
c , err := d .connectionInfoCache (ctx , cn , & d .defaultDialConfig .useIAMAuthN )
446
473
if err != nil {
447
474
return "" , err
@@ -493,6 +520,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
493
520
type instrumentedConn struct {
494
521
net.Conn
495
522
closeFunc func ()
523
+ closed bool
496
524
dialerID string
497
525
connName string
498
526
}
@@ -520,6 +548,7 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
520
548
// Close delegates to the underlying net.Conn interface and reports the close
521
549
// to the provided closeFunc only when Close returns no error.
522
550
func (i * instrumentedConn ) Close () error {
551
+ i .closed = true
523
552
err := i .Conn .Close ()
524
553
if err != nil {
525
554
return err
@@ -538,11 +567,11 @@ func (d *Dialer) Close() error {
538
567
default :
539
568
}
540
569
close (d .closed )
541
- d . lock . Lock ()
542
- defer d . lock . Unlock ()
543
- for _ , i := range d . cache {
544
- i . Close ()
545
- }
570
+
571
+ d . cache . replaceAll ( func ( cn instance. ConnName , c monitoredCache ) (instance. ConnName , monitoredCache ) {
572
+ c . Close () // close the monitoredCache
573
+ return instance. ConnName {}, monitoredCache {} // Remove from cache
574
+ })
546
575
return nil
547
576
}
548
577
@@ -552,47 +581,50 @@ func (d *Dialer) Close() error {
552
581
func (d * Dialer ) connectionInfoCache (
553
582
ctx context.Context , cn instance.ConnName , useIAMAuthN * bool ,
554
583
) (monitoredCache , error ) {
555
- d .lock .RLock ()
556
- c , ok := d .cache [cn ]
557
- d .lock .RUnlock ()
558
- if ! ok {
559
- d .lock .Lock ()
560
- defer d .lock .Unlock ()
561
- // Recheck to ensure instance wasn't created or changed between locks
562
- c , ok = d .cache [cn ]
563
- if ! ok {
564
- var useIAMAuthNDial bool
565
- if useIAMAuthN != nil {
566
- useIAMAuthNDial = * useIAMAuthN
567
- }
568
- d .logger .Debugf (ctx , "[%v] Connection info added to cache" , cn .String ())
569
- k , err := d .keyGenerator .rsaKey ()
570
- if err != nil {
571
- return monitoredCache {}, err
572
- }
573
- var cache connectionInfoCache
574
- if d .lazyRefresh {
575
- cache = cloudsql .NewLazyRefreshCache (
576
- cn ,
577
- d .logger ,
578
- d .sqladmin , k ,
579
- d .refreshTimeout , d .iamTokenSource ,
580
- d .dialerID , useIAMAuthNDial ,
581
- )
582
- } else {
583
- cache = cloudsql .NewRefreshAheadCache (
584
- cn ,
585
- d .logger ,
586
- d .sqladmin , k ,
587
- d .refreshTimeout , d .iamTokenSource ,
588
- d .dialerID , useIAMAuthNDial ,
589
- )
590
- }
591
- var count uint64
592
- c = monitoredCache {openConns : & count , connectionInfoCache : cache }
593
- d .cache [cn ] = c
594
- }
584
+
585
+ c , oldC , err := d .cache .getOrAdd (cn , func () (monitoredCache , error ) {
586
+ return d .createConnectionInfoCache (ctx , cn , useIAMAuthN )
587
+ })
588
+
589
+ oldC .Close ()
590
+ c .UpdateRefresh (useIAMAuthN )
591
+
592
+ return c , err
593
+ }
594
+
595
+ func (d * Dialer ) createConnectionInfoCache (
596
+ ctx context.Context , cn instance.ConnName , useIAMAuthN * bool ,
597
+ ) (monitoredCache , error ) {
598
+
599
+ var useIAMAuthNDial bool
600
+ if useIAMAuthN != nil {
601
+ useIAMAuthNDial = * useIAMAuthN
602
+ }
603
+ d .logger .Debugf (ctx , "[%v] Connection info created" , cn .String ())
604
+ k , err := d .keyGenerator .rsaKey ()
605
+ if err != nil {
606
+ return monitoredCache {}, err
607
+ }
608
+ var cache connectionInfoCache
609
+ if d .lazyRefresh {
610
+ cache = cloudsql .NewLazyRefreshCache (
611
+ cn ,
612
+ d .logger ,
613
+ d .sqladmin , k ,
614
+ d .refreshTimeout , d .iamTokenSource ,
615
+ d .dialerID , useIAMAuthNDial ,
616
+ )
617
+ } else {
618
+ cache = cloudsql .NewRefreshAheadCache (
619
+ cn ,
620
+ d .logger ,
621
+ d .sqladmin , k ,
622
+ d .refreshTimeout , d .iamTokenSource ,
623
+ d .dialerID , useIAMAuthNDial ,
624
+ )
595
625
}
626
+ var count uint64
627
+ c := monitoredCache {openConns : & count , connectionInfoCache : cache }
596
628
597
629
c .UpdateRefresh (useIAMAuthN )
598
630
0 commit comments