Skip to content

Commit dcad42f

Browse files
committed
feat: Automatially check for DNS changes periodically. On change, close all connections and create a new dialer.
chore: Expose the refresh strategy UseIAMAuthN() value to the dialer. Part of #842 chore: Add domain name to the cloudsql.ConnName struct Feat: Check for DNS changes on connect. On change, close all connections and create a new dialer. feat: Automatially check for DNS changes periodically. On change, close all connections and create a new dialer. wip: eno changes wip: eno interface cleanup wip: convert monitoredInstance to *monitoredInstance
1 parent 135ec39 commit dcad42f

10 files changed

+716
-93
lines changed

README.md

+37-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ func connect() {
234234
// ... etc
235235
}
236236
```
237-
### Using DNS to identify an instance
237+
238+
### Using DNS domain names to identify instances
238239

239240
The connector can be configured to use DNS to look up an instance. This would
240241
allow you to configure your application to connect to a database instance, and
@@ -292,6 +293,41 @@ func connect() {
292293
}
293294
```
294295

296+
### Automatic fail-over using DNS domain names
297+
298+
When the connector is configured using a domain name, the connector will
299+
periodically check if the DNS record for an instance changes. When the connector
300+
detects that the domain name refers to a different instance, the connector will
301+
close all open connections to the old instance. Subsequent connection attempts
302+
will be directed to the new instance.
303+
304+
For example: suppose application is configured to connect using the
305+
domain name `prod-db.mycompany.example.com`. Initially the corporate DNS
306+
zone has a TXT record with the value `my-project:region:my-instance`. The
307+
application establishes connections to the `my-project:region:my-instance`
308+
Cloud SQL instance.
309+
310+
Then, to reconfigure the application using a different database
311+
instance: `my-project:other-region:my-instance-2`. You update the DNS record
312+
for `prod-db.mycompany.example.com` with the target
313+
`my-project:other-region:my-instance-2`
314+
315+
The connector inside the application detects the change to this
316+
DNS entry. Now, when the application connects to its database using the
317+
domain name `prod-db.mycompany.example.com`, it will connect to the
318+
`my-project:other-region:my-instance-2` Cloud SQL instance.
319+
320+
The connector will automatically close all existing connections to
321+
`my-project:region:my-instance`. This will force the connection pools to
322+
establish new connections. Also, it may cause database queries in progress
323+
to fail.
324+
325+
The connector will poll for changes to the DNS name every 30 seconds by default.
326+
You may configure the frequency of the connections using the option
327+
`WithFailoverPeriod(d time.Duration)`. When this is set to 0, the connector will
328+
disable polling and only check if the DNS record changed when it is
329+
creating a new connection.
330+
295331

296332
### Using Options
297333

dialer.go

+136-63
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ type keyGenerator struct {
7676
// - generate an RSA key lazily when it's requested, or
7777
// - (default) immediately generate an RSA key as part of the initializer.
7878
func newKeyGenerator(
79-
k *rsa.PrivateKey, lazy bool, genFunc func() (*rsa.PrivateKey, error),
79+
k *rsa.PrivateKey, lazy bool, genFunc func() (*rsa.PrivateKey, error),
8080
) (*keyGenerator, error) {
8181
g := &keyGenerator{genFunc: genFunc}
8282
switch {
@@ -107,23 +107,16 @@ type connectionInfoCache interface {
107107
ConnectionInfo(context.Context) (cloudsql.ConnectionInfo, error)
108108
UpdateRefresh(*bool)
109109
ForceRefresh()
110+
UseIAMAuthN() bool
110111
io.Closer
111112
}
112113

113-
// monitoredCache is a wrapper around a connectionInfoCache that tracks the
114-
// number of connections to the associated instance.
115-
type monitoredCache struct {
116-
openConns *uint64
117-
118-
connectionInfoCache
119-
}
120-
121114
// A Dialer is used to create connections to Cloud SQL instances.
122115
//
123116
// Use NewDialer to initialize a Dialer.
124117
type Dialer struct {
125118
lock sync.RWMutex
126-
cache map[instance.ConnName]monitoredCache
119+
cache map[instance.ConnName]*monitoredCache
127120
keyGenerator *keyGenerator
128121
refreshTimeout time.Duration
129122
// closed reports if the dialer has been closed.
@@ -155,7 +148,8 @@ type Dialer struct {
155148
iamTokenSource oauth2.TokenSource
156149

157150
// resolver converts instance names into DNS names.
158-
resolver instance.ConnectionNameResolver
151+
resolver instance.ConnectionNameResolver
152+
failoverPeriod time.Duration
159153
}
160154

161155
var (
@@ -179,6 +173,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
179173
logger: nullLogger{},
180174
useragents: []string{userAgent},
181175
serviceUniverse: "googleapis.com",
176+
failoverPeriod: cloudsql.FailoverPeriod,
182177
}
183178
for _, opt := range opts {
184179
opt(cfg)
@@ -192,6 +187,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
192187
if cfg.setIAMAuthNTokenSource && !cfg.useIAMAuthN {
193188
return nil, errUseTokenSource
194189
}
190+
195191
// Add this to the end to make sure it's not overridden
196192
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " ")))
197193

@@ -219,7 +215,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
219215
if cfg.setUniverseDomain && cfg.setAdminAPIEndpoint {
220216
return nil, errors.New(
221217
"can not use WithAdminAPIEndpoint and WithUniverseDomain Options together, " +
222-
"use WithAdminAPIEndpoint (it already contains the universe domain)",
218+
"use WithAdminAPIEndpoint (it already contains the universe domain)",
223219
)
224220
}
225221

@@ -263,7 +259,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
263259

264260
d := &Dialer{
265261
closed: make(chan struct{}),
266-
cache: make(map[instance.ConnName]monitoredCache),
262+
cache: make(map[instance.ConnName]*monitoredCache),
267263
lazyRefresh: cfg.lazyRefresh,
268264
keyGenerator: g,
269265
refreshTimeout: cfg.refreshTimeout,
@@ -274,7 +270,9 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
274270
iamTokenSource: cfg.iamLoginTokenSource,
275271
dialFunc: cfg.dialFunc,
276272
resolver: r,
273+
failoverPeriod: cfg.failoverPeriod,
277274
}
275+
278276
return d, nil
279277
}
280278

@@ -380,22 +378,31 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
380378

381379
latency := time.Since(startTime).Milliseconds()
382380
go func() {
383-
n := atomic.AddUint64(c.openConns, 1)
381+
n := atomic.AddUint64(c.openConnsCount, 1)
384382
trace.RecordOpenConnections(ctx, int64(n), d.dialerID, cn.String())
385383
trace.RecordDialLatency(ctx, icn, d.dialerID, latency)
386384
}()
387385

388-
return newInstrumentedConn(tlsConn, func() {
389-
n := atomic.AddUint64(c.openConns, ^uint64(0))
386+
iConn := newInstrumentedConn(tlsConn, func() {
387+
n := atomic.AddUint64(c.openConnsCount, ^uint64(0))
390388
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
391-
}, d.dialerID, cn.String()), nil
389+
}, d.dialerID, cn.String())
390+
391+
// If this connection was opened using a Domain Name, then store it for later
392+
// in case it needs to be forcibly closed.
393+
if cn.DomainName() != "" {
394+
c.mu.Lock()
395+
c.openConns = append(c.openConns, iConn)
396+
c.mu.Unlock()
397+
}
398+
return iConn, nil
392399
}
393400

394401
// removeCached stops all background refreshes and deletes the connection
395402
// info cache from the map of caches.
396403
func (d *Dialer) removeCached(
397-
ctx context.Context,
398-
i instance.ConnName, c connectionInfoCache, err error,
404+
ctx context.Context,
405+
i instance.ConnName, c connectionInfoCache, err error,
399406
) {
400407
d.logger.Debugf(
401408
ctx,
@@ -413,8 +420,8 @@ func (d *Dialer) removeCached(
413420
// the cache is unexpired. The time comparisons strip the monotonic clock value
414421
// to ensure an accurate result, even after laptop sleep.
415422
func validClientCert(
416-
ctx context.Context, cn instance.ConnName,
417-
l debug.ContextLogger, expiration time.Time,
423+
ctx context.Context, cn instance.ConnName,
424+
l debug.ContextLogger, expiration time.Time,
418425
) bool {
419426
// Use UTC() to strip monotonic clock value to guard against inaccurate
420427
// comparisons, especially after laptop sleep.
@@ -448,7 +455,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
448455
}
449456
ci, err := c.ConnectionInfo(ctx)
450457
if err != nil {
451-
d.removeCached(ctx, cn, c, err)
458+
d.removeCached(ctx, cn, c.connectionInfoCache, err)
452459
return "", err
453460
}
454461
return ci.DBVersion, nil
@@ -472,7 +479,7 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
472479
}
473480
_, err = c.ConnectionInfo(ctx)
474481
if err != nil {
475-
d.removeCached(ctx, cn, c, err)
482+
d.removeCached(ctx, cn, c.connectionInfoCache, err)
476483
}
477484
return err
478485
}
@@ -493,6 +500,8 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
493500
type instrumentedConn struct {
494501
net.Conn
495502
closeFunc func()
503+
mu sync.RWMutex
504+
closed bool
496505
dialerID string
497506
connName string
498507
}
@@ -517,9 +526,19 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
517526
return bytesWritten, err
518527
}
519528

529+
// isClosed returns true if this connection is closing or is already closed.
530+
func (i *instrumentedConn) isClosed() bool {
531+
i.mu.RLock()
532+
defer i.mu.RUnlock()
533+
return i.closed
534+
}
535+
520536
// Close delegates to the underlying net.Conn interface and reports the close
521537
// to the provided closeFunc only when Close returns no error.
522538
func (i *instrumentedConn) Close() error {
539+
i.mu.Lock()
540+
defer i.mu.Unlock()
541+
i.closed = true
523542
err := i.Conn.Close()
524543
if err != nil {
525544
return err
@@ -550,51 +569,105 @@ func (d *Dialer) Close() error {
550569
// connection info Cache in a threadsafe way. It will create a new cache,
551570
// modify the existing one, or leave it unchanged as needed.
552571
func (d *Dialer) connectionInfoCache(
553-
ctx context.Context, cn instance.ConnName, useIAMAuthN *bool,
554-
) (monitoredCache, error) {
572+
ctx context.Context, cn instance.ConnName, useIAMAuthN *bool,
573+
) (*monitoredCache, error) {
555574
d.lock.RLock()
556575
c, ok := d.cache[cn]
557576
d.lock.RUnlock()
558-
if !ok {
559-
d.lock.Lock()
560-
defer d.lock.Unlock()
561-
// Recheck to ensure instance wasn't created or changed between locks
562-
c, ok = d.cache[cn]
563-
if !ok {
564-
var useIAMAuthNDial bool
565-
if useIAMAuthN != nil {
566-
useIAMAuthNDial = *useIAMAuthN
567-
}
568-
d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String())
569-
k, err := d.keyGenerator.rsaKey()
570-
if err != nil {
571-
return monitoredCache{}, err
572-
}
573-
var cache connectionInfoCache
574-
if d.lazyRefresh {
575-
cache = cloudsql.NewLazyRefreshCache(
576-
cn,
577-
d.logger,
578-
d.sqladmin, k,
579-
d.refreshTimeout, d.iamTokenSource,
580-
d.dialerID, useIAMAuthNDial,
581-
)
582-
} else {
583-
cache = cloudsql.NewRefreshAheadCache(
584-
cn,
585-
d.logger,
586-
d.sqladmin, k,
587-
d.refreshTimeout, d.iamTokenSource,
588-
d.dialerID, useIAMAuthNDial,
589-
)
590-
}
591-
var count uint64
592-
c = monitoredCache{openConns: &count, connectionInfoCache: cache}
593-
d.cache[cn] = c
594-
}
577+
578+
// recheck the domain name, this may close the cache.
579+
if ok {
580+
c.checkDomainName(ctx)
581+
}
582+
583+
if ok && !c.isClosed() {
584+
c.UpdateRefresh(useIAMAuthN)
585+
return c, nil
586+
}
587+
588+
d.lock.Lock()
589+
defer d.lock.Unlock()
590+
591+
// Recheck to ensure instance wasn't created or changed between locks
592+
c, ok = d.cache[cn]
593+
594+
// c exists and is not closed
595+
if ok && !c.isClosed() {
596+
c.UpdateRefresh(useIAMAuthN)
597+
return c, nil
598+
}
599+
600+
// c exists and is closed, remove it from the cache
601+
if ok {
602+
// remove it.
603+
_ = c.Close()
604+
delete(d.cache, cn)
595605
}
596606

597-
c.UpdateRefresh(useIAMAuthN)
607+
// c does not exist, check for matching domain and close it
608+
oldCn, old, ok := d.findByDn(cn)
609+
if ok {
610+
_ = old.Close()
611+
delete(d.cache, oldCn)
612+
}
598613

614+
// Create a new instance of monitoredCache
615+
var useIAMAuthNDial bool
616+
if useIAMAuthN != nil {
617+
useIAMAuthNDial = *useIAMAuthN
618+
}
619+
d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String())
620+
k, err := d.keyGenerator.rsaKey()
621+
if err != nil {
622+
return nil, err
623+
}
624+
var cache connectionInfoCache
625+
if d.lazyRefresh {
626+
cache = cloudsql.NewLazyRefreshCache(
627+
cn,
628+
d.logger,
629+
d.sqladmin, k,
630+
d.refreshTimeout, d.iamTokenSource,
631+
d.dialerID, useIAMAuthNDial,
632+
)
633+
} else {
634+
cache = cloudsql.NewRefreshAheadCache(
635+
cn,
636+
d.logger,
637+
d.sqladmin, k,
638+
d.refreshTimeout, d.iamTokenSource,
639+
d.dialerID, useIAMAuthNDial,
640+
)
641+
}
642+
c = newMonitoredCache(ctx, cache, cn, d.failoverPeriod, d.resolver, d.logger)
643+
d.cache[cn] = c
644+
599645
return c, nil
600646
}
647+
648+
// getOrAdd returns the cache entry, creating it if necessary. This will also
649+
// take care to remove entries with the same domain name.
650+
//
651+
// cn - the connection name to getOrAdd
652+
//
653+
// returns:
654+
//
655+
// monitoredCache - the cached entry
656+
// bool ok - the instance exists
657+
// instance.ConnName - the key to the old entry with the same domain name
658+
//
659+
// This method does not manage locks.
660+
func (d *Dialer) findByDn(cn instance.ConnName) (instance.ConnName, *monitoredCache, bool) {
661+
662+
// Try to get an instance with the same domain name but different instance
663+
// Remove this instance from the cache, it will be replaced.
664+
if cn.HasDomainName() {
665+
for oldCn, oc := range d.cache {
666+
if oldCn.DomainName() == cn.DomainName() && oldCn != cn {
667+
return oldCn, oc, true
668+
}
669+
}
670+
}
671+
672+
return instance.ConnName{}, nil, false
673+
}

0 commit comments

Comments
 (0)