@@ -107,13 +107,17 @@ type connectionInfoCache interface {
107
107
ConnectionInfo (context.Context ) (cloudsql.ConnectionInfo , error )
108
108
UpdateRefresh (* bool )
109
109
ForceRefresh ()
110
+ UseIAMAuthN () bool
110
111
io.Closer
111
112
}
112
113
113
114
// monitoredCache is a wrapper around a connectionInfoCache that tracks the
114
115
// number of connections to the associated instance.
115
116
type monitoredCache struct {
116
- openConns * uint64
117
+ openConnsCount * uint64
118
+
119
+ mu sync.Mutex
120
+ openConns []* instrumentedConn
117
121
118
122
connectionInfoCache
119
123
}
@@ -122,6 +126,16 @@ func (c *monitoredCache) Close() error {
122
126
if c == nil || c .connectionInfoCache == nil {
123
127
return nil
124
128
}
129
+
130
+ if atomic .LoadUint64 (c .openConnsCount ) > 0 {
131
+ for _ , socket := range c .openConns {
132
+ if ! socket .isClosed () {
133
+ _ = socket .Close () // force socket closed, ok to ignore error.
134
+ }
135
+ }
136
+ atomic .StoreUint64 (c .openConnsCount , 0 )
137
+ }
138
+
125
139
return c .connectionInfoCache .Close ()
126
140
}
127
141
@@ -145,6 +159,21 @@ func (c *monitoredCache) ConnectionInfo(ctx context.Context) (cloudsql.Connectio
145
159
return c .connectionInfoCache .ConnectionInfo (ctx )
146
160
}
147
161
162
+ func (c * monitoredCache ) purgeClosedConns () {
163
+ if c == nil || c .connectionInfoCache == nil {
164
+ return
165
+ }
166
+ c .mu .Lock ()
167
+ var open []* instrumentedConn
168
+ for _ , s := range c .openConns {
169
+ if ! s .isClosed () {
170
+ open = append (open , s )
171
+ }
172
+ }
173
+ c .openConns = open
174
+ c .mu .Unlock ()
175
+ }
176
+
148
177
// A Dialer is used to create connections to Cloud SQL instances.
149
178
//
150
179
// Use NewDialer to initialize a Dialer.
@@ -182,6 +211,10 @@ type Dialer struct {
182
211
183
212
// resolver converts instance names into DNS names.
184
213
resolver instance.ConnectionNameResolver
214
+
215
+ // domainNameTicker periodically checks any domain names to see if they
216
+ // changed.
217
+ domainNameTicker * time.Ticker
185
218
}
186
219
187
220
var (
@@ -205,6 +238,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
205
238
logger : nullLogger {},
206
239
useragents : []string {userAgent },
207
240
serviceUniverse : "googleapis.com" ,
241
+ failoverPeriod : cloudsql .FailoverPeriod ,
208
242
}
209
243
for _ , opt := range opts {
210
244
opt (cfg )
@@ -218,6 +252,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
218
252
if cfg .setIAMAuthNTokenSource && ! cfg .useIAMAuthN {
219
253
return nil , errUseTokenSource
220
254
}
255
+
221
256
// Add this to the end to make sure it's not overridden
222
257
cfg .sqladminOpts = append (cfg .sqladminOpts , option .WithUserAgent (strings .Join (cfg .useragents , " " )))
223
258
@@ -231,7 +266,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
231
266
}
232
267
ud , err := c .GetUniverseDomain ()
233
268
if err != nil {
234
- return nil , fmt .Errorf ("failed to getOrAdd universe domain: %v" , err )
269
+ return nil , fmt .Errorf ("failed to get universe domain: %v" , err )
235
270
}
236
271
cfg .credentialsUniverse = ud
237
272
cfg .sqladminOpts = append (cfg .sqladminOpts , option .WithTokenSource (c .TokenSource ))
@@ -301,8 +336,28 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
301
336
dialFunc : cfg .dialFunc ,
302
337
resolver : r ,
303
338
}
339
+
340
+ // If the failover period is set, start a goroutine to periodically
341
+ // check for DNS changes.
342
+ if cfg .failoverPeriod > 0 {
343
+ d .initFailoverRoutine (ctx , cfg .failoverPeriod )
344
+ }
345
+
304
346
return d , nil
305
347
}
348
+ func (d * Dialer ) initFailoverRoutine (ctx context.Context , p time.Duration ) {
349
+ d .domainNameTicker = time .NewTicker (p )
350
+ go func () {
351
+ for {
352
+ select {
353
+ case <- d .domainNameTicker .C :
354
+ d .pollDomainNames (ctx )
355
+ case <- d .closed :
356
+ return
357
+ }
358
+ }
359
+ }()
360
+ }
306
361
307
362
// Dial returns a net.Conn connected to the specified Cloud SQL instance. The
308
363
// icn argument must be the instance's connection name, which is in the format
@@ -406,16 +461,23 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
406
461
407
462
latency := time .Since (startTime ).Milliseconds ()
408
463
go func () {
409
- n := atomic .AddUint64 (c .openConns , 1 )
464
+ n := atomic .AddUint64 (c .openConnsCount , 1 )
410
465
trace .RecordOpenConnections (ctx , int64 (n ), d .dialerID , cn .String ())
411
466
trace .RecordDialLatency (ctx , icn , d .dialerID , latency )
412
467
}()
413
468
414
469
iConn := newInstrumentedConn (tlsConn , func () {
415
- n := atomic .AddUint64 (c .openConns , ^ uint64 (0 ))
470
+ n := atomic .AddUint64 (c .openConnsCount , ^ uint64 (0 ))
416
471
trace .RecordOpenConnections (context .Background (), int64 (n ), d .dialerID , cn .String ())
417
472
}, d .dialerID , cn .String ())
418
473
474
+ // If this connection was opened using a Domain Name, then store it for later
475
+ // in case it needs to be forcibly closed.
476
+ if cn .DomainName () != "" {
477
+ c .mu .Lock ()
478
+ c .openConns = append (c .openConns , iConn )
479
+ c .mu .Unlock ()
480
+ }
419
481
return iConn , nil
420
482
}
421
483
@@ -520,6 +582,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
520
582
type instrumentedConn struct {
521
583
net.Conn
522
584
closeFunc func ()
585
+ mu sync.RWMutex
523
586
closed bool
524
587
dialerID string
525
588
connName string
@@ -545,6 +608,13 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
545
608
return bytesWritten , err
546
609
}
547
610
611
+ // isClosed returns true if this connection is closing or is already closed.
612
+ func (i * instrumentedConn ) isClosed () bool {
613
+ i .mu .RLock ()
614
+ defer i .mu .RUnlock ()
615
+ return i .closed
616
+ }
617
+
548
618
// Close delegates to the underlying net.Conn interface and reports the close
549
619
// to the provided closeFunc only when Close returns no error.
550
620
func (i * instrumentedConn ) Close () error {
@@ -568,13 +638,56 @@ func (d *Dialer) Close() error {
568
638
}
569
639
close (d .closed )
570
640
571
- d .cache .replaceAll (func (cn instance.ConnName , c monitoredCache ) (instance.ConnName , monitoredCache ) {
572
- c .Close () // close the monitoredCache
573
- return instance.ConnName {}, monitoredCache {} // Remove from cache
641
+ if d .domainNameTicker != nil {
642
+ d .domainNameTicker .Stop ()
643
+ }
644
+
645
+ d .cache .replaceAll (func (cn instance.ConnName , c * monitoredCache ) (instance.ConnName , * monitoredCache ) {
646
+ c .Close () // close the monitoredCache
647
+ return instance.ConnName {}, nil // Remove from cache
574
648
})
575
649
return nil
576
650
}
577
651
652
+ func (d * Dialer ) pollDomainNames (ctx context.Context ) {
653
+ d .cache .replaceAll (func (cn instance.ConnName , cache * monitoredCache ) (instance.ConnName , * monitoredCache ) {
654
+ if cn .DomainName () == "" {
655
+ return cn , cache
656
+ }
657
+
658
+ // Resolve the domain name.
659
+ newCn , err := d .resolver .Resolve (ctx , cn .DomainName ())
660
+
661
+ if err != nil {
662
+ // the domain name no longer resolves to a valid instance
663
+ d .logger .Debugf (ctx , "[failover] unable to resolve DNS for instance %s: %v" , cn .DomainName (), err )
664
+ cache .Close ()
665
+ return instance.ConnName {}, nil
666
+ } else if newCn != cn {
667
+ d .logger .Debugf (ctx , "domain name %s changed from old instance %s to new instance %s" ,
668
+ cn .DomainName (), cn .String (), newCn .String ())
669
+
670
+ useIamAuthn := cache .UseIAMAuthN ()
671
+ // The domain name points to a different instance.
672
+ cache .Close ()
673
+
674
+ newC , err := d .createConnectionInfoCache (ctx , cn , & useIamAuthn )
675
+ if err != nil {
676
+ d .logger .Debugf (ctx , "error connecting to new instance %s, %s: %v" ,
677
+ cn .DomainName (), newCn .String (), err )
678
+ return instance.ConnName {}, nil
679
+ }
680
+ return newCn , newC
681
+ }
682
+
683
+ // Remove closed sockets from cache.openConns
684
+ cache .purgeClosedConns ()
685
+ return cn , cache
686
+
687
+ })
688
+
689
+ }
690
+
578
691
// connectionInfoCache is a helper function for returning the appropriate
579
692
// connection info Cache in a threadsafe way. It will create a new cache,
580
693
// modify the existing one, or leave it unchanged as needed.
@@ -624,7 +737,7 @@ func (d *Dialer) createConnectionInfoCache(
624
737
)
625
738
}
626
739
c := & monitoredCache {
627
- openConns : new (uint64 ),
740
+ openConnsCount : new (uint64 ),
628
741
connectionInfoCache : cache ,
629
742
}
630
743
0 commit comments