Skip to content

Commit aab26be

Browse files
committed
chore: Refactor dialer cache concurrency logic. Part of #842.
1 parent 5b2a68b commit aab26be

7 files changed

+587
-81
lines changed

dialer.go

+87-55
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,38 @@ type monitoredCache struct {
118118
connectionInfoCache
119119
}
120120

121+
func (c monitoredCache) Close() error {
122+
if c.connectionInfoCache == nil {
123+
return nil
124+
}
125+
return c.connectionInfoCache.Close()
126+
}
127+
128+
func (c monitoredCache) ForceRefresh() {
129+
if c.connectionInfoCache == nil {
130+
return
131+
}
132+
c.connectionInfoCache.ForceRefresh()
133+
}
134+
135+
func (c monitoredCache) UpdateRefresh(b *bool) {
136+
if c.connectionInfoCache == nil {
137+
return
138+
}
139+
c.connectionInfoCache.UpdateRefresh(b)
140+
}
141+
func (c monitoredCache) ConnectionInfo(ctx context.Context) (cloudsql.ConnectionInfo, error) {
142+
if c.connectionInfoCache == nil {
143+
return cloudsql.ConnectionInfo{}, nil
144+
}
145+
return c.connectionInfoCache.ConnectionInfo(ctx)
146+
}
147+
121148
// A Dialer is used to create connections to Cloud SQL instances.
122149
//
123150
// Use NewDialer to initialize a Dialer.
124151
type Dialer struct {
125-
lock sync.RWMutex
126-
cache map[instance.ConnName]monitoredCache
152+
cache *DialerCache
127153
keyGenerator *keyGenerator
128154
refreshTimeout time.Duration
129155
// closed reports if the dialer has been closed.
@@ -205,7 +231,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
205231
}
206232
ud, err := c.GetUniverseDomain()
207233
if err != nil {
208-
return nil, fmt.Errorf("failed to get universe domain: %v", err)
234+
return nil, fmt.Errorf("failed to getOrAdd universe domain: %v", err)
209235
}
210236
cfg.credentialsUniverse = ud
211237
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithTokenSource(c.TokenSource))
@@ -263,7 +289,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
263289

264290
d := &Dialer{
265291
closed: make(chan struct{}),
266-
cache: make(map[instance.ConnName]monitoredCache),
292+
cache: newDialerCache(cfg.logger),
267293
lazyRefresh: cfg.lazyRefresh,
268294
keyGenerator: g,
269295
refreshTimeout: cfg.refreshTimeout,
@@ -385,10 +411,12 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
385411
trace.RecordDialLatency(ctx, icn, d.dialerID, latency)
386412
}()
387413

388-
return newInstrumentedConn(tlsConn, func() {
414+
iConn := newInstrumentedConn(tlsConn, func() {
389415
n := atomic.AddUint64(c.openConns, ^uint64(0))
390416
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
391-
}, d.dialerID, cn.String()), nil
417+
}, d.dialerID, cn.String())
418+
419+
return iConn, nil
392420
}
393421

394422
// removeCached stops all background refreshes and deletes the connection
@@ -397,16 +425,14 @@ func (d *Dialer) removeCached(
397425
ctx context.Context,
398426
i instance.ConnName, c connectionInfoCache, err error,
399427
) {
428+
mc := d.cache.remove(i)
429+
mc.Close()
400430
d.logger.Debugf(
401431
ctx,
402432
"[%v] Removing connection info from cache: %v",
403433
i.String(),
404434
err,
405435
)
406-
d.lock.Lock()
407-
defer d.lock.Unlock()
408-
c.Close()
409-
delete(d.cache, i)
410436
}
411437

412438
// validClientCert checks that the ephemeral client certificate retrieved from
@@ -442,6 +468,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
442468
if err != nil {
443469
return "", err
444470
}
471+
445472
c, err := d.connectionInfoCache(ctx, cn, &d.defaultDialConfig.useIAMAuthN)
446473
if err != nil {
447474
return "", err
@@ -493,6 +520,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
493520
type instrumentedConn struct {
494521
net.Conn
495522
closeFunc func()
523+
closed bool
496524
dialerID string
497525
connName string
498526
}
@@ -520,6 +548,7 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
520548
// Close delegates to the underlying net.Conn interface and reports the close
521549
// to the provided closeFunc only when Close returns no error.
522550
func (i *instrumentedConn) Close() error {
551+
i.closed = true
523552
err := i.Conn.Close()
524553
if err != nil {
525554
return err
@@ -538,11 +567,11 @@ func (d *Dialer) Close() error {
538567
default:
539568
}
540569
close(d.closed)
541-
d.lock.Lock()
542-
defer d.lock.Unlock()
543-
for _, i := range d.cache {
544-
i.Close()
545-
}
570+
571+
d.cache.replaceAll(func(cn instance.ConnName, c monitoredCache) (instance.ConnName, monitoredCache) {
572+
c.Close() // close the monitoredCache
573+
return instance.ConnName{}, monitoredCache{} // Remove from cache
574+
})
546575
return nil
547576
}
548577

@@ -552,47 +581,50 @@ func (d *Dialer) Close() error {
552581
func (d *Dialer) connectionInfoCache(
553582
ctx context.Context, cn instance.ConnName, useIAMAuthN *bool,
554583
) (monitoredCache, error) {
555-
d.lock.RLock()
556-
c, ok := d.cache[cn]
557-
d.lock.RUnlock()
558-
if !ok {
559-
d.lock.Lock()
560-
defer d.lock.Unlock()
561-
// Recheck to ensure instance wasn't created or changed between locks
562-
c, ok = d.cache[cn]
563-
if !ok {
564-
var useIAMAuthNDial bool
565-
if useIAMAuthN != nil {
566-
useIAMAuthNDial = *useIAMAuthN
567-
}
568-
d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String())
569-
k, err := d.keyGenerator.rsaKey()
570-
if err != nil {
571-
return monitoredCache{}, err
572-
}
573-
var cache connectionInfoCache
574-
if d.lazyRefresh {
575-
cache = cloudsql.NewLazyRefreshCache(
576-
cn,
577-
d.logger,
578-
d.sqladmin, k,
579-
d.refreshTimeout, d.iamTokenSource,
580-
d.dialerID, useIAMAuthNDial,
581-
)
582-
} else {
583-
cache = cloudsql.NewRefreshAheadCache(
584-
cn,
585-
d.logger,
586-
d.sqladmin, k,
587-
d.refreshTimeout, d.iamTokenSource,
588-
d.dialerID, useIAMAuthNDial,
589-
)
590-
}
591-
var count uint64
592-
c = monitoredCache{openConns: &count, connectionInfoCache: cache}
593-
d.cache[cn] = c
594-
}
584+
585+
c, oldC, err := d.cache.getOrAdd(cn, func() (monitoredCache, error) {
586+
return d.createConnectionInfoCache(ctx, cn, useIAMAuthN)
587+
})
588+
589+
oldC.Close()
590+
c.UpdateRefresh(useIAMAuthN)
591+
592+
return c, err
593+
}
594+
595+
func (d *Dialer) createConnectionInfoCache(
596+
ctx context.Context, cn instance.ConnName, useIAMAuthN *bool,
597+
) (monitoredCache, error) {
598+
599+
var useIAMAuthNDial bool
600+
if useIAMAuthN != nil {
601+
useIAMAuthNDial = *useIAMAuthN
602+
}
603+
d.logger.Debugf(ctx, "[%v] Connection info created", cn.String())
604+
k, err := d.keyGenerator.rsaKey()
605+
if err != nil {
606+
return monitoredCache{}, err
607+
}
608+
var cache connectionInfoCache
609+
if d.lazyRefresh {
610+
cache = cloudsql.NewLazyRefreshCache(
611+
cn,
612+
d.logger,
613+
d.sqladmin, k,
614+
d.refreshTimeout, d.iamTokenSource,
615+
d.dialerID, useIAMAuthNDial,
616+
)
617+
} else {
618+
cache = cloudsql.NewRefreshAheadCache(
619+
cn,
620+
d.logger,
621+
d.sqladmin, k,
622+
d.refreshTimeout, d.iamTokenSource,
623+
d.dialerID, useIAMAuthNDial,
624+
)
595625
}
626+
var count uint64
627+
c := monitoredCache{openConns: &count, connectionInfoCache: cache}
596628

597629
c.UpdateRefresh(useIAMAuthN)
598630

dialer_cache.go

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package cloudsqlconn
16+
17+
import (
18+
"sync"
19+
20+
"cloud.google.com/go/cloudsqlconn/debug"
21+
"cloud.google.com/go/cloudsqlconn/instance"
22+
)
23+
24+
type DialerCache struct {
25+
mu sync.RWMutex
26+
cache map[instance.ConnName]monitoredCache
27+
logger debug.ContextLogger
28+
}
29+
30+
// newDialerCache creates and initializes an instance of the dialer cache
31+
func newDialerCache(logger debug.ContextLogger) *DialerCache {
32+
return &DialerCache{
33+
mu: sync.RWMutex{},
34+
cache: make(map[instance.ConnName]monitoredCache),
35+
logger: logger,
36+
}
37+
}
38+
39+
// replaceAll thread-safe iterate through all cache entries, replace or removing
40+
// the entries. f() provides the replacement values
41+
// - no change: return cn, c
42+
// - replace: return an instance.ConnName and monitoredCache. This will
43+
// replace the old entry
44+
// - remove:should return the empty value instance.ConnName{} if the
45+
//
46+
// entry should be removed.
47+
// This method is not re-entrant.
48+
func (d *DialerCache) replaceAll(f func(cn instance.ConnName, c monitoredCache) (instance.ConnName, monitoredCache)) {
49+
emptyInstance := instance.ConnName{}
50+
d.mu.Lock()
51+
defer d.mu.Unlock()
52+
newCache := make(map[instance.ConnName]monitoredCache)
53+
for cn, c := range d.cache {
54+
newCn, newC := f(cn, c)
55+
// ignore entries that have empty instance names
56+
if newCn == emptyInstance {
57+
continue
58+
}
59+
newCache[newCn] = newC
60+
}
61+
d.cache = newCache
62+
}
63+
64+
// findByDomainName returns the entry that matches the domain name.
65+
// dn - the domain name
66+
// returns:
67+
//
68+
// instance.ConnName the name of the matching instance
69+
// monitoredCache the cached item
70+
// bool true when there is a result.
71+
func (d *DialerCache) findByDomainName(dn string) (instance.ConnName, monitoredCache, bool) {
72+
d.mu.RLock()
73+
defer d.mu.RUnlock()
74+
for cn, c := range d.cache {
75+
if cn.DomainName() == dn {
76+
return cn, c, true
77+
}
78+
}
79+
return instance.ConnName{}, monitoredCache{}, false
80+
}
81+
82+
// get returns the instance matching the cn
83+
func (d *DialerCache) get(cn instance.ConnName) (monitoredCache, bool) {
84+
d.mu.RLock()
85+
defer d.mu.RUnlock()
86+
c, ok := d.cache[cn]
87+
return c, ok
88+
}
89+
90+
// getOrAdd returns the cache entry, creating it if necessary. This will also
91+
// take care to remove entries with the same domain name.
92+
//
93+
// cn - the connection name to getOrAdd
94+
// f - the function to use to create a new cache, may return an error
95+
//
96+
// returns:
97+
//
98+
// monitoredCache - the cached entry
99+
// monitoredCache - the replaced entry if the cache contains an entry with the
100+
// same domain name.
101+
// error - an error if the cache entry could not be created.
102+
func (d *DialerCache) getOrAdd(cn instance.ConnName, f func() (monitoredCache, error)) (monitoredCache, monitoredCache, error) {
103+
var oldC monitoredCache
104+
105+
d.mu.RLock()
106+
c, ok := d.cache[cn]
107+
d.mu.RUnlock()
108+
if ok {
109+
return c, oldC, nil
110+
}
111+
// If not found, acquire write lock.
112+
d.mu.Lock()
113+
defer d.mu.Unlock()
114+
115+
// Look up in the map by CN again
116+
c, ok = d.cache[cn]
117+
if ok {
118+
return c, monitoredCache{}, nil
119+
}
120+
121+
// Try to get an instance with the same domain name but different instance
122+
// Remove this instance from the cache, it will be replaced.
123+
if cn.HasDomainName() {
124+
for oldCn, oc := range d.cache {
125+
if oldCn.DomainName() == cn.DomainName() && oldCn != cn {
126+
oldC = oc
127+
delete(d.cache, oldCn)
128+
break
129+
}
130+
}
131+
}
132+
133+
// Create the new instance and put it in the cache
134+
c, err := f()
135+
if err != nil {
136+
return monitoredCache{}, oldC, err
137+
}
138+
139+
// Instance created successfully. Return it.
140+
d.cache[cn] = c
141+
return c, oldC, nil
142+
}
143+
144+
func (d *DialerCache) remove(cn instance.ConnName) monitoredCache {
145+
// If not found, acquire write lock.
146+
d.mu.Lock()
147+
defer d.mu.Unlock()
148+
149+
// Look up in the map by CN again
150+
c, ok := d.cache[cn]
151+
if ok {
152+
delete(d.cache, cn)
153+
}
154+
155+
return c
156+
}

0 commit comments

Comments
 (0)