@@ -107,6 +107,7 @@ type connectionInfoCache interface {
107
107
ConnectionInfo (context.Context ) (cloudsql.ConnectionInfo , error )
108
108
UpdateRefresh (* bool )
109
109
ForceRefresh ()
110
+ UseIAMAuthN () bool
110
111
io.Closer
111
112
}
112
113
@@ -115,13 +116,26 @@ type connectionInfoCache interface {
115
116
type monitoredCache struct {
116
117
openConnsCount * uint64
117
118
119
+ mu sync.Mutex
120
+ openConns []* instrumentedConn
121
+
118
122
connectionInfoCache
119
123
}
120
124
121
125
func (c * monitoredCache ) Close () error {
122
126
if c == nil || c .connectionInfoCache == nil {
123
127
return nil
124
128
}
129
+
130
+ if atomic .LoadUint64 (c .openConnsCount ) > 0 {
131
+ for _ , socket := range c .openConns {
132
+ if ! socket .isClosed () {
133
+ _ = socket .Close () // force socket closed, ok to ignore error.
134
+ }
135
+ }
136
+ atomic .StoreUint64 (c .openConnsCount , 0 )
137
+ }
138
+
125
139
return c .connectionInfoCache .Close ()
126
140
}
127
141
@@ -145,6 +159,21 @@ func (c *monitoredCache) ConnectionInfo(ctx context.Context) (cloudsql.Connectio
145
159
return c .connectionInfoCache .ConnectionInfo (ctx )
146
160
}
147
161
162
+ func (c * monitoredCache ) purgeClosedConns () {
163
+ if c == nil || c .connectionInfoCache == nil {
164
+ return
165
+ }
166
+ c .mu .Lock ()
167
+ var open []* instrumentedConn
168
+ for _ , s := range c .openConns {
169
+ if ! s .isClosed () {
170
+ open = append (open , s )
171
+ }
172
+ }
173
+ c .openConns = open
174
+ c .mu .Unlock ()
175
+ }
176
+
148
177
// A Dialer is used to create connections to Cloud SQL instances.
149
178
//
150
179
// Use NewDialer to initialize a Dialer.
@@ -182,6 +211,10 @@ type Dialer struct {
182
211
183
212
// resolver converts instance names into DNS names.
184
213
resolver instance.ConnectionNameResolver
214
+
215
+ // domainNameTicker periodically checks any domain names to see if they
216
+ // changed.
217
+ domainNameTicker * time.Ticker
185
218
}
186
219
187
220
var (
@@ -205,6 +238,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
205
238
logger : nullLogger {},
206
239
useragents : []string {userAgent },
207
240
serviceUniverse : "googleapis.com" ,
241
+ failoverPeriod : cloudsql .FailoverPeriod ,
208
242
}
209
243
for _ , opt := range opts {
210
244
opt (cfg )
@@ -218,6 +252,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
218
252
if cfg .setIAMAuthNTokenSource && ! cfg .useIAMAuthN {
219
253
return nil , errUseTokenSource
220
254
}
255
+
221
256
// Add this to the end to make sure it's not overridden
222
257
cfg .sqladminOpts = append (cfg .sqladminOpts , option .WithUserAgent (strings .Join (cfg .useragents , " " )))
223
258
@@ -231,7 +266,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
231
266
}
232
267
ud , err := c .GetUniverseDomain ()
233
268
if err != nil {
234
- return nil , fmt .Errorf ("failed to getOrAdd universe domain: %v" , err )
269
+ return nil , fmt .Errorf ("failed to get universe domain: %v" , err )
235
270
}
236
271
cfg .credentialsUniverse = ud
237
272
cfg .sqladminOpts = append (cfg .sqladminOpts , option .WithTokenSource (c .TokenSource ))
@@ -301,8 +336,28 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
301
336
dialFunc : cfg .dialFunc ,
302
337
resolver : r ,
303
338
}
339
+
340
+ // If the failover period is set, start a goroutine to periodically
341
+ // check for DNS changes.
342
+ if cfg .failoverPeriod > 0 {
343
+ d .initFailoverRoutine (ctx , cfg .failoverPeriod )
344
+ }
345
+
304
346
return d , nil
305
347
}
348
+ func (d * Dialer ) initFailoverRoutine (ctx context.Context , p time.Duration ) {
349
+ d .domainNameTicker = time .NewTicker (p )
350
+ go func () {
351
+ for {
352
+ select {
353
+ case <- d .domainNameTicker .C :
354
+ d .pollDomainNames (ctx )
355
+ case <- d .closed :
356
+ return
357
+ }
358
+ }
359
+ }()
360
+ }
306
361
307
362
// Dial returns a net.Conn connected to the specified Cloud SQL instance. The
308
363
// icn argument must be the instance's connection name, which is in the format
@@ -416,6 +471,13 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
416
471
trace .RecordOpenConnections (context .Background (), int64 (n ), d .dialerID , cn .String ())
417
472
}, d .dialerID , cn .String ())
418
473
474
+ // If this connection was opened using a Domain Name, then store it for later
475
+ // in case it needs to be forcibly closed.
476
+ if cn .DomainName () != "" {
477
+ c .mu .Lock ()
478
+ c .openConns = append (c .openConns , iConn )
479
+ c .mu .Unlock ()
480
+ }
419
481
return iConn , nil
420
482
}
421
483
@@ -517,6 +579,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
517
579
type instrumentedConn struct {
518
580
net.Conn
519
581
closeFunc func ()
582
+ mu sync.RWMutex
520
583
closed bool
521
584
dialerID string
522
585
connName string
@@ -542,6 +605,13 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
542
605
return bytesWritten , err
543
606
}
544
607
608
+ // isClosed returns true if this connection is closing or is already closed.
609
+ func (i * instrumentedConn ) isClosed () bool {
610
+ i .mu .RLock ()
611
+ defer i .mu .RUnlock ()
612
+ return i .closed
613
+ }
614
+
545
615
// Close delegates to the underlying net.Conn interface and reports the close
546
616
// to the provided closeFunc only when Close returns no error.
547
617
func (i * instrumentedConn ) Close () error {
@@ -565,13 +635,56 @@ func (d *Dialer) Close() error {
565
635
}
566
636
close (d .closed )
567
637
638
+ if d .domainNameTicker != nil {
639
+ d .domainNameTicker .Stop ()
640
+ }
641
+
568
642
d .cache .replaceAll (func (cn instance.ConnName , c * monitoredCache ) (instance.ConnName , * monitoredCache ) {
569
643
c .Close () // close the monitoredCache
570
644
return instance.ConnName {}, nil // Remove from cache
571
645
})
572
646
return nil
573
647
}
574
648
649
+ func (d * Dialer ) pollDomainNames (ctx context.Context ) {
650
+ d .cache .replaceAll (func (cn instance.ConnName , cache * monitoredCache ) (instance.ConnName , * monitoredCache ) {
651
+ if cn .DomainName () == "" {
652
+ return cn , cache
653
+ }
654
+
655
+ // Resolve the domain name.
656
+ newCn , err := d .resolver .Resolve (ctx , cn .DomainName ())
657
+
658
+ if err != nil {
659
+ // the domain name no longer resolves to a valid instance
660
+ d .logger .Debugf (ctx , "[failover] unable to resolve DNS for instance %s: %v" , cn .DomainName (), err )
661
+ cache .Close ()
662
+ return instance.ConnName {}, nil
663
+ } else if newCn != cn {
664
+ d .logger .Debugf (ctx , "domain name %s changed from old instance %s to new instance %s" ,
665
+ cn .DomainName (), cn .String (), newCn .String ())
666
+
667
+ useIamAuthn := cache .UseIAMAuthN ()
668
+ // The domain name points to a different instance.
669
+ cache .Close ()
670
+
671
+ newC , err := d .createConnectionInfoCache (ctx , cn , & useIamAuthn )
672
+ if err != nil {
673
+ d .logger .Debugf (ctx , "error connecting to new instance %s, %s: %v" ,
674
+ cn .DomainName (), newCn .String (), err )
675
+ return instance.ConnName {}, nil
676
+ }
677
+ return newCn , newC
678
+ }
679
+
680
+ // Remove closed sockets from cache.openConns
681
+ cache .purgeClosedConns ()
682
+ return cn , cache
683
+
684
+ })
685
+
686
+ }
687
+
575
688
// connectionInfoCache is a helper function for returning the appropriate
576
689
// connection info Cache in a threadsafe way. It will create a new cache,
577
690
// modify the existing one, or leave it unchanged as needed.
0 commit comments