@@ -113,17 +113,43 @@ type connectionInfoCache interface {
113
113
// monitoredCache is a wrapper around a connectionInfoCache that tracks the
114
114
// number of connections to the associated instance.
115
115
type monitoredCache struct {
116
- openConns * uint64
116
+ openConnsCount * uint64
117
117
118
118
connectionInfoCache
119
119
}
120
120
121
+ func (c * monitoredCache ) Close () error {
122
+ if c == nil || c .connectionInfoCache == nil {
123
+ return nil
124
+ }
125
+ return c .connectionInfoCache .Close ()
126
+ }
127
+
128
+ func (c * monitoredCache ) ForceRefresh () {
129
+ if c == nil || c .connectionInfoCache == nil {
130
+ return
131
+ }
132
+ c .connectionInfoCache .ForceRefresh ()
133
+ }
134
+
135
+ func (c * monitoredCache ) UpdateRefresh (b * bool ) {
136
+ if c == nil || c .connectionInfoCache == nil {
137
+ return
138
+ }
139
+ c .connectionInfoCache .UpdateRefresh (b )
140
+ }
141
+ func (c * monitoredCache ) ConnectionInfo (ctx context.Context ) (cloudsql.ConnectionInfo , error ) {
142
+ if c == nil || c .connectionInfoCache == nil {
143
+ return cloudsql.ConnectionInfo {}, nil
144
+ }
145
+ return c .connectionInfoCache .ConnectionInfo (ctx )
146
+ }
147
+
121
148
// A Dialer is used to create connections to Cloud SQL instances.
122
149
//
123
150
// Use NewDialer to initialize a Dialer.
124
151
type Dialer struct {
125
- lock sync.RWMutex
126
- cache map [instance.ConnName ]monitoredCache
152
+ cache * dialerCache
127
153
keyGenerator * keyGenerator
128
154
refreshTimeout time.Duration
129
155
// closed reports if the dialer has been closed.
@@ -263,7 +289,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
263
289
264
290
d := & Dialer {
265
291
closed : make (chan struct {}),
266
- cache : make ( map [instance. ConnName ] monitoredCache ),
292
+ cache : newDialerCache ( cfg . logger ),
267
293
lazyRefresh : cfg .lazyRefresh ,
268
294
keyGenerator : g ,
269
295
refreshTimeout : cfg .refreshTimeout ,
@@ -316,7 +342,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
316
342
}
317
343
ci , err := c .ConnectionInfo (ctx )
318
344
if err != nil {
319
- d .removeCached (ctx , cn , c , err )
345
+ d .removeCached (ctx , cn , err )
320
346
endInfo (err )
321
347
return nil , err
322
348
}
@@ -333,7 +359,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
333
359
// Block on refreshed connection info
334
360
ci , err = c .ConnectionInfo (ctx )
335
361
if err != nil {
336
- d .removeCached (ctx , cn , c , err )
362
+ d .removeCached (ctx , cn , err )
337
363
return nil , err
338
364
}
339
365
}
@@ -343,7 +369,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
343
369
defer func () { connectEnd (err ) }()
344
370
addr , err := ci .Addr (cfg .ipType )
345
371
if err != nil {
346
- d .removeCached (ctx , cn , c , err )
372
+ d .removeCached (ctx , cn , err )
347
373
return nil , err
348
374
}
349
375
addr = net .JoinHostPort (addr , serverProxyPort )
@@ -380,33 +406,30 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
380
406
381
407
latency := time .Since (startTime ).Milliseconds ()
382
408
go func () {
383
- n := atomic .AddUint64 (c .openConns , 1 )
409
+ n := atomic .AddUint64 (c .openConnsCount , 1 )
384
410
trace .RecordOpenConnections (ctx , int64 (n ), d .dialerID , cn .String ())
385
411
trace .RecordDialLatency (ctx , icn , d .dialerID , latency )
386
412
}()
387
413
388
- return newInstrumentedConn (tlsConn , func () {
389
- n := atomic .AddUint64 (c .openConns , ^ uint64 (0 ))
414
+ iConn := newInstrumentedConn (tlsConn , func () {
415
+ n := atomic .AddUint64 (c .openConnsCount , ^ uint64 (0 ))
390
416
trace .RecordOpenConnections (context .Background (), int64 (n ), d .dialerID , cn .String ())
391
- }, d .dialerID , cn .String ()), nil
417
+ }, d .dialerID , cn .String ())
418
+
419
+ return iConn , nil
392
420
}
393
421
394
422
// removeCached stops all background refreshes and deletes the connection
395
423
// info cache from the map of caches.
396
- func (d * Dialer ) removeCached (
397
- ctx context.Context ,
398
- i instance.ConnName , c connectionInfoCache , err error ,
399
- ) {
424
+ func (d * Dialer ) removeCached (ctx context.Context , i instance.ConnName , err error ) {
425
+ mc := d .cache .remove (i )
426
+ mc .Close ()
400
427
d .logger .Debugf (
401
428
ctx ,
402
429
"[%v] Removing connection info from cache: %v" ,
403
430
i .String (),
404
431
err ,
405
432
)
406
- d .lock .Lock ()
407
- defer d .lock .Unlock ()
408
- c .Close ()
409
- delete (d .cache , i )
410
433
}
411
434
412
435
// validClientCert checks that the ephemeral client certificate retrieved from
@@ -442,13 +465,14 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
442
465
if err != nil {
443
466
return "" , err
444
467
}
468
+
445
469
c , err := d .connectionInfoCache (ctx , cn , & d .defaultDialConfig .useIAMAuthN )
446
470
if err != nil {
447
471
return "" , err
448
472
}
449
473
ci , err := c .ConnectionInfo (ctx )
450
474
if err != nil {
451
- d .removeCached (ctx , cn , c , err )
475
+ d .removeCached (ctx , cn , err )
452
476
return "" , err
453
477
}
454
478
return ci .DBVersion , nil
@@ -472,7 +496,7 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
472
496
}
473
497
_ , err = c .ConnectionInfo (ctx )
474
498
if err != nil {
475
- d .removeCached (ctx , cn , c , err )
499
+ d .removeCached (ctx , cn , err )
476
500
}
477
501
return err
478
502
}
@@ -493,6 +517,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
493
517
type instrumentedConn struct {
494
518
net.Conn
495
519
closeFunc func ()
520
+ closed bool
496
521
dialerID string
497
522
connName string
498
523
}
@@ -520,6 +545,7 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
520
545
// Close delegates to the underlying net.Conn interface and reports the close
521
546
// to the provided closeFunc only when Close returns no error.
522
547
func (i * instrumentedConn ) Close () error {
548
+ i .closed = true
523
549
err := i .Conn .Close ()
524
550
if err != nil {
525
551
return err
@@ -538,11 +564,11 @@ func (d *Dialer) Close() error {
538
564
default :
539
565
}
540
566
close (d .closed )
541
- d . lock . Lock ()
542
- defer d . lock . Unlock ()
543
- for _ , i := range d . cache {
544
- i . Close ()
545
- }
567
+
568
+ d . cache . replaceAll ( func ( _ instance. ConnName , c * monitoredCache ) (instance. ConnName , * monitoredCache ) {
569
+ c . Close () // close the monitoredCache
570
+ return instance. ConnName {}, nil // Remove from cache
571
+ })
546
572
return nil
547
573
}
548
574
@@ -551,47 +577,52 @@ func (d *Dialer) Close() error {
551
577
// modify the existing one, or leave it unchanged as needed.
552
578
func (d * Dialer ) connectionInfoCache (
553
579
ctx context.Context , cn instance.ConnName , useIAMAuthN * bool ,
554
- ) (monitoredCache , error ) {
555
- d .lock .RLock ()
556
- c , ok := d .cache [cn ]
557
- d .lock .RUnlock ()
558
- if ! ok {
559
- d .lock .Lock ()
560
- defer d .lock .Unlock ()
561
- // Recheck to ensure instance wasn't created or changed between locks
562
- c , ok = d .cache [cn ]
563
- if ! ok {
564
- var useIAMAuthNDial bool
565
- if useIAMAuthN != nil {
566
- useIAMAuthNDial = * useIAMAuthN
567
- }
568
- d .logger .Debugf (ctx , "[%v] Connection info added to cache" , cn .String ())
569
- k , err := d .keyGenerator .rsaKey ()
570
- if err != nil {
571
- return monitoredCache {}, err
572
- }
573
- var cache connectionInfoCache
574
- if d .lazyRefresh {
575
- cache = cloudsql .NewLazyRefreshCache (
576
- cn ,
577
- d .logger ,
578
- d .sqladmin , k ,
579
- d .refreshTimeout , d .iamTokenSource ,
580
- d .dialerID , useIAMAuthNDial ,
581
- )
582
- } else {
583
- cache = cloudsql .NewRefreshAheadCache (
584
- cn ,
585
- d .logger ,
586
- d .sqladmin , k ,
587
- d .refreshTimeout , d .iamTokenSource ,
588
- d .dialerID , useIAMAuthNDial ,
589
- )
590
- }
591
- var count uint64
592
- c = monitoredCache {openConns : & count , connectionInfoCache : cache }
593
- d .cache [cn ] = c
594
- }
580
+ ) (* monitoredCache , error ) {
581
+
582
+ c , oldC , err := d .cache .getOrAdd (cn , func () (* monitoredCache , error ) {
583
+ return d .createConnectionInfoCache (ctx , cn , useIAMAuthN )
584
+ })
585
+
586
+ oldC .Close ()
587
+ c .UpdateRefresh (useIAMAuthN )
588
+
589
+ return c , err
590
+ }
591
+
592
+ func (d * Dialer ) createConnectionInfoCache (
593
+ ctx context.Context , cn instance.ConnName , useIAMAuthN * bool ,
594
+ ) (* monitoredCache , error ) {
595
+
596
+ var useIAMAuthNDial bool
597
+ if useIAMAuthN != nil {
598
+ useIAMAuthNDial = * useIAMAuthN
599
+ }
600
+ d .logger .Debugf (ctx , "[%v] Connection info created" , cn .String ())
601
+ k , err := d .keyGenerator .rsaKey ()
602
+ if err != nil {
603
+ return nil , err
604
+ }
605
+ var cache connectionInfoCache
606
+ if d .lazyRefresh {
607
+ cache = cloudsql .NewLazyRefreshCache (
608
+ cn ,
609
+ d .logger ,
610
+ d .sqladmin , k ,
611
+ d .refreshTimeout , d .iamTokenSource ,
612
+ d .dialerID , useIAMAuthNDial ,
613
+ )
614
+ } else {
615
+ cache = cloudsql .NewRefreshAheadCache (
616
+ cn ,
617
+ d .logger ,
618
+ d .sqladmin , k ,
619
+ d .refreshTimeout , d .iamTokenSource ,
620
+ d .dialerID , useIAMAuthNDial ,
621
+ )
622
+ }
623
+ c := & monitoredCache {
624
+ openConnsCount : new (uint64 ),
625
+ connectionInfoCache : cache ,
595
626
}
596
627
597
628
c .UpdateRefresh (useIAMAuthN )
0 commit comments