Skip to content

Commit af8b703

Browse files
authored
feat: support static connection info (#572)
In development environments it can be useful to seed the Dialer with static connection info properties. This commit makes it possible to replace the internal lazy or refresh ahead cache with a static cache that is populated from JSON data. The JSON format looks like this: { "publicKey": "<PEM Encoded public RSA key>", "privateKey": "<PEM Encoded private RSA key>", "projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>": { "ipAddress": "<PSA-based private IP address>", "publicIpAddress": "<public IP address>", "pscInstanceConfig": { "pscDnsName": "<PSC DNS name>" }, "pemCertificateChain": [ "<client cert>", "<intermediate cert>", "<CA cert>" ], "caCert": "<CA cert>" } }
1 parent 34de206 commit af8b703

File tree

8 files changed

+354
-34
lines changed

8 files changed

+354
-34
lines changed

dialer.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ type Dialer struct {
108108
// ahead cache assumes a background goroutine may run consistently.
109109
lazyRefresh bool
110110

111+
staticConnInfo io.Reader
112+
111113
client *alloydbadmin.AlloyDBAdminClient
112114
logger debug.Logger
113115

@@ -194,6 +196,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
194196
closed: make(chan struct{}),
195197
cache: make(map[alloydb.InstanceURI]monitoredCache),
196198
lazyRefresh: cfg.lazyRefresh,
199+
staticConnInfo: cfg.staticConnInfo,
197200
key: cfg.rsaKey,
198201
refreshTimeout: cfg.refreshTimeout,
199202
client: client,
@@ -563,14 +566,25 @@ func (d *Dialer) connectionInfoCache(
563566
uri.String(),
564567
)
565568
var cache connectionInfoCache
566-
if d.lazyRefresh {
569+
switch {
570+
case d.lazyRefresh:
567571
cache = alloydb.NewLazyRefreshCache(
568572
uri,
569573
d.logger,
570574
d.client, d.key,
571575
d.refreshTimeout, d.dialerID,
572576
)
573-
} else {
577+
case d.staticConnInfo != nil:
578+
var err error
579+
cache, err = alloydb.NewStaticConnectionInfoCache(
580+
uri,
581+
d.logger,
582+
d.staticConnInfo,
583+
)
584+
if err != nil {
585+
return monitoredCache{}, err
586+
}
587+
default:
574588
cache = alloydb.NewRefreshAheadCache(
575589
uri,
576590
d.logger,

dialer_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
package alloydbconn
1616

1717
import (
18+
"bytes"
1819
"context"
20+
"crypto/rand"
21+
"crypto/rsa"
22+
"crypto/x509"
23+
"encoding/json"
24+
"encoding/pem"
1925
"errors"
2026
"fmt"
2127
"io"
@@ -89,7 +95,81 @@ func TestDialerCanConnectToInstance(t *testing.T) {
8995
}
9096
})
9197
}
98+
}
99+
100+
func writeStaticInfo(t *testing.T, i mock.FakeAlloyDBInstance) io.Reader {
101+
t.Helper()
102+
key, err := rsa.GenerateKey(rand.Reader, 2048)
103+
if err != nil {
104+
t.Fatal(err)
105+
}
106+
107+
pub := x509.MarshalPKCS1PublicKey(&key.PublicKey)
108+
pubPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: pub})
109+
if pubPEM == nil {
110+
t.Fatal("public key encoding failed")
111+
}
112+
priv := x509.MarshalPKCS1PrivateKey(key)
113+
privPEM := pem.EncodeToMemory(
114+
&pem.Block{Type: "OPENSSH PRIVATE KEY", Bytes: priv},
115+
)
116+
if privPEM == nil {
117+
t.Fatal("private key encoding failed")
118+
}
119+
120+
static := map[string]interface{}{}
121+
static["publicKey"] = string(pubPEM)
122+
static["privateKey"] = string(privPEM)
123+
info := make(map[string]interface{})
124+
info["ipAddress"] = "127.0.0.1" // "private" IP is localhost in testing
125+
chain, err := i.GeneratePEMCertificateChain(&key.PublicKey)
126+
if err != nil {
127+
t.Fatal(err)
128+
}
129+
info["pemCertificateChain"] = chain
130+
info["caCert"] = chain[len(chain)-1] // CA cert is last in chain
131+
static[i.String()] = info
132+
133+
data, err := json.Marshal(static)
134+
if err != nil {
135+
t.Fatal(err)
136+
}
92137

138+
return bytes.NewReader(data)
139+
}
140+
141+
func TestDialerWorksWithStaticConnectionInfo(t *testing.T) {
142+
ctx := context.Background()
143+
inst := mock.NewFakeInstance(
144+
"my-project", "my-region", "my-cluster", "my-instance",
145+
)
146+
stop := mock.StartServerProxy(t, inst)
147+
t.Cleanup(stop)
148+
149+
staticPath := writeStaticInfo(t, inst)
150+
151+
d, err := NewDialer(
152+
ctx,
153+
WithTokenSource(stubTokenSource{}),
154+
WithStaticConnectionInfo(staticPath),
155+
)
156+
if err != nil {
157+
t.Fatalf("expected NewDialer to succeed, but got error: %v", err)
158+
}
159+
160+
conn, err := d.Dial(ctx, testInstanceURI)
161+
if err != nil {
162+
t.Fatalf("expected Dial to succeed, but got error: %v", err)
163+
}
164+
defer conn.Close()
165+
166+
data, err := io.ReadAll(conn)
167+
if err != nil {
168+
t.Fatalf("expected ReadAll to succeed, got error %v", err)
169+
}
170+
if string(data) != "my-instance" {
171+
t.Fatalf("expected known response from the server, but got %v", string(data))
172+
}
93173
}
94174

95175
func TestDialWithAdminAPIErrors(t *testing.T) {

internal/alloydb/instance.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ type InstanceURI struct {
6262
name string
6363
}
6464

65+
// URI returns the full URI specifying an instance.
66+
func (i *InstanceURI) URI() string {
67+
return fmt.Sprintf(
68+
"projects/%s/locations/%s/clusters/%s/instances/%s",
69+
i.project, i.region, i.cluster, i.name,
70+
)
71+
}
72+
73+
// String returns a short-hand representation of an instance URI.
6574
func (i *InstanceURI) String() string {
6675
return fmt.Sprintf("%s/%s/%s/%s", i.project, i.region, i.cluster, i.name)
6776
}

internal/alloydb/refresh.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,25 @@ func fetchClientCertificate(
149149
)
150150
}
151151

152-
certPEMBlock := []byte(strings.Join(resp.PemCertificateChain, "\n"))
153152
keyPEMBlock := &pem.Block{
154153
Type: "RSA PRIVATE KEY",
155154
Bytes: x509.MarshalPKCS1PrivateKey(key),
156155
}
156+
keyPEM := pem.EncodeToMemory(keyPEMBlock)
157157

158-
cert, err := tls.X509KeyPair(certPEMBlock, pem.EncodeToMemory(keyPEMBlock))
158+
return newClientCertificate(
159+
inst, keyPEM, resp.PemCertificateChain, resp.CaCert,
160+
)
161+
}
162+
163+
func newClientCertificate(
164+
inst InstanceURI,
165+
keyPEM []byte,
166+
chain []string,
167+
caCertRaw string,
168+
) (cc *clientCertificate, err error) {
169+
certPEMBlock := []byte(strings.Join(chain, "\n"))
170+
cert, err := tls.X509KeyPair(certPEMBlock, keyPEM)
159171
if err != nil {
160172
return nil, errtype.NewRefreshError(
161173
"create ephemeral cert failed",
@@ -164,7 +176,7 @@ func fetchClientCertificate(
164176
)
165177
}
166178

167-
caCertPEMBlock, _ := pem.Decode([]byte(resp.CaCert))
179+
caCertPEMBlock, _ := pem.Decode([]byte(caCertRaw))
168180
if caCertPEMBlock == nil {
169181
return nil, errtype.NewRefreshError(
170182
"create ephemeral cert failed",
@@ -182,7 +194,7 @@ func fetchClientCertificate(
182194
}
183195

184196
// Extract expiry from client certificate.
185-
clientCertPEMBlock, _ := pem.Decode([]byte(resp.PemCertificateChain[0]))
197+
clientCertPEMBlock, _ := pem.Decode([]byte(chain[0]))
186198
if clientCertPEMBlock == nil {
187199
return nil, errtype.NewRefreshError(
188200
"create ephemeral cert failed",

internal/alloydb/static.go

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package alloydb
16+
17+
import (
18+
"context"
19+
"crypto/x509"
20+
"encoding/json"
21+
"io"
22+
23+
"cloud.google.com/go/alloydbconn/debug"
24+
"cloud.google.com/go/alloydbconn/errtype"
25+
)
26+
27+
type staticPSCConfig struct {
28+
PSCDNSName string `json:"pscDnsName"`
29+
}
30+
31+
// staticConnectionInfo is an amalgamation of the generate ephemeral
32+
// certificate and instance metadata endpoints. Its structure concatenates the
33+
// IP address information with certificate information. As such it provides all
34+
// the necessary properties needed for the Dialer to connect to an instance's
35+
// Auth Proxy server.
36+
type staticConnectionInfo struct {
37+
IPAddress string `json:"ipAddress"`
38+
PublicIPAddress string `json:"publicIPAddress"`
39+
PSCInstanceConfig staticPSCConfig `json:"pscInstanceConfig"`
40+
PEMCertificateChain []string `json:"pemCertificateChain"`
41+
CACert string `json:"caCert"`
42+
}
43+
44+
// staticInstanceInfo correlates instance URIs with static connection info.
45+
type staticInstanceInfo map[string]staticConnectionInfo
46+
47+
// staticData represent a collection of static connection info.
48+
type staticData struct {
49+
PublicKey string
50+
PrivateKey string
51+
InstanceInfo staticInstanceInfo
52+
}
53+
54+
func (s *staticData) UnmarshalJSON(data []byte) error {
55+
inner := map[string]json.RawMessage{}
56+
if err := json.Unmarshal(data, &inner); err != nil {
57+
return err
58+
}
59+
if err := json.Unmarshal(inner["privateKey"], &s.PrivateKey); err != nil {
60+
return err
61+
}
62+
delete(inner, "privateKey")
63+
if err := json.Unmarshal(inner["publicKey"], &s.PublicKey); err != nil {
64+
return err
65+
}
66+
delete(inner, "publicKey")
67+
68+
s.InstanceInfo = staticInstanceInfo{}
69+
for k, v := range inner {
70+
var sci staticConnectionInfo
71+
if err := json.Unmarshal(v, &sci); err != nil {
72+
return err
73+
}
74+
s.InstanceInfo[k] = sci
75+
}
76+
return nil
77+
}
78+
79+
// StaticConnectionInfoCache provides connection info that is never refreshed.
80+
type StaticConnectionInfoCache struct {
81+
logger debug.Logger
82+
info ConnectionInfo
83+
}
84+
85+
// NewStaticConnectionInfoCache creates a connection info cache that will
86+
// always return the predefined connection info within the provided io.Reader
87+
func NewStaticConnectionInfoCache(
88+
inst InstanceURI,
89+
l debug.Logger,
90+
r io.Reader,
91+
) (*StaticConnectionInfoCache, error) {
92+
data, err := io.ReadAll(r)
93+
if err != nil {
94+
return nil, err
95+
}
96+
var d staticData
97+
if err := json.Unmarshal(data, &d); err != nil {
98+
return nil, err
99+
}
100+
static, ok := d.InstanceInfo[inst.URI()]
101+
if !ok {
102+
return nil, errtype.NewConfigError("unknown instance", inst.String())
103+
}
104+
cc, err := newClientCertificate(
105+
inst, []byte(d.PrivateKey), static.PEMCertificateChain, static.CACert,
106+
)
107+
if err != nil {
108+
return nil, err
109+
}
110+
pool := x509.NewCertPool()
111+
pool.AddCert(cc.caCert)
112+
info := ConnectionInfo{
113+
Instance: inst,
114+
IPAddrs: map[string]string{
115+
PublicIP: static.PublicIPAddress,
116+
PrivateIP: static.IPAddress,
117+
PSC: static.PSCInstanceConfig.PSCDNSName,
118+
},
119+
ClientCert: cc.certChain,
120+
RootCAs: pool,
121+
Expiration: cc.expiry,
122+
}
123+
return &StaticConnectionInfoCache{
124+
logger: l,
125+
info: info,
126+
}, nil
127+
}
128+
129+
// ConnectionInfo returns the connection info for the specified instance URI as
130+
// loaded from the provided io.Reader.
131+
func (c *StaticConnectionInfoCache) ConnectionInfo(
132+
_ context.Context,
133+
) (ConnectionInfo, error) {
134+
return c.info, nil
135+
}
136+
137+
// ForceRefresh is a no-op as the cache holds only static connection
138+
// information and does no refresh.
139+
func (*StaticConnectionInfoCache) ForceRefresh() {}
140+
141+
// Close is a no-op.
142+
func (*StaticConnectionInfoCache) Close() error { return nil }

0 commit comments

Comments
 (0)