Skip to content

Commit a8c842f

Browse files
authored
chore(auth): add a bunch of missing validation logic (#8718)
- made the look of validation consistent across packages - moved refresh token into the options struct for 3lo (this is the flagged breaking change.)
1 parent 3a4ec65 commit a8c842f

File tree

15 files changed

+371
-82
lines changed

15 files changed

+371
-82
lines changed

auth/auth.go

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package auth
1717
import (
1818
"context"
1919
"encoding/json"
20+
"errors"
2021
"fmt"
2122
"net/http"
2223
"net/url"
@@ -224,17 +225,17 @@ type Options2LO struct {
224225
// contents of a PEM file that contains a private key. It is used to sign
225226
// the JWT created.
226227
PrivateKey []byte
228+
// TokenURL is th URL the JWT is sent to. Required.
229+
TokenURL string
227230
// PrivateKeyID is the ID of the key used to sign the JWT. It is used as the
228-
// "kid" in the JWT header.
231+
// "kid" in the JWT header. Optional.
229232
PrivateKeyID string
230233
// Subject is the used for to impersonate a user. It is used as the "sub" in
231234
// the JWT.m Optional.
232235
Subject string
233236
// Scopes specifies requested permissions for the token. Optional.
234237
Scopes []string
235-
// TokenURL is th URL the JWT is sent to.
236-
TokenURL string
237-
// Expires specifies the lifetime of the token.
238+
// Expires specifies the lifetime of the token. Optional.
238239
Expires time.Duration
239240
// Audience specifies the "aud" in the JWT. Optional.
240241
Audience string
@@ -249,16 +250,34 @@ type Options2LO struct {
249250
UseIDToken bool
250251
}
251252

252-
func (c *Options2LO) client() *http.Client {
253-
if c.Client != nil {
254-
return c.Client
253+
func (o *Options2LO) client() *http.Client {
254+
if o.Client != nil {
255+
return o.Client
255256
}
256257
return internal.CloneDefaultClient()
257258
}
258259

260+
func (o *Options2LO) validate() error {
261+
if o == nil {
262+
return errors.New("auth: options must be provided")
263+
}
264+
if o.Email == "" {
265+
return errors.New("auth: email must be provided")
266+
}
267+
if len(o.PrivateKey) == 0 {
268+
return errors.New("auth: private key must be provided")
269+
}
270+
if o.TokenURL == "" {
271+
return errors.New("auth: token URL must be provided")
272+
}
273+
return nil
274+
}
275+
259276
// New2LOTokenProvider returns a [TokenProvider] from the provided options.
260277
func New2LOTokenProvider(opts *Options2LO) (TokenProvider, error) {
261-
// TODO(codyoss): add validation
278+
if err := opts.validate(); err != nil {
279+
return nil, err
280+
}
262281
return tokenProvider2LO{opts: opts, Client: opts.client()}, nil
263282
}
264283

auth/auth_test.go

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ func TestError_Error(t *testing.T) {
178178
}
179179
}
180180

181-
func TestConfigJWT2LO_JSONResponse(t *testing.T) {
181+
func TestNew2LOTokenProvider_JSONResponse(t *testing.T) {
182182
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
183183
w.Header().Set("Content-Type", "application/json")
184184
w.Write([]byte(`{
@@ -221,7 +221,7 @@ func TestConfigJWT2LO_JSONResponse(t *testing.T) {
221221
}
222222
}
223223

224-
func TestConfigJWT2LO_BadResponse(t *testing.T) {
224+
func TestNew2LOTokenProvider_BadResponse(t *testing.T) {
225225
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
226226
w.Header().Set("Content-Type", "application/json")
227227
w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
@@ -259,7 +259,7 @@ func TestConfigJWT2LO_BadResponse(t *testing.T) {
259259
}
260260
}
261261

262-
func TestConfigJWT2LO_BadResponseType(t *testing.T) {
262+
func TestNew2LOTokenProvider_BadResponseType(t *testing.T) {
263263
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
264264
w.Header().Set("Content-Type", "application/json")
265265
w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
@@ -283,7 +283,7 @@ func TestConfigJWT2LO_BadResponseType(t *testing.T) {
283283
}
284284
}
285285

286-
func TestConfigJWT2LO_Assertion(t *testing.T) {
286+
func TestNew2LOTokenProvider_Assertion(t *testing.T) {
287287
var assertion string
288288
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
289289
r.ParseForm()
@@ -339,7 +339,7 @@ func TestConfigJWT2LO_Assertion(t *testing.T) {
339339
}
340340
}
341341

342-
func TestConfigJWT2LO_AssertionPayload(t *testing.T) {
342+
func TestNew2LOTokenProvider_AssertionPayload(t *testing.T) {
343343
var assertion string
344344
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
345345
r.ParseForm()
@@ -436,7 +436,7 @@ func TestConfigJWT2LO_AssertionPayload(t *testing.T) {
436436
}
437437
}
438438

439-
func TestConfigJWT2LO_TokenError(t *testing.T) {
439+
func TestNew2LOTokenProvider_TokenError(t *testing.T) {
440440
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
441441
w.Header().Set("Content-type", "application/json")
442442
w.WriteHeader(http.StatusBadRequest)
@@ -467,3 +467,42 @@ func TestConfigJWT2LO_TokenError(t *testing.T) {
467467
t.Fatalf("got %#v, expected %#v", errStr, expected)
468468
}
469469
}
470+
471+
func TestNew2LOTokenProvider_Validate(t *testing.T) {
472+
tests := []struct {
473+
name string
474+
opts *Options2LO
475+
}{
476+
{
477+
name: "missing options",
478+
},
479+
{
480+
name: "missing email",
481+
opts: &Options2LO{
482+
PrivateKey: []byte("key"),
483+
TokenURL: "url",
484+
},
485+
},
486+
{
487+
name: "missing key",
488+
opts: &Options2LO{
489+
Email: "email",
490+
TokenURL: "url",
491+
},
492+
},
493+
{
494+
name: "missing URL",
495+
opts: &Options2LO{
496+
Email: "email",
497+
PrivateKey: []byte("key"),
498+
},
499+
},
500+
}
501+
for _, tt := range tests {
502+
t.Run(tt.name, func(t *testing.T) {
503+
if _, err := New2LOTokenProvider(tt.opts); err == nil {
504+
t.Error("got nil, want an error")
505+
}
506+
})
507+
}
508+
}

auth/detect/detect.go

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package detect
1616

1717
import (
1818
"encoding/json"
19+
"errors"
1920
"fmt"
2021
"net/http"
2122
"os"
@@ -119,7 +120,9 @@ func OnGCE() bool {
119120
// runtimes, and Google App Engine flexible environment, it fetches
120121
// credentials from the metadata server.
121122
func DefaultCredentials(opts *Options) (*Credentials, error) {
122-
// TODO(codyoss): add some validation logic here.
123+
if err := opts.validate(); err != nil {
124+
return nil, err
125+
}
123126
if opts.CredentialsJSON != nil {
124127
return readCredentialsFileJSON(opts.CredentialsJSON, opts)
125128
}
@@ -145,7 +148,8 @@ func DefaultCredentials(opts *Options) (*Credentials, error) {
145148
// Options provides configuration for [DefaultCredentials].
146149
type Options struct {
147150
// Scopes that credentials tokens should have. Example:
148-
// https://www.googleapis.com/auth/cloud-platform
151+
// https://www.googleapis.com/auth/cloud-platform. Required if Audience is
152+
// not provided.
149153
Scopes []string
150154
// Audience that credentials tokens should have. Only applicable for 2LO
151155
// flows with service accounts. If specified, scopes should not be provided.
@@ -168,10 +172,12 @@ type Options struct {
168172
// Currently this only used for GDCH auth flow, for which it is required.
169173
STSAudience string
170174
// CredentialsFile overrides detection logic and sources a credential file
171-
// from the provided filepath. Optional.
175+
// from the provided filepath. If provided, CredentialsJSON must not be.
176+
// Optional.
172177
CredentialsFile string
173178
// CredentialsJSON overrides detection logic and uses the JSON bytes as the
174-
// source for the credential. Optional.
179+
// source for the credential. If provided, CredentialsFile must not be.
180+
// Optional.
175181
CredentialsJSON []byte
176182
// UseSelfSignedJWT directs service account based credentials to create a
177183
// self-signed JWT with the private key found in the file, skipping any
@@ -182,6 +188,19 @@ type Options struct {
182188
Client *http.Client
183189
}
184190

191+
func (o *Options) validate() error {
192+
if o == nil {
193+
return errors.New("detect: options must be provided")
194+
}
195+
if len(o.Scopes) > 0 && o.Audience != "" {
196+
return errors.New("detect: both scopes and audience were provided")
197+
}
198+
if len(o.CredentialsJSON) > 0 && o.CredentialsFile != "" {
199+
return errors.New("detect: both credentials file and JSON were provided")
200+
}
201+
return nil
202+
}
203+
185204
func (o *Options) tokenURL() string {
186205
if o.TokenURL != "" {
187206
return o.TokenURL
@@ -214,7 +233,10 @@ func readCredentialsFileJSON(b []byte, opts *Options) (*Credentials, error) {
214233
// attempt to parse jsonData as a Google Developers Console client_credentials.json.
215234
config := clientCredConfigFromJSON(b, opts)
216235
if config != nil {
217-
tp, err := auth.New3LOTokenProvider("", config)
236+
if config.AuthHandlerOpts == nil {
237+
return nil, errors.New("detect: auth handler must be specified for this credential filetype")
238+
}
239+
tp, err := auth.New3LOTokenProvider(config)
218240
if err != nil {
219241
return nil, err
220242
}
@@ -242,7 +264,6 @@ func clientCredConfigFromJSON(b []byte, opts *Options) *auth.Options3LO {
242264
}
243265
var handleOpts *auth.AuthorizationHandlerOptions
244266
if opts.AuthHandlerOptions != nil {
245-
// TODO(codyoss): these have to be here for this flow, validate that
246267
handleOpts = &auth.AuthorizationHandlerOptions{
247268
Handler: opts.AuthHandlerOptions.Handler,
248269
State: opts.AuthHandlerOptions.State,

auth/detect/detect_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,3 +560,36 @@ func TestDefaultCredentials_BadFiletype(t *testing.T) {
560560
t.Fatal("got nil, want non-nil err")
561561
}
562562
}
563+
564+
func TestDefaultCredentials_Validate(t *testing.T) {
565+
tests := []struct {
566+
name string
567+
opts *Options
568+
}{
569+
{
570+
name: "missing options",
571+
},
572+
{
573+
name: "scope and audience provided",
574+
opts: &Options{
575+
Scopes: []string{"scope"},
576+
Audience: "aud",
577+
},
578+
},
579+
{
580+
name: "file and json provided",
581+
opts: &Options{
582+
Scopes: []string{"scope"},
583+
CredentialsFile: "path",
584+
CredentialsJSON: []byte(`{"some":"json"}`),
585+
},
586+
},
587+
}
588+
for _, tt := range tests {
589+
t.Run(tt.name, func(t *testing.T) {
590+
if _, err := DefaultCredentials(tt.opts); err == nil {
591+
t.Error("got nil, want an error")
592+
}
593+
})
594+
}
595+
}

auth/detect/filetypes.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ func handleUserCredential(f *internaldetect.UserCredentialsFile, opts *Options)
121121
TokenURL: opts.tokenURL(),
122122
AuthStyle: auth.StyleInParams,
123123
EarlyTokenExpiry: opts.EarlyTokenRefresh,
124+
RefreshToken: f.RefreshToken,
124125
}
125-
return auth.New3LOTokenProvider(f.RefreshToken, opts3LO)
126+
return auth.New3LOTokenProvider(opts3LO)
126127
}
127128

128129
func handleExternalAccount(f *internaldetect.ExternalAccountFile, opts *Options) (auth.TokenProvider, error) {

auth/detect/internal/externalaccount/sts_exchange.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ type stsTokenExchangeResponse struct {
116116
TokenType string `json:"token_type"`
117117
ExpiresIn int `json:"expires_in"`
118118
Scope string `json:"scope"`
119-
// TODO(codyoss): original impl parsed but did not use a refresh token here, do we need it?
120119
}
121120

122121
// clientAuthentication represents an OAuth client ID and secret and the

auth/detect/internal/impersonate/impersonate.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"bytes"
1919
"context"
2020
"encoding/json"
21+
"errors"
2122
"fmt"
2223
"net/http"
2324
"time"
@@ -46,7 +47,9 @@ type impersonateTokenResponse struct {
4647
// NewTokenProvider uses a source credential, stored in Ts, to request an access token to the provided URL.
4748
// Scopes can be defined when the access token is requested.
4849
func NewTokenProvider(opts *Options) (auth.TokenProvider, error) {
49-
// TODO(codyoss): add validation
50+
if err := opts.validate(); err != nil {
51+
return nil, err
52+
}
5053
return opts, nil
5154
}
5255

@@ -73,6 +76,16 @@ type Options struct {
7376
Client *http.Client
7477
}
7578

79+
func (o *Options) validate() error {
80+
if o.Tp == nil {
81+
return errors.New("detect: missing required 'source_credentials' field in impersonated credentials")
82+
}
83+
if o.URL == "" {
84+
return errors.New("detect: missing required 'service_account_impersonation_url' field in impersonated credentials")
85+
}
86+
return nil
87+
}
88+
7689
// Token performs the exchange to get a temporary service account token to allow access to GCP.
7790
func (tp *Options) Token(ctx context.Context) (*auth.Token, error) {
7891
lifetime := defaultTokenLifetime

auth/detect/internal/impersonate/impersonate_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,34 @@ func (tp mockProvider) Token(context.Context) (*auth.Token, error) {
3333
}, nil
3434
}
3535

36+
func TestNewImpersonatedTokenProvider_Validation(t *testing.T) {
37+
tests := []struct {
38+
name string
39+
opt *Options
40+
}{
41+
{
42+
name: "missing source creds",
43+
opt: &Options{
44+
URL: "some-url",
45+
},
46+
},
47+
{
48+
name: "missing url",
49+
opt: &Options{
50+
Tp: &Options{},
51+
},
52+
},
53+
}
54+
for _, tt := range tests {
55+
t.Run(tt.name, func(t *testing.T) {
56+
_, err := NewTokenProvider(tt.opt)
57+
if err == nil {
58+
t.Errorf("got nil, want an error")
59+
}
60+
})
61+
}
62+
}
63+
3664
func TestNewImpersonatedTokenProvider(t *testing.T) {
3765
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3866
if got, want := r.Header.Get("Authorization"), "Bearer fake_token_base"; got != want {

0 commit comments

Comments
 (0)