Skip to content

Commit 068a5f5

Browse files
committed
feat: Automatially check for DNS changes periodically. On change, close all connections and create a new dialer.
chore: Expose the refresh strategy UseIAMAuthN() value to the dialer. Part of #842 chore: Add domain name to the cloudsql.ConnName struct Feat: Check for DNS changes on connect. On change, close all connections and create a new dialer. feat: Automatially check for DNS changes periodically. On change, close all connections and create a new dialer. wip: eno changes wip: eno interface cleanup wip: convert monitoredInstance to *monitoredInstance
1 parent 135ec39 commit 068a5f5

File tree

9 files changed

+698
-86
lines changed

9 files changed

+698
-86
lines changed

README.md

+37-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ func connect() {
234234
// ... etc
235235
}
236236
```
237-
### Using DNS to identify an instance
237+
238+
### Using DNS domain names to identify instances
238239

239240
The connector can be configured to use DNS to look up an instance. This would
240241
allow you to configure your application to connect to a database instance, and
@@ -292,6 +293,41 @@ func connect() {
292293
}
293294
```
294295

296+
### Automatic fail-over using DNS domain names
297+
298+
When the connector is configured using a domain name, the connector will
299+
periodically check if the DNS record for an instance changes. When the connector
300+
detects that the domain name refers to a different instance, the connector will
301+
close all open connections to the old instance. Subsequent connection attempts
302+
will be directed to the new instance.
303+
304+
For example: suppose application is configured to connect using the
305+
domain name `prod-db.mycompany.example.com`. Initially the corporate DNS
306+
zone has a TXT record with the value `my-project:region:my-instance`. The
307+
application establishes connections to the `my-project:region:my-instance`
308+
Cloud SQL instance.
309+
310+
Then, to reconfigure the application using a different database
311+
instance: `my-project:other-region:my-instance-2`. You update the DNS record
312+
for `prod-db.mycompany.example.com` with the target
313+
`my-project:other-region:my-instance-2`
314+
315+
The connector inside the application detects the change to this
316+
DNS entry. Now, when the application connects to its database using the
317+
domain name `prod-db.mycompany.example.com`, it will connect to the
318+
`my-project:other-region:my-instance-2` Cloud SQL instance.
319+
320+
The connector will automatically close all existing connections to
321+
`my-project:region:my-instance`. This will force the connection pools to
322+
establish new connections. Also, it may cause database queries in progress
323+
to fail.
324+
325+
The connector will poll for changes to the DNS name every 30 seconds by default.
326+
You may configure the frequency of the connections using the option
327+
`WithFailoverPeriod(d time.Duration)`. When this is set to 0, the connector will
328+
disable polling and only check if the DNS record changed when it is
329+
creating a new connection.
330+
295331

296332
### Using Options
297333

dialer.go

+128-56
Original file line numberDiff line numberDiff line change
@@ -110,20 +110,12 @@ type connectionInfoCache interface {
110110
io.Closer
111111
}
112112

113-
// monitoredCache is a wrapper around a connectionInfoCache that tracks the
114-
// number of connections to the associated instance.
115-
type monitoredCache struct {
116-
openConns *uint64
117-
118-
connectionInfoCache
119-
}
120-
121113
// A Dialer is used to create connections to Cloud SQL instances.
122114
//
123115
// Use NewDialer to initialize a Dialer.
124116
type Dialer struct {
125117
lock sync.RWMutex
126-
cache map[instance.ConnName]monitoredCache
118+
cache map[instance.ConnName]*monitoredCache
127119
keyGenerator *keyGenerator
128120
refreshTimeout time.Duration
129121
// closed reports if the dialer has been closed.
@@ -155,7 +147,8 @@ type Dialer struct {
155147
iamTokenSource oauth2.TokenSource
156148

157149
// resolver converts instance names into DNS names.
158-
resolver instance.ConnectionNameResolver
150+
resolver instance.ConnectionNameResolver
151+
failoverPeriod time.Duration
159152
}
160153

161154
var (
@@ -179,6 +172,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
179172
logger: nullLogger{},
180173
useragents: []string{userAgent},
181174
serviceUniverse: "googleapis.com",
175+
failoverPeriod: cloudsql.FailoverPeriod,
182176
}
183177
for _, opt := range opts {
184178
opt(cfg)
@@ -192,6 +186,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
192186
if cfg.setIAMAuthNTokenSource && !cfg.useIAMAuthN {
193187
return nil, errUseTokenSource
194188
}
189+
195190
// Add this to the end to make sure it's not overridden
196191
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " ")))
197192

@@ -263,7 +258,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
263258

264259
d := &Dialer{
265260
closed: make(chan struct{}),
266-
cache: make(map[instance.ConnName]monitoredCache),
261+
cache: make(map[instance.ConnName]*monitoredCache),
267262
lazyRefresh: cfg.lazyRefresh,
268263
keyGenerator: g,
269264
refreshTimeout: cfg.refreshTimeout,
@@ -274,7 +269,9 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
274269
iamTokenSource: cfg.iamLoginTokenSource,
275270
dialFunc: cfg.dialFunc,
276271
resolver: r,
272+
failoverPeriod: cfg.failoverPeriod,
277273
}
274+
278275
return d, nil
279276
}
280277

@@ -380,15 +377,24 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
380377

381378
latency := time.Since(startTime).Milliseconds()
382379
go func() {
383-
n := atomic.AddUint64(c.openConns, 1)
380+
n := atomic.AddUint64(c.openConnsCount, 1)
384381
trace.RecordOpenConnections(ctx, int64(n), d.dialerID, cn.String())
385382
trace.RecordDialLatency(ctx, icn, d.dialerID, latency)
386383
}()
387384

388-
return newInstrumentedConn(tlsConn, func() {
389-
n := atomic.AddUint64(c.openConns, ^uint64(0))
385+
iConn := newInstrumentedConn(tlsConn, func() {
386+
n := atomic.AddUint64(c.openConnsCount, ^uint64(0))
390387
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
391-
}, d.dialerID, cn.String()), nil
388+
}, d.dialerID, cn.String())
389+
390+
// If this connection was opened using a Domain Name, then store it for later
391+
// in case it needs to be forcibly closed.
392+
if cn.DomainName() != "" {
393+
c.mu.Lock()
394+
c.openConns = append(c.openConns, iConn)
395+
c.mu.Unlock()
396+
}
397+
return iConn, nil
392398
}
393399

394400
// removeCached stops all background refreshes and deletes the connection
@@ -448,7 +454,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
448454
}
449455
ci, err := c.ConnectionInfo(ctx)
450456
if err != nil {
451-
d.removeCached(ctx, cn, c, err)
457+
d.removeCached(ctx, cn, c.connectionInfoCache, err)
452458
return "", err
453459
}
454460
return ci.DBVersion, nil
@@ -472,7 +478,7 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
472478
}
473479
_, err = c.ConnectionInfo(ctx)
474480
if err != nil {
475-
d.removeCached(ctx, cn, c, err)
481+
d.removeCached(ctx, cn, c.connectionInfoCache, err)
476482
}
477483
return err
478484
}
@@ -493,6 +499,8 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
493499
type instrumentedConn struct {
494500
net.Conn
495501
closeFunc func()
502+
mu sync.RWMutex
503+
closed bool
496504
dialerID string
497505
connName string
498506
}
@@ -517,9 +525,19 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
517525
return bytesWritten, err
518526
}
519527

528+
// isClosed returns true if this connection is closing or is already closed.
529+
func (i *instrumentedConn) isClosed() bool {
530+
i.mu.RLock()
531+
defer i.mu.RUnlock()
532+
return i.closed
533+
}
534+
520535
// Close delegates to the underlying net.Conn interface and reports the close
521536
// to the provided closeFunc only when Close returns no error.
522537
func (i *instrumentedConn) Close() error {
538+
i.mu.Lock()
539+
defer i.mu.Unlock()
540+
i.closed = true
523541
err := i.Conn.Close()
524542
if err != nil {
525543
return err
@@ -551,50 +569,104 @@ func (d *Dialer) Close() error {
551569
// modify the existing one, or leave it unchanged as needed.
552570
func (d *Dialer) connectionInfoCache(
553571
ctx context.Context, cn instance.ConnName, useIAMAuthN *bool,
554-
) (monitoredCache, error) {
572+
) (*monitoredCache, error) {
555573
d.lock.RLock()
556574
c, ok := d.cache[cn]
557575
d.lock.RUnlock()
558-
if !ok {
559-
d.lock.Lock()
560-
defer d.lock.Unlock()
561-
// Recheck to ensure instance wasn't created or changed between locks
562-
c, ok = d.cache[cn]
563-
if !ok {
564-
var useIAMAuthNDial bool
565-
if useIAMAuthN != nil {
566-
useIAMAuthNDial = *useIAMAuthN
567-
}
568-
d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String())
569-
k, err := d.keyGenerator.rsaKey()
570-
if err != nil {
571-
return monitoredCache{}, err
572-
}
573-
var cache connectionInfoCache
574-
if d.lazyRefresh {
575-
cache = cloudsql.NewLazyRefreshCache(
576-
cn,
577-
d.logger,
578-
d.sqladmin, k,
579-
d.refreshTimeout, d.iamTokenSource,
580-
d.dialerID, useIAMAuthNDial,
581-
)
582-
} else {
583-
cache = cloudsql.NewRefreshAheadCache(
584-
cn,
585-
d.logger,
586-
d.sqladmin, k,
587-
d.refreshTimeout, d.iamTokenSource,
588-
d.dialerID, useIAMAuthNDial,
589-
)
590-
}
591-
var count uint64
592-
c = monitoredCache{openConns: &count, connectionInfoCache: cache}
593-
d.cache[cn] = c
594-
}
576+
577+
// recheck the domain name, this may close the cache.
578+
if ok {
579+
c.checkDomainName(ctx)
580+
}
581+
582+
if ok && !c.isClosed() {
583+
c.UpdateRefresh(useIAMAuthN)
584+
return c, nil
595585
}
596586

597-
c.UpdateRefresh(useIAMAuthN)
587+
d.lock.Lock()
588+
defer d.lock.Unlock()
589+
590+
// Recheck to ensure instance wasn't created or changed between locks
591+
c, ok = d.cache[cn]
592+
593+
// c exists and is not closed
594+
if ok && !c.isClosed() {
595+
c.UpdateRefresh(useIAMAuthN)
596+
return c, nil
597+
}
598+
599+
// c exists and is closed, remove it from the cache
600+
if ok {
601+
// remove it.
602+
_ = c.Close()
603+
delete(d.cache, cn)
604+
}
605+
606+
// c does not exist, check for matching domain and close it
607+
oldCn, old, ok := d.findByDn(cn)
608+
if ok {
609+
_ = old.Close()
610+
delete(d.cache, oldCn)
611+
}
612+
613+
// Create a new instance of monitoredCache
614+
var useIAMAuthNDial bool
615+
if useIAMAuthN != nil {
616+
useIAMAuthNDial = *useIAMAuthN
617+
}
618+
d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String())
619+
k, err := d.keyGenerator.rsaKey()
620+
if err != nil {
621+
return nil, err
622+
}
623+
var cache connectionInfoCache
624+
if d.lazyRefresh {
625+
cache = cloudsql.NewLazyRefreshCache(
626+
cn,
627+
d.logger,
628+
d.sqladmin, k,
629+
d.refreshTimeout, d.iamTokenSource,
630+
d.dialerID, useIAMAuthNDial,
631+
)
632+
} else {
633+
cache = cloudsql.NewRefreshAheadCache(
634+
cn,
635+
d.logger,
636+
d.sqladmin, k,
637+
d.refreshTimeout, d.iamTokenSource,
638+
d.dialerID, useIAMAuthNDial,
639+
)
640+
}
641+
c = newMonitoredCache(ctx, cache, cn, d.failoverPeriod, d.resolver, d.logger)
642+
d.cache[cn] = c
598643

599644
return c, nil
600645
}
646+
647+
// getOrAdd returns the cache entry, creating it if necessary. This will also
648+
// take care to remove entries with the same domain name.
649+
//
650+
// cn - the connection name to getOrAdd
651+
//
652+
// returns:
653+
//
654+
// monitoredCache - the cached entry
655+
// bool ok - the instance exists
656+
// instance.ConnName - the key to the old entry with the same domain name
657+
//
658+
// This method does not manage locks.
659+
func (d *Dialer) findByDn(cn instance.ConnName) (instance.ConnName, *monitoredCache, bool) {
660+
661+
// Try to get an instance with the same domain name but different instance
662+
// Remove this instance from the cache, it will be replaced.
663+
if cn.HasDomainName() {
664+
for oldCn, oc := range d.cache {
665+
if oldCn.DomainName() == cn.DomainName() && oldCn != cn {
666+
return oldCn, oc, true
667+
}
668+
}
669+
}
670+
671+
return instance.ConnName{}, nil, false
672+
}

0 commit comments

Comments
 (0)