From a9cc46344457c2aa597154f63f77a24723a6bc4d Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 7 Jul 2023 12:54:05 -0700 Subject: [PATCH 1/2] Refactor internal MSAL client constructors --- sdk/azidentity/azidentity.go | 29 ++++++++++++------- sdk/azidentity/client_assertion_credential.go | 6 +++- .../client_certificate_credential.go | 10 +++---- sdk/azidentity/client_secret_credential.go | 8 +++-- sdk/azidentity/device_code_credential.go | 8 +++-- .../interactive_browser_credential.go | 6 +++- sdk/azidentity/on_behalf_of_credential.go | 10 +++---- .../username_password_credential.go | 6 +++- 8 files changed, 54 insertions(+), 29 deletions(-) diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index 49136c6b9861..7b0a0f861f50 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -50,46 +50,55 @@ var ( disableCP1 = strings.ToLower(os.Getenv("AZURE_IDENTITY_DISABLE_CP1")) == "true" ) -var getConfidentialClient = func(clientID, tenantID string, cred confidential.Credential, co *azcore.ClientOptions, additionalOpts ...confidential.Option) (confidentialClient, error) { +type msalClientOptions struct { + azcore.ClientOptions + + DisableInstanceDiscovery bool + // SendX5C applies only to confidential clients authenticating with a cert + SendX5C bool +} + +var getConfidentialClient = func(clientID, tenantID string, cred confidential.Credential, opts msalClientOptions) (confidentialClient, error) { if !validTenantID(tenantID) { return confidential.Client{}, errors.New(tenantIDValidationErr) } - authorityHost, err := setAuthorityHost(co.Cloud) + authorityHost, err := setAuthorityHost(opts.Cloud) if err != nil { return confidential.Client{}, err } authority := runtime.JoinPaths(authorityHost, tenantID) o := []confidential.Option{ confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)), - confidential.WithHTTPClient(newPipelineAdapter(co)), + confidential.WithHTTPClient(newPipelineAdapter(&opts.ClientOptions)), } if !disableCP1 { o = append(o, confidential.WithClientCapabilities(cp1)) } - o = append(o, additionalOpts...) - if strings.ToLower(tenantID) == "adfs" { + if opts.SendX5C { + o = append(o, confidential.WithX5C()) + } + if opts.DisableInstanceDiscovery || strings.ToLower(tenantID) == "adfs" { o = append(o, confidential.WithInstanceDiscovery(false)) } return confidential.New(authority, clientID, cred, o...) } -var getPublicClient = func(clientID, tenantID string, co *azcore.ClientOptions, additionalOpts ...public.Option) (public.Client, error) { +var getPublicClient = func(clientID, tenantID string, opts msalClientOptions) (public.Client, error) { if !validTenantID(tenantID) { return public.Client{}, errors.New(tenantIDValidationErr) } - authorityHost, err := setAuthorityHost(co.Cloud) + authorityHost, err := setAuthorityHost(opts.Cloud) if err != nil { return public.Client{}, err } o := []public.Option{ public.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)), - public.WithHTTPClient(newPipelineAdapter(co)), + public.WithHTTPClient(newPipelineAdapter(&opts.ClientOptions)), } if !disableCP1 { o = append(o, public.WithClientCapabilities(cp1)) } - o = append(o, additionalOpts...) - if strings.ToLower(tenantID) == "adfs" { + if opts.DisableInstanceDiscovery || strings.ToLower(tenantID) == "adfs" { o = append(o, public.WithInstanceDiscovery(false)) } return public.New(clientID, o...) diff --git a/sdk/azidentity/client_assertion_credential.go b/sdk/azidentity/client_assertion_credential.go index 619955373e99..56bf7e1dd35b 100644 --- a/sdk/azidentity/client_assertion_credential.go +++ b/sdk/azidentity/client_assertion_credential.go @@ -56,7 +56,11 @@ func NewClientAssertionCredential(tenantID, clientID string, getAssertion func(c return getAssertion(ctx) }, ) - c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions, confidential.WithInstanceDiscovery(!options.DisableInstanceDiscovery)) + msalOpts := msalClientOptions{ + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, + } + c, err := getConfidentialClient(clientID, tenantID, cred, msalOpts) if err != nil { return nil, err } diff --git a/sdk/azidentity/client_certificate_credential.go b/sdk/azidentity/client_certificate_credential.go index d2239d68097a..c78b4c442b30 100644 --- a/sdk/azidentity/client_certificate_credential.go +++ b/sdk/azidentity/client_certificate_credential.go @@ -58,12 +58,12 @@ func NewClientCertificateCredential(tenantID string, clientID string, certs []*x if err != nil { return nil, err } - var o []confidential.Option - if options.SendCertificateChain { - o = append(o, confidential.WithX5C()) + msalOpts := msalClientOptions{ + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, + SendX5C: options.SendCertificateChain, } - o = append(o, confidential.WithInstanceDiscovery(!options.DisableInstanceDiscovery)) - c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions, o...) + c, err := getConfidentialClient(clientID, tenantID, cred, msalOpts) if err != nil { return nil, err } diff --git a/sdk/azidentity/client_secret_credential.go b/sdk/azidentity/client_secret_credential.go index 7d3ec645bf3b..36a84bebe007 100644 --- a/sdk/azidentity/client_secret_credential.go +++ b/sdk/azidentity/client_secret_credential.go @@ -46,9 +46,11 @@ func NewClientSecretCredential(tenantID string, clientID string, clientSecret st if err != nil { return nil, err } - c, err := getConfidentialClient( - clientID, tenantID, cred, &options.ClientOptions, confidential.WithInstanceDiscovery(!options.DisableInstanceDiscovery), - ) + msalOpts := msalClientOptions{ + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, + } + c, err := getConfidentialClient(clientID, tenantID, cred, msalOpts) if err != nil { return nil, err } diff --git a/sdk/azidentity/device_code_credential.go b/sdk/azidentity/device_code_credential.go index c746038bb371..90f128db9de6 100644 --- a/sdk/azidentity/device_code_credential.go +++ b/sdk/azidentity/device_code_credential.go @@ -87,9 +87,11 @@ func NewDeviceCodeCredential(options *DeviceCodeCredentialOptions) (*DeviceCodeC cp = *options } cp.init() - c, err := getPublicClient( - cp.ClientID, cp.TenantID, &cp.ClientOptions, public.WithInstanceDiscovery(!cp.DisableInstanceDiscovery), - ) + msalOpts := msalClientOptions{ + ClientOptions: cp.ClientOptions, + DisableInstanceDiscovery: cp.DisableInstanceDiscovery, + } + c, err := getPublicClient(cp.ClientID, cp.TenantID, msalOpts) if err != nil { return nil, err } diff --git a/sdk/azidentity/interactive_browser_credential.go b/sdk/azidentity/interactive_browser_credential.go index 3b59263ba92c..c630cca7bf17 100644 --- a/sdk/azidentity/interactive_browser_credential.go +++ b/sdk/azidentity/interactive_browser_credential.go @@ -69,7 +69,11 @@ func NewInteractiveBrowserCredential(options *InteractiveBrowserCredentialOption cp = *options } cp.init() - c, err := getPublicClient(cp.ClientID, cp.TenantID, &cp.ClientOptions, public.WithInstanceDiscovery(!cp.DisableInstanceDiscovery)) + msalOpts := msalClientOptions{ + ClientOptions: cp.ClientOptions, + DisableInstanceDiscovery: cp.DisableInstanceDiscovery, + } + c, err := getPublicClient(cp.ClientID, cp.TenantID, msalOpts) if err != nil { return nil, err } diff --git a/sdk/azidentity/on_behalf_of_credential.go b/sdk/azidentity/on_behalf_of_credential.go index 440b121600c1..30fa168f3dfe 100644 --- a/sdk/azidentity/on_behalf_of_credential.go +++ b/sdk/azidentity/on_behalf_of_credential.go @@ -72,12 +72,12 @@ func newOnBehalfOfCredential(tenantID, clientID, userAssertion string, cred conf if options == nil { options = &OnBehalfOfCredentialOptions{} } - opts := []confidential.Option{} - if options.SendCertificateChain { - opts = append(opts, confidential.WithX5C()) + msalOpts := msalClientOptions{ + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, + SendX5C: options.SendCertificateChain, } - opts = append(opts, confidential.WithInstanceDiscovery(!options.DisableInstanceDiscovery)) - c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions, opts...) + c, err := getConfidentialClient(clientID, tenantID, cred, msalOpts) if err != nil { return nil, err } diff --git a/sdk/azidentity/username_password_credential.go b/sdk/azidentity/username_password_credential.go index 4e5a5142efae..c698377eec32 100644 --- a/sdk/azidentity/username_password_credential.go +++ b/sdk/azidentity/username_password_credential.go @@ -48,7 +48,11 @@ func NewUsernamePasswordCredential(tenantID string, clientID string, username st if options == nil { options = &UsernamePasswordCredentialOptions{} } - c, err := getPublicClient(clientID, tenantID, &options.ClientOptions, public.WithInstanceDiscovery(!options.DisableInstanceDiscovery)) + msalOpts := msalClientOptions{ + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, + } + c, err := getPublicClient(clientID, tenantID, msalOpts) if err != nil { return nil, err } From 896899ed30af3cfb727e5abaf811f930e0930b97 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 7 Jul 2023 12:54:51 -0700 Subject: [PATCH 2/2] fix OBO test: check whether x5c is set --- sdk/azidentity/on_behalf_of_credential_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sdk/azidentity/on_behalf_of_credential_test.go b/sdk/azidentity/on_behalf_of_credential_test.go index b1f0a5c4f34b..510ad57c19f0 100644 --- a/sdk/azidentity/on_behalf_of_credential_test.go +++ b/sdk/azidentity/on_behalf_of_credential_test.go @@ -20,6 +20,7 @@ func TestOnBehalfOfCredential(t *testing.T) { realGetClient := getConfidentialClient t.Cleanup(func() { getConfidentialClient = realGetClient }) expectedAssertion := "user-assertion" + certs, key := allCertTests[0].certs, allCertTests[0].key for _, test := range []struct { ctor func(policy.Transporter) (*OnBehalfOfCredential, error) name string @@ -27,7 +28,6 @@ func TestOnBehalfOfCredential(t *testing.T) { }{ { ctor: func(tp policy.Transporter) (*OnBehalfOfCredential, error) { - certs, key := allCertTests[0].certs, allCertTests[0].key o := OnBehalfOfCredentialOptions{ClientOptions: policy.ClientOptions{Transport: tp}} return NewOnBehalfOfCredentialWithCertificate(fakeTenantID, fakeClientID, expectedAssertion, certs, key, &o) }, @@ -35,7 +35,6 @@ func TestOnBehalfOfCredential(t *testing.T) { }, { ctor: func(tp policy.Transporter) (*OnBehalfOfCredential, error) { - certs, key := allCertTests[0].certs, allCertTests[0].key o := OnBehalfOfCredentialOptions{ClientOptions: policy.ClientOptions{Transport: tp}, SendCertificateChain: true} return NewOnBehalfOfCredentialWithCertificate(fakeTenantID, fakeClientID, expectedAssertion, certs, key, &o) }, @@ -68,6 +67,9 @@ func TestOnBehalfOfCredential(t *testing.T) { if assertion := r.FormValue("assertion"); assertion != expectedAssertion { t.Errorf(`unexpected assertion "%s"`, assertion) } + if test.sendX5C { + validateX5C(t, certs)(r) + } }} cred, err := test.ctor(&srv) if err != nil {