Skip to content

Commit 9d202e9

Browse files
authored
Refactor internal MSAL client constructors (#21117)
1 parent a09fcfb commit 9d202e9

9 files changed

+58
-31
lines changed

sdk/azidentity/azidentity.go

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,46 +50,55 @@ var (
5050
disableCP1 = strings.ToLower(os.Getenv("AZURE_IDENTITY_DISABLE_CP1")) == "true"
5151
)
5252

53-
var getConfidentialClient = func(clientID, tenantID string, cred confidential.Credential, co *azcore.ClientOptions, additionalOpts ...confidential.Option) (confidentialClient, error) {
53+
type msalClientOptions struct {
54+
azcore.ClientOptions
55+
56+
DisableInstanceDiscovery bool
57+
// SendX5C applies only to confidential clients authenticating with a cert
58+
SendX5C bool
59+
}
60+
61+
var getConfidentialClient = func(clientID, tenantID string, cred confidential.Credential, opts msalClientOptions) (confidentialClient, error) {
5462
if !validTenantID(tenantID) {
5563
return confidential.Client{}, errors.New(tenantIDValidationErr)
5664
}
57-
authorityHost, err := setAuthorityHost(co.Cloud)
65+
authorityHost, err := setAuthorityHost(opts.Cloud)
5866
if err != nil {
5967
return confidential.Client{}, err
6068
}
6169
authority := runtime.JoinPaths(authorityHost, tenantID)
6270
o := []confidential.Option{
6371
confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)),
64-
confidential.WithHTTPClient(newPipelineAdapter(co)),
72+
confidential.WithHTTPClient(newPipelineAdapter(&opts.ClientOptions)),
6573
}
6674
if !disableCP1 {
6775
o = append(o, confidential.WithClientCapabilities(cp1))
6876
}
69-
o = append(o, additionalOpts...)
70-
if strings.ToLower(tenantID) == "adfs" {
77+
if opts.SendX5C {
78+
o = append(o, confidential.WithX5C())
79+
}
80+
if opts.DisableInstanceDiscovery || strings.ToLower(tenantID) == "adfs" {
7181
o = append(o, confidential.WithInstanceDiscovery(false))
7282
}
7383
return confidential.New(authority, clientID, cred, o...)
7484
}
7585

76-
var getPublicClient = func(clientID, tenantID string, co *azcore.ClientOptions, additionalOpts ...public.Option) (public.Client, error) {
86+
var getPublicClient = func(clientID, tenantID string, opts msalClientOptions) (public.Client, error) {
7787
if !validTenantID(tenantID) {
7888
return public.Client{}, errors.New(tenantIDValidationErr)
7989
}
80-
authorityHost, err := setAuthorityHost(co.Cloud)
90+
authorityHost, err := setAuthorityHost(opts.Cloud)
8191
if err != nil {
8292
return public.Client{}, err
8393
}
8494
o := []public.Option{
8595
public.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)),
86-
public.WithHTTPClient(newPipelineAdapter(co)),
96+
public.WithHTTPClient(newPipelineAdapter(&opts.ClientOptions)),
8797
}
8898
if !disableCP1 {
8999
o = append(o, public.WithClientCapabilities(cp1))
90100
}
91-
o = append(o, additionalOpts...)
92-
if strings.ToLower(tenantID) == "adfs" {
101+
if opts.DisableInstanceDiscovery || strings.ToLower(tenantID) == "adfs" {
93102
o = append(o, public.WithInstanceDiscovery(false))
94103
}
95104
return public.New(clientID, o...)

sdk/azidentity/client_assertion_credential.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ func NewClientAssertionCredential(tenantID, clientID string, getAssertion func(c
5656
return getAssertion(ctx)
5757
},
5858
)
59-
c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions, confidential.WithInstanceDiscovery(!options.DisableInstanceDiscovery))
59+
msalOpts := msalClientOptions{
60+
ClientOptions: options.ClientOptions,
61+
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
62+
}
63+
c, err := getConfidentialClient(clientID, tenantID, cred, msalOpts)
6064
if err != nil {
6165
return nil, err
6266
}

sdk/azidentity/client_certificate_credential.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ func NewClientCertificateCredential(tenantID string, clientID string, certs []*x
5858
if err != nil {
5959
return nil, err
6060
}
61-
var o []confidential.Option
62-
if options.SendCertificateChain {
63-
o = append(o, confidential.WithX5C())
61+
msalOpts := msalClientOptions{
62+
ClientOptions: options.ClientOptions,
63+
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
64+
SendX5C: options.SendCertificateChain,
6465
}
65-
o = append(o, confidential.WithInstanceDiscovery(!options.DisableInstanceDiscovery))
66-
c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions, o...)
66+
c, err := getConfidentialClient(clientID, tenantID, cred, msalOpts)
6767
if err != nil {
6868
return nil, err
6969
}

sdk/azidentity/client_secret_credential.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@ func NewClientSecretCredential(tenantID string, clientID string, clientSecret st
4646
if err != nil {
4747
return nil, err
4848
}
49-
c, err := getConfidentialClient(
50-
clientID, tenantID, cred, &options.ClientOptions, confidential.WithInstanceDiscovery(!options.DisableInstanceDiscovery),
51-
)
49+
msalOpts := msalClientOptions{
50+
ClientOptions: options.ClientOptions,
51+
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
52+
}
53+
c, err := getConfidentialClient(clientID, tenantID, cred, msalOpts)
5254
if err != nil {
5355
return nil, err
5456
}

sdk/azidentity/device_code_credential.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,11 @@ func NewDeviceCodeCredential(options *DeviceCodeCredentialOptions) (*DeviceCodeC
8787
cp = *options
8888
}
8989
cp.init()
90-
c, err := getPublicClient(
91-
cp.ClientID, cp.TenantID, &cp.ClientOptions, public.WithInstanceDiscovery(!cp.DisableInstanceDiscovery),
92-
)
90+
msalOpts := msalClientOptions{
91+
ClientOptions: cp.ClientOptions,
92+
DisableInstanceDiscovery: cp.DisableInstanceDiscovery,
93+
}
94+
c, err := getPublicClient(cp.ClientID, cp.TenantID, msalOpts)
9395
if err != nil {
9496
return nil, err
9597
}

sdk/azidentity/interactive_browser_credential.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@ func NewInteractiveBrowserCredential(options *InteractiveBrowserCredentialOption
6969
cp = *options
7070
}
7171
cp.init()
72-
c, err := getPublicClient(cp.ClientID, cp.TenantID, &cp.ClientOptions, public.WithInstanceDiscovery(!cp.DisableInstanceDiscovery))
72+
msalOpts := msalClientOptions{
73+
ClientOptions: cp.ClientOptions,
74+
DisableInstanceDiscovery: cp.DisableInstanceDiscovery,
75+
}
76+
c, err := getPublicClient(cp.ClientID, cp.TenantID, msalOpts)
7377
if err != nil {
7478
return nil, err
7579
}

sdk/azidentity/on_behalf_of_credential.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,12 @@ func newOnBehalfOfCredential(tenantID, clientID, userAssertion string, cred conf
7272
if options == nil {
7373
options = &OnBehalfOfCredentialOptions{}
7474
}
75-
opts := []confidential.Option{}
76-
if options.SendCertificateChain {
77-
opts = append(opts, confidential.WithX5C())
75+
msalOpts := msalClientOptions{
76+
ClientOptions: options.ClientOptions,
77+
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
78+
SendX5C: options.SendCertificateChain,
7879
}
79-
opts = append(opts, confidential.WithInstanceDiscovery(!options.DisableInstanceDiscovery))
80-
c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions, opts...)
80+
c, err := getConfidentialClient(clientID, tenantID, cred, msalOpts)
8181
if err != nil {
8282
return nil, err
8383
}

sdk/azidentity/on_behalf_of_credential_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,21 @@ func TestOnBehalfOfCredential(t *testing.T) {
2020
realGetClient := getConfidentialClient
2121
t.Cleanup(func() { getConfidentialClient = realGetClient })
2222
expectedAssertion := "user-assertion"
23+
certs, key := allCertTests[0].certs, allCertTests[0].key
2324
for _, test := range []struct {
2425
ctor func(policy.Transporter) (*OnBehalfOfCredential, error)
2526
name string
2627
sendX5C bool
2728
}{
2829
{
2930
ctor: func(tp policy.Transporter) (*OnBehalfOfCredential, error) {
30-
certs, key := allCertTests[0].certs, allCertTests[0].key
3131
o := OnBehalfOfCredentialOptions{ClientOptions: policy.ClientOptions{Transport: tp}}
3232
return NewOnBehalfOfCredentialWithCertificate(fakeTenantID, fakeClientID, expectedAssertion, certs, key, &o)
3333
},
3434
name: "certificate",
3535
},
3636
{
3737
ctor: func(tp policy.Transporter) (*OnBehalfOfCredential, error) {
38-
certs, key := allCertTests[0].certs, allCertTests[0].key
3938
o := OnBehalfOfCredentialOptions{ClientOptions: policy.ClientOptions{Transport: tp}, SendCertificateChain: true}
4039
return NewOnBehalfOfCredentialWithCertificate(fakeTenantID, fakeClientID, expectedAssertion, certs, key, &o)
4140
},
@@ -68,6 +67,9 @@ func TestOnBehalfOfCredential(t *testing.T) {
6867
if assertion := r.FormValue("assertion"); assertion != expectedAssertion {
6968
t.Errorf(`unexpected assertion "%s"`, assertion)
7069
}
70+
if test.sendX5C {
71+
validateX5C(t, certs)(r)
72+
}
7173
}}
7274
cred, err := test.ctor(&srv)
7375
if err != nil {

sdk/azidentity/username_password_credential.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ func NewUsernamePasswordCredential(tenantID string, clientID string, username st
4848
if options == nil {
4949
options = &UsernamePasswordCredentialOptions{}
5050
}
51-
c, err := getPublicClient(clientID, tenantID, &options.ClientOptions, public.WithInstanceDiscovery(!options.DisableInstanceDiscovery))
51+
msalOpts := msalClientOptions{
52+
ClientOptions: options.ClientOptions,
53+
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
54+
}
55+
c, err := getPublicClient(clientID, tenantID, msalOpts)
5256
if err != nil {
5357
return nil, err
5458
}

0 commit comments

Comments
 (0)