Skip to content

Commit 84e62f8

Browse files
committed
chore: Refactor dialer cache concurrency logic. Part of #842.
1 parent 5b2a68b commit 84e62f8

File tree

6 files changed

+616
-92
lines changed

6 files changed

+616
-92
lines changed

dialer.go

+98-67
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,43 @@ type connectionInfoCache interface {
113113
// monitoredCache is a wrapper around a connectionInfoCache that tracks the
114114
// number of connections to the associated instance.
115115
type monitoredCache struct {
116-
openConns *uint64
116+
openConnsCount *uint64
117117

118118
connectionInfoCache
119119
}
120120

121+
func (c *monitoredCache) Close() error {
122+
if c == nil || c.connectionInfoCache == nil {
123+
return nil
124+
}
125+
return c.connectionInfoCache.Close()
126+
}
127+
128+
func (c *monitoredCache) ForceRefresh() {
129+
if c == nil || c.connectionInfoCache == nil {
130+
return
131+
}
132+
c.connectionInfoCache.ForceRefresh()
133+
}
134+
135+
func (c *monitoredCache) UpdateRefresh(b *bool) {
136+
if c == nil || c.connectionInfoCache == nil {
137+
return
138+
}
139+
c.connectionInfoCache.UpdateRefresh(b)
140+
}
141+
func (c *monitoredCache) ConnectionInfo(ctx context.Context) (cloudsql.ConnectionInfo, error) {
142+
if c == nil || c.connectionInfoCache == nil {
143+
return cloudsql.ConnectionInfo{}, nil
144+
}
145+
return c.connectionInfoCache.ConnectionInfo(ctx)
146+
}
147+
121148
// A Dialer is used to create connections to Cloud SQL instances.
122149
//
123150
// Use NewDialer to initialize a Dialer.
124151
type Dialer struct {
125-
lock sync.RWMutex
126-
cache map[instance.ConnName]monitoredCache
152+
cache *dialerCache
127153
keyGenerator *keyGenerator
128154
refreshTimeout time.Duration
129155
// closed reports if the dialer has been closed.
@@ -263,7 +289,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
263289

264290
d := &Dialer{
265291
closed: make(chan struct{}),
266-
cache: make(map[instance.ConnName]monitoredCache),
292+
cache: newDialerCache(cfg.logger),
267293
lazyRefresh: cfg.lazyRefresh,
268294
keyGenerator: g,
269295
refreshTimeout: cfg.refreshTimeout,
@@ -316,7 +342,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
316342
}
317343
ci, err := c.ConnectionInfo(ctx)
318344
if err != nil {
319-
d.removeCached(ctx, cn, c, err)
345+
d.removeCached(ctx, cn, err)
320346
endInfo(err)
321347
return nil, err
322348
}
@@ -333,7 +359,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
333359
// Block on refreshed connection info
334360
ci, err = c.ConnectionInfo(ctx)
335361
if err != nil {
336-
d.removeCached(ctx, cn, c, err)
362+
d.removeCached(ctx, cn, err)
337363
return nil, err
338364
}
339365
}
@@ -343,7 +369,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
343369
defer func() { connectEnd(err) }()
344370
addr, err := ci.Addr(cfg.ipType)
345371
if err != nil {
346-
d.removeCached(ctx, cn, c, err)
372+
d.removeCached(ctx, cn, err)
347373
return nil, err
348374
}
349375
addr = net.JoinHostPort(addr, serverProxyPort)
@@ -380,33 +406,30 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
380406

381407
latency := time.Since(startTime).Milliseconds()
382408
go func() {
383-
n := atomic.AddUint64(c.openConns, 1)
409+
n := atomic.AddUint64(c.openConnsCount, 1)
384410
trace.RecordOpenConnections(ctx, int64(n), d.dialerID, cn.String())
385411
trace.RecordDialLatency(ctx, icn, d.dialerID, latency)
386412
}()
387413

388-
return newInstrumentedConn(tlsConn, func() {
389-
n := atomic.AddUint64(c.openConns, ^uint64(0))
414+
iConn := newInstrumentedConn(tlsConn, func() {
415+
n := atomic.AddUint64(c.openConnsCount, ^uint64(0))
390416
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
391-
}, d.dialerID, cn.String()), nil
417+
}, d.dialerID, cn.String())
418+
419+
return iConn, nil
392420
}
393421

394422
// removeCached stops all background refreshes and deletes the connection
395423
// info cache from the map of caches.
396-
func (d *Dialer) removeCached(
397-
ctx context.Context,
398-
i instance.ConnName, c connectionInfoCache, err error,
399-
) {
424+
func (d *Dialer) removeCached(ctx context.Context, i instance.ConnName, err error) {
425+
mc := d.cache.remove(i)
426+
mc.Close()
400427
d.logger.Debugf(
401428
ctx,
402429
"[%v] Removing connection info from cache: %v",
403430
i.String(),
404431
err,
405432
)
406-
d.lock.Lock()
407-
defer d.lock.Unlock()
408-
c.Close()
409-
delete(d.cache, i)
410433
}
411434

412435
// validClientCert checks that the ephemeral client certificate retrieved from
@@ -442,13 +465,14 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
442465
if err != nil {
443466
return "", err
444467
}
468+
445469
c, err := d.connectionInfoCache(ctx, cn, &d.defaultDialConfig.useIAMAuthN)
446470
if err != nil {
447471
return "", err
448472
}
449473
ci, err := c.ConnectionInfo(ctx)
450474
if err != nil {
451-
d.removeCached(ctx, cn, c, err)
475+
d.removeCached(ctx, cn, err)
452476
return "", err
453477
}
454478
return ci.DBVersion, nil
@@ -472,7 +496,7 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
472496
}
473497
_, err = c.ConnectionInfo(ctx)
474498
if err != nil {
475-
d.removeCached(ctx, cn, c, err)
499+
d.removeCached(ctx, cn, err)
476500
}
477501
return err
478502
}
@@ -493,6 +517,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
493517
type instrumentedConn struct {
494518
net.Conn
495519
closeFunc func()
520+
closed bool
496521
dialerID string
497522
connName string
498523
}
@@ -520,6 +545,7 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
520545
// Close delegates to the underlying net.Conn interface and reports the close
521546
// to the provided closeFunc only when Close returns no error.
522547
func (i *instrumentedConn) Close() error {
548+
i.closed = true
523549
err := i.Conn.Close()
524550
if err != nil {
525551
return err
@@ -538,11 +564,11 @@ func (d *Dialer) Close() error {
538564
default:
539565
}
540566
close(d.closed)
541-
d.lock.Lock()
542-
defer d.lock.Unlock()
543-
for _, i := range d.cache {
544-
i.Close()
545-
}
567+
568+
d.cache.replaceAll(func(_ instance.ConnName, c *monitoredCache) (instance.ConnName, *monitoredCache) {
569+
c.Close() // close the monitoredCache
570+
return instance.ConnName{}, nil // Remove from cache
571+
})
546572
return nil
547573
}
548574

@@ -551,47 +577,52 @@ func (d *Dialer) Close() error {
551577
// modify the existing one, or leave it unchanged as needed.
552578
func (d *Dialer) connectionInfoCache(
553579
ctx context.Context, cn instance.ConnName, useIAMAuthN *bool,
554-
) (monitoredCache, error) {
555-
d.lock.RLock()
556-
c, ok := d.cache[cn]
557-
d.lock.RUnlock()
558-
if !ok {
559-
d.lock.Lock()
560-
defer d.lock.Unlock()
561-
// Recheck to ensure instance wasn't created or changed between locks
562-
c, ok = d.cache[cn]
563-
if !ok {
564-
var useIAMAuthNDial bool
565-
if useIAMAuthN != nil {
566-
useIAMAuthNDial = *useIAMAuthN
567-
}
568-
d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String())
569-
k, err := d.keyGenerator.rsaKey()
570-
if err != nil {
571-
return monitoredCache{}, err
572-
}
573-
var cache connectionInfoCache
574-
if d.lazyRefresh {
575-
cache = cloudsql.NewLazyRefreshCache(
576-
cn,
577-
d.logger,
578-
d.sqladmin, k,
579-
d.refreshTimeout, d.iamTokenSource,
580-
d.dialerID, useIAMAuthNDial,
581-
)
582-
} else {
583-
cache = cloudsql.NewRefreshAheadCache(
584-
cn,
585-
d.logger,
586-
d.sqladmin, k,
587-
d.refreshTimeout, d.iamTokenSource,
588-
d.dialerID, useIAMAuthNDial,
589-
)
590-
}
591-
var count uint64
592-
c = monitoredCache{openConns: &count, connectionInfoCache: cache}
593-
d.cache[cn] = c
594-
}
580+
) (*monitoredCache, error) {
581+
582+
c, oldC, err := d.cache.getOrAdd(cn, func() (*monitoredCache, error) {
583+
return d.createConnectionInfoCache(ctx, cn, useIAMAuthN)
584+
})
585+
586+
oldC.Close()
587+
c.UpdateRefresh(useIAMAuthN)
588+
589+
return c, err
590+
}
591+
592+
func (d *Dialer) createConnectionInfoCache(
593+
ctx context.Context, cn instance.ConnName, useIAMAuthN *bool,
594+
) (*monitoredCache, error) {
595+
596+
var useIAMAuthNDial bool
597+
if useIAMAuthN != nil {
598+
useIAMAuthNDial = *useIAMAuthN
599+
}
600+
d.logger.Debugf(ctx, "[%v] Connection info created", cn.String())
601+
k, err := d.keyGenerator.rsaKey()
602+
if err != nil {
603+
return nil, err
604+
}
605+
var cache connectionInfoCache
606+
if d.lazyRefresh {
607+
cache = cloudsql.NewLazyRefreshCache(
608+
cn,
609+
d.logger,
610+
d.sqladmin, k,
611+
d.refreshTimeout, d.iamTokenSource,
612+
d.dialerID, useIAMAuthNDial,
613+
)
614+
} else {
615+
cache = cloudsql.NewRefreshAheadCache(
616+
cn,
617+
d.logger,
618+
d.sqladmin, k,
619+
d.refreshTimeout, d.iamTokenSource,
620+
d.dialerID, useIAMAuthNDial,
621+
)
622+
}
623+
c := &monitoredCache{
624+
openConnsCount: new(uint64),
625+
connectionInfoCache: cache,
595626
}
596627

597628
c.UpdateRefresh(useIAMAuthN)

0 commit comments

Comments
 (0)