Skip to content

Commit 225fa6b

Browse files
authored
internal: Refactor cert logic to support OAuth2 token exchange over mTLS (#1886)
With Context Aware Access enabled, users must use the endpoint "https://oauth2.mtls.googleapis.com/token" for token exchange. This PR refactors the cert logic currently used by the transport layer to be reused by the internal credentials layer in order to inject an mTLS-enabled HTTPClient (via the "context" mechanism) for use by the OAuth2 transport (along with the mTLS OAuth2 endpoint if so).
1 parent 8d4d70d commit 225fa6b

22 files changed

+69
-19
lines changed
File renamed without changes.
File renamed without changes.

transport/cert/testdata/signer.sh renamed to internal/cert/testdata/signer.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
# Use of this source code is governed by a BSD-style
55
# license that can be found in the LICENSE file.
66

7-
go run ../internal/ecp/test_signer.go testdata/rsa2048bit.pem
7+
go run ../ecp/test_signer.go testdata/rsa2048bit.pem

transport/cert/testdata/signer_invalid_pem.sh renamed to internal/cert/testdata/signer_invalid_pem.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
# Use of this source code is governed by a BSD-style
55
# license that can be found in the LICENSE file.
66

7-
go run ../internal/ecp/test_signer.go testdata/invalid.pem
7+
go run ../ecp/test_signer.go testdata/invalid.pem

internal/creds.go

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@ package internal
66

77
import (
88
"context"
9+
"crypto/tls"
910
"encoding/json"
1011
"errors"
1112
"fmt"
1213
"io/ioutil"
14+
"net"
15+
"net/http"
16+
"time"
1317

1418
"golang.org/x/oauth2"
1519
"google.golang.org/api/internal/impersonate"
@@ -80,8 +84,25 @@ const (
8084
// - Otherwise, executes standard OAuth 2.0 flow
8185
// More details: google.aip.dev/auth/4111
8286
func credentialsFromJSON(ctx context.Context, data []byte, ds *DialSettings) (*google.Credentials, error) {
87+
var params google.CredentialsParams
88+
params.Scopes = ds.GetScopes()
89+
90+
// Determine configurations for the OAuth2 transport, which is separate from the API transport.
91+
// The OAuth2 transport and endpoint will be configured for mTLS if applicable.
92+
clientCertSource, oauth2Endpoint, err := GetClientCertificateSourceAndEndpoint(oauth2DialSettings(ds))
93+
if err != nil {
94+
return nil, err
95+
}
96+
params.TokenURL = oauth2Endpoint
97+
if clientCertSource != nil {
98+
tlsConfig := &tls.Config{
99+
GetClientCertificate: clientCertSource,
100+
}
101+
ctx = context.WithValue(ctx, oauth2.HTTPClient, customHTTPClient(tlsConfig))
102+
}
103+
83104
// By default, a standard OAuth 2.0 token source is created
84-
cred, err := google.CredentialsFromJSON(ctx, data, ds.GetScopes()...)
105+
cred, err := google.CredentialsFromJSONWithParams(ctx, data, params)
85106
if err != nil {
86107
return nil, err
87108
}
@@ -157,3 +178,35 @@ func impersonateCredentials(ctx context.Context, creds *google.Credentials, ds *
157178
ProjectID: creds.ProjectID,
158179
}, nil
159180
}
181+
182+
// oauth2DialSettings returns the settings to be used by the OAuth2 transport, which is separate from the API transport.
183+
func oauth2DialSettings(ds *DialSettings) *DialSettings {
184+
var ods DialSettings
185+
ods.DefaultEndpoint = google.Endpoint.TokenURL
186+
ods.DefaultMTLSEndpoint = google.MTLSTokenURL
187+
ods.ClientCertSource = ds.ClientCertSource
188+
return &ods
189+
}
190+
191+
// customHTTPClient constructs an HTTPClient using the provided tlsConfig, to support mTLS.
192+
func customHTTPClient(tlsConfig *tls.Config) *http.Client {
193+
trans := baseTransport()
194+
trans.TLSClientConfig = tlsConfig
195+
return &http.Client{Transport: trans}
196+
}
197+
198+
func baseTransport() *http.Transport {
199+
return &http.Transport{
200+
Proxy: http.ProxyFromEnvironment,
201+
DialContext: (&net.Dialer{
202+
Timeout: 30 * time.Second,
203+
KeepAlive: 30 * time.Second,
204+
DualStack: true,
205+
}).DialContext,
206+
MaxIdleConns: 100,
207+
MaxIdleConnsPerHost: 100,
208+
IdleConnTimeout: 90 * time.Second,
209+
TLSHandshakeTimeout: 10 * time.Second,
210+
ExpectContinueTimeout: 1 * time.Second,
211+
}
212+
}

transport/internal/dca/dca.go renamed to internal/dca.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,16 @@
2323
//
2424
// This package is not intended for use by end developers. Use the
2525
// google.golang.org/api/option package to configure API clients.
26-
package dca
26+
27+
// Package internal supports the options and transport packages.
28+
package internal
2729

2830
import (
2931
"net/url"
3032
"os"
3133
"strings"
3234

33-
"google.golang.org/api/internal"
34-
"google.golang.org/api/transport/cert"
35+
"google.golang.org/api/internal/cert"
3536
)
3637

3738
const (
@@ -43,7 +44,7 @@ const (
4344
// GetClientCertificateSourceAndEndpoint is a convenience function that invokes
4445
// getClientCertificateSource and getEndpoint sequentially and returns the client
4546
// cert source and endpoint as a tuple.
46-
func GetClientCertificateSourceAndEndpoint(settings *internal.DialSettings) (cert.Source, string, error) {
47+
func GetClientCertificateSourceAndEndpoint(settings *DialSettings) (cert.Source, string, error) {
4748
clientCertSource, err := getClientCertificateSource(settings)
4849
if err != nil {
4950
return nil, "", err
@@ -65,7 +66,7 @@ func GetClientCertificateSourceAndEndpoint(settings *internal.DialSettings) (cer
6566
// Important Note: For now, the environment variable GOOGLE_API_USE_CLIENT_CERTIFICATE
6667
// must be set to "true" to allow certificate to be used (including user provided
6768
// certificates). For details, see AIP-4114.
68-
func getClientCertificateSource(settings *internal.DialSettings) (cert.Source, error) {
69+
func getClientCertificateSource(settings *DialSettings) (cert.Source, error) {
6970
if !isClientCertificateEnabled() {
7071
return nil, nil
7172
} else if settings.ClientCertSource != nil {
@@ -94,7 +95,7 @@ func isClientCertificateEnabled() bool {
9495
// URL (ex. https://...), then the user-provided address will be merged into
9596
// the default endpoint. For example, WithEndpoint("myhost:8000") and
9697
// WithDefaultEndpoint("https://foo.com/bar/baz") will return "https://myhost:8080/bar/baz"
97-
func getEndpoint(settings *internal.DialSettings, clientCertSource cert.Source) (string, error) {
98+
func getEndpoint(settings *DialSettings, clientCertSource cert.Source) (string, error) {
9899
if settings.Endpoint == "" {
99100
mtlsMode := getMTLSMode()
100101
if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) {

transport/internal/dca/dca_test.go renamed to internal/dca_test.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22
// Use of this source code is governed by a BSD-style
33
// license that can be found in the LICENSE file.
44

5-
package dca
5+
package internal
66

77
import (
88
"testing"
99

1010
"crypto/tls"
11-
12-
"google.golang.org/api/internal"
1311
)
1412

1513
func TestGetEndpoint(t *testing.T) {
@@ -51,7 +49,7 @@ func TestGetEndpoint(t *testing.T) {
5149
}
5250

5351
for _, tc := range testCases {
54-
got, err := getEndpoint(&internal.DialSettings{
52+
got, err := getEndpoint(&DialSettings{
5553
Endpoint: tc.UserEndpoint,
5654
DefaultEndpoint: tc.DefaultEndpoint,
5755
}, nil)
@@ -106,7 +104,7 @@ func TestGetEndpointWithClientCertSource(t *testing.T) {
106104
}
107105

108106
for _, tc := range testCases {
109-
got, err := getEndpoint(&internal.DialSettings{
107+
got, err := getEndpoint(&DialSettings{
110108
Endpoint: tc.UserEndpoint,
111109
DefaultEndpoint: tc.DefaultEndpoint,
112110
DefaultMTLSEndpoint: tc.DefaultMTLSEndpoint,
File renamed without changes.

transport/grpc/dial.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"golang.org/x/oauth2"
2222
"google.golang.org/api/internal"
2323
"google.golang.org/api/option"
24-
"google.golang.org/api/transport/internal/dca"
2524
"google.golang.org/grpc"
2625
"google.golang.org/grpc/credentials"
2726
grpcgoogle "google.golang.org/grpc/credentials/google"
@@ -123,7 +122,7 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C
123122
if o.GRPCConn != nil {
124123
return o.GRPCConn, nil
125124
}
126-
clientCertSource, endpoint, err := dca.GetClientCertificateSourceAndEndpoint(o)
125+
clientCertSource, endpoint, err := internal.GetClientCertificateSourceAndEndpoint(o)
127126
if err != nil {
128127
return nil, err
129128
}

transport/http/dial.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@ import (
2020
"golang.org/x/oauth2"
2121
"google.golang.org/api/googleapi/transport"
2222
"google.golang.org/api/internal"
23+
"google.golang.org/api/internal/cert"
2324
"google.golang.org/api/option"
24-
"google.golang.org/api/transport/cert"
2525
"google.golang.org/api/transport/http/internal/propagation"
26-
"google.golang.org/api/transport/internal/dca"
2726
)
2827

2928
// NewClient returns an HTTP client for use communicating with a Google cloud
@@ -34,7 +33,7 @@ func NewClient(ctx context.Context, opts ...option.ClientOption) (*http.Client,
3433
if err != nil {
3534
return nil, "", err
3635
}
37-
clientCertSource, endpoint, err := dca.GetClientCertificateSourceAndEndpoint(settings)
36+
clientCertSource, endpoint, err := internal.GetClientCertificateSourceAndEndpoint(settings)
3837
if err != nil {
3938
return nil, "", err
4039
}

0 commit comments

Comments
 (0)