diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index 12deaf84ad46..9be3a774e6cc 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -35,10 +35,8 @@ const ( ) var ( - accessTokenRespSuccess = []byte(fmt.Sprintf(`{"access_token": "%s", "expires_in": %d}`, tokenValue, tokenExpiresIn)) - instanceDiscoveryResponse = getInstanceDiscoveryResponse(fakeTenantID) - tenantDiscoveryResponse = getTenantDiscoveryResponse(fakeTenantID) - testTRO = policy.TokenRequestOptions{Scopes: []string{liveTestScope}} + accessTokenRespSuccess = []byte(fmt.Sprintf(`{"access_token": "%s", "expires_in": %d}`, tokenValue, tokenExpiresIn)) + testTRO = policy.TokenRequestOptions{Scopes: []string{liveTestScope}} ) // constants for this file @@ -46,97 +44,8 @@ const ( testHost = "https://localhost" ) -func getInstanceDiscoveryResponse(tenant string) []byte { - return []byte(strings.ReplaceAll(`{ - "tenant_discovery_endpoint": "https://login.microsoftonline.com/{tenant}/v2.0/.well-known/openid-configuration", - "api-version": "1.1", - "metadata": [ - { - "preferred_network": "login.microsoftonline.com", - "preferred_cache": "login.windows.net", - "aliases": [ - "login.microsoftonline.com", - "login.windows.net", - "login.microsoft.com", - "sts.windows.net" - ] - } - ] - }`, "{tenant}", tenant)) -} - -func getTenantDiscoveryResponse(tenant string) []byte { - return []byte(strings.ReplaceAll(`{ - "token_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token", - "token_endpoint_auth_methods_supported": [ - "client_secret_post", - "private_key_jwt", - "client_secret_basic" - ], - "jwks_uri": "https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys", - "response_modes_supported": [ - "query", - "fragment", - "form_post" - ], - "subject_types_supported": [ - "pairwise" - ], - "id_token_signing_alg_values_supported": [ - "RS256" - ], - "response_types_supported": [ - "code", - "id_token", - "code id_token", - "id_token token" - ], - "scopes_supported": [ - "openid", - "profile", - "email", - "offline_access" - ], - "issuer": "https://login.microsoftonline.com/{tenant}/v2.0", - "request_uri_parameter_supported": false, - "userinfo_endpoint": "https://graph.microsoft.com/oidc/userinfo", - "authorization_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/authorize", - "device_authorization_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/devicecode", - "http_logout_supported": true, - "frontchannel_logout_supported": true, - "end_session_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/logout", - "claims_supported": [ - "sub", - "iss", - "cloud_instance_name", - "cloud_instance_host_name", - "cloud_graph_host_name", - "msgraph_host", - "aud", - "exp", - "iat", - "auth_time", - "acr", - "nonce", - "preferred_username", - "name", - "tid", - "ver", - "at_hash", - "c_hash", - "email" - ], - "kerberos_endpoint": "https://login.microsoftonline.com/{tenant}/kerberos", - "tenant_region_scope": "NA", - "cloud_instance_name": "microsoftonline.com", - "cloud_graph_host_name": "graph.windows.net", - "msgraph_host": "graph.microsoft.com", - "rbac_url": "https://pas.windows.net" - }`, "{tenant}", tenant)) -} - -func validateX5C(t *testing.T, certs []*x509.Certificate) mock.ResponsePredicate { - return func(req *http.Request) bool { +func validateX5C(t *testing.T, certs []*x509.Certificate) func(*http.Request) *http.Response { + return func(req *http.Request) *http.Response { err := req.ParseForm() if err != nil { t.Fatal("expected a form body") @@ -157,7 +66,7 @@ func validateX5C(t *testing.T, certs []*x509.Certificate) mock.ResponsePredicate } else if actual := len(v); actual != len(certs) { t.Fatalf("expected %d certs, got %d", len(certs), actual) } - return true + return nil } } diff --git a/sdk/azidentity/client_assertion_credential_test.go b/sdk/azidentity/client_assertion_credential_test.go index 577c4e919595..497f48eb2c7f 100644 --- a/sdk/azidentity/client_assertion_credential_test.go +++ b/sdk/azidentity/client_assertion_credential_test.go @@ -14,16 +14,9 @@ import ( "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) func TestClientAssertionCredential(t *testing.T) { - srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) - defer close() - srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse)) - srv.AppendResponse(mock.WithBody(tenantDiscoveryResponse)) - srv.AppendResponse(mock.WithBody(accessTokenRespSuccess)) - key := struct{}{} calls := 0 getAssertion := func(c context.Context) (string, error) { @@ -34,7 +27,7 @@ func TestClientAssertionCredential(t *testing.T) { return "assertion", nil } cred, err := NewClientAssertionCredential("tenant", "clientID", getAssertion, &ClientAssertionCredentialOptions{ - ClientOptions: azcore.ClientOptions{Transport: srv}, + ClientOptions: azcore.ClientOptions{Transport: &mockSTS{}}, }) if err != nil { t.Fatal(err) @@ -58,16 +51,10 @@ func TestClientAssertionCredential(t *testing.T) { } func TestClientAssertionCredentialCallbackError(t *testing.T) { - srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) - defer close() - srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse)) - srv.AppendResponse(mock.WithBody(tenantDiscoveryResponse)) - srv.AppendResponse(mock.WithBody(accessTokenRespSuccess)) - expectedError := errors.New("it didn't work") getAssertion := func(c context.Context) (string, error) { return "", expectedError } cred, err := NewClientAssertionCredential("tenant", "clientID", getAssertion, &ClientAssertionCredentialOptions{ - ClientOptions: azcore.ClientOptions{Transport: srv}, + ClientOptions: azcore.ClientOptions{Transport: &mockSTS{}}, }) if err != nil { t.Fatal(err) @@ -97,7 +84,7 @@ func TestClientAssertionCredential_Live(t *testing.T) { defer stop() cred, err := NewClientAssertionCredential(liveSP.tenantID, liveSP.clientID, func(context.Context) (string, error) { - return getAssertion(certs[0], key) + return assertion(certs[0], key) }, &ClientAssertionCredentialOptions{ClientOptions: o, DisableInstanceDiscovery: d}, ) diff --git a/sdk/azidentity/client_certificate_credential_test.go b/sdk/azidentity/client_certificate_credential_test.go index d3859057aaab..7372ba2fbffc 100644 --- a/sdk/azidentity/client_certificate_credential_test.go +++ b/sdk/azidentity/client_certificate_credential_test.go @@ -17,7 +17,6 @@ import ( "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" ) @@ -111,14 +110,7 @@ func TestClientCertificateCredential_GetTokenSuccess_withCertificateChain(t *tes func TestClientCertificateCredential_SendCertificateChain(t *testing.T) { for _, test := range allCertTests { t.Run(test.name, func(t *testing.T) { - srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) - defer close() - srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse)) - srv.AppendResponse(mock.WithBody(tenantDiscoveryResponse)) - srv.AppendResponse(mock.WithPredicate(validateX5C(t, test.certs)), mock.WithBody(accessTokenRespSuccess)) - srv.AppendResponse() - - options := ClientCertificateCredentialOptions{ClientOptions: azcore.ClientOptions{Transport: srv}, SendCertificateChain: true} + options := ClientCertificateCredentialOptions{ClientOptions: azcore.ClientOptions{Transport: &mockSTS{}}, SendCertificateChain: true} cred, err := NewClientCertificateCredential(fakeTenantID, fakeClientID, test.certs, test.key, &options) if err != nil { t.Fatal(err) @@ -165,14 +157,8 @@ func TestClientCertificateCredential_NoCertificate(t *testing.T) { func TestClientCertificateCredential_NoPrivateKey(t *testing.T) { test := allCertTests[0] - srv, close := mock.NewTLSServer() - defer close() - srv.AppendResponse(mock.WithBody(accessTokenRespSuccess)) - options := ClientCertificateCredentialOptions{} - options.Cloud.ActiveDirectoryAuthorityHost = srv.URL() - options.Transport = srv var key crypto.PrivateKey - _, err := NewClientCertificateCredential(fakeTenantID, fakeClientID, test.certs, key, &options) + _, err := NewClientCertificateCredential(fakeTenantID, fakeClientID, test.certs, key, nil) if err == nil { t.Fatalf("Expected an error but received nil") } diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go index 660269bd3b39..1945ea82a68a 100644 --- a/sdk/azidentity/default_azure_credential_test.go +++ b/sdk/azidentity/default_azure_credential_test.go @@ -212,7 +212,7 @@ func TestDefaultAzureCredential_Workload(t *testing.T) { if err := os.WriteFile(tempFile, []byte(expectedAssertion), os.ModePerm); err != nil { t.Fatalf(`failed to write temporary file "%s": %v`, tempFile, err) } - pred := func(req *http.Request) bool { + sts := mockSTS{tokenRequestCallback: func(req *http.Request) *http.Response { if err := req.ParseForm(); err != nil { t.Fatal(err) } @@ -225,14 +225,8 @@ func TestDefaultAzureCredential_Workload(t *testing.T) { if actual := strings.Split(req.URL.Path, "/")[1]; actual != fakeTenantID { t.Fatalf(`unexpected tenant "%s"`, actual) } - return true - } - srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) - defer close() - srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse)) - srv.AppendResponse(mock.WithBody(tenantDiscoveryResponse)) - srv.AppendResponse(mock.WithPredicate(pred), mock.WithBody(accessTokenRespSuccess)) - srv.AppendResponse() + return nil + }} for k, v := range map[string]string{ azureAuthorityHost: cloud.AzurePublic.ActiveDirectoryAuthorityHost, azureClientID: fakeClientID, @@ -241,7 +235,7 @@ func TestDefaultAzureCredential_Workload(t *testing.T) { } { t.Setenv(k, v) } - cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ClientOptions: policy.ClientOptions{Transport: srv}}) + cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ClientOptions: policy.ClientOptions{Transport: &sts}}) if err != nil { t.Fatal(err) } @@ -266,17 +260,13 @@ func (p *delayPolicy) Do(req *policy.Request) (resp *http.Response, err error) { } func TestDefaultAzureCredential_timeoutWrapper(t *testing.T) { - srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) - defer close() - srv.SetResponse(mock.WithBody(accessTokenRespSuccess)) - timeout := 100 * time.Millisecond dp := delayPolicy{2 * timeout} mic, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ ClientOptions: policy.ClientOptions{ PerCallPolicies: []policy.Policy{&dp}, Retry: policy.RetryOptions{MaxRetries: -1}, - Transport: srv, + Transport: &mockSTS{}, }, }) if err != nil { diff --git a/sdk/azidentity/device_code_credential_test.go b/sdk/azidentity/device_code_credential_test.go index a587543e01ea..d680fa35f049 100644 --- a/sdk/azidentity/device_code_credential_test.go +++ b/sdk/azidentity/device_code_credential_test.go @@ -87,7 +87,7 @@ func TestDeviceCodeCredential_UserPromptError(t *testing.T) { func TestDeviceCodeCredential_Live(t *testing.T) { if recording.GetRecordMode() != recording.PlaybackMode && !runManualTests { - t.Skip("set AZIDENTITY_RUN_MANUAL_TESTS to run this test") + t.Skipf("set %s to run this test", azidentityRunManualTests) } for _, test := range []struct { clientID, desc, tenantID string @@ -123,7 +123,7 @@ func TestDeviceCodeCredential_Live(t *testing.T) { func TestDeviceCodeCredentialADFS_Live(t *testing.T) { if recording.GetRecordMode() != recording.PlaybackMode && !runManualTests { - t.Skip("set AZIDENTITY_RUN_MANUAL_TESTS to run this test") + t.Skipf("set %s to run this test", azidentityRunManualTests) } if adfsLiveSP.clientID == "" { t.Skip("set ADFS_SP_* environment variables to run this test") diff --git a/sdk/azidentity/environment_credential_test.go b/sdk/azidentity/environment_credential_test.go index 2ef51055d6a5..3d4b7d2d1e5b 100644 --- a/sdk/azidentity/environment_credential_test.go +++ b/sdk/azidentity/environment_credential_test.go @@ -15,7 +15,6 @@ import ( "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" ) @@ -204,13 +203,7 @@ func TestEnvironmentCredential_SendCertificateChain(t *testing.T) { t.Fatal(err) } resetEnvironmentVarsForTest() - srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) - defer close() - srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse)) - srv.AppendResponse(mock.WithBody(tenantDiscoveryResponse)) - srv.AppendResponse(mock.WithPredicate(validateX5C(t, certs)), mock.WithBody(accessTokenRespSuccess)) - srv.AppendResponse() - + sts := mockSTS{tokenRequestCallback: validateX5C(t, certs)} vars := map[string]string{ azureClientID: liveSP.clientID, azureClientCertificatePath: liveSP.pfxPath, @@ -218,7 +211,7 @@ func TestEnvironmentCredential_SendCertificateChain(t *testing.T) { envVarSendCertChain: "true", } setEnvironmentVariables(t, vars) - cred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{ClientOptions: azcore.ClientOptions{Transport: srv}}) + cred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{ClientOptions: azcore.ClientOptions{Transport: &sts}}) if err != nil { t.Fatal(err) } diff --git a/sdk/azidentity/interactive_browser_credential_test.go b/sdk/azidentity/interactive_browser_credential_test.go index 178dc2c5c8f6..a47323a06cec 100644 --- a/sdk/azidentity/interactive_browser_credential_test.go +++ b/sdk/azidentity/interactive_browser_credential_test.go @@ -8,6 +8,7 @@ package azidentity import ( "context" + "fmt" "net/http" "strings" "testing" @@ -77,7 +78,7 @@ func (p *instanceDiscoveryPolicy) Do(req *policy.Request) (resp *http.Response, func TestInteractiveBrowserCredential_Live(t *testing.T) { if !runManualTests { - t.Skip("set AZIDENTITY_RUN_MANUAL_TESTS to run this test") + t.Skipf("set %s to run this test", azidentityRunManualTests) } t.Run("defaults", func(t *testing.T) { cred, err := NewInteractiveBrowserCredential(nil) @@ -88,7 +89,7 @@ func TestInteractiveBrowserCredential_Live(t *testing.T) { }) t.Run("LoginHint", func(t *testing.T) { upn := "test@pass" - t.Logf("consider this test passing when %q appears in the login prompt", upn) + fmt.Printf("\t%s: consider this test passing when %q appears in the login prompt", t.Name(), upn) cred, err := NewInteractiveBrowserCredential(&InteractiveBrowserCredentialOptions{LoginHint: upn}) if err != nil { t.Fatal(err) @@ -97,7 +98,7 @@ func TestInteractiveBrowserCredential_Live(t *testing.T) { }) t.Run("RedirectURL", func(t *testing.T) { url := "http://localhost:8180" - t.Logf("consider this test passing when AAD redirects to %s", url) + fmt.Printf("\t%s: consider this test passing when AAD redirects to %s", t.Name(), url) cred, err := NewInteractiveBrowserCredential(&InteractiveBrowserCredentialOptions{RedirectURL: url}) if err != nil { t.Fatal(err) @@ -122,7 +123,7 @@ func TestInteractiveBrowserCredential_Live(t *testing.T) { func TestInteractiveBrowserCredentialADFS_Live(t *testing.T) { if !runManualTests { - t.Skip("set AZIDENTITY_RUN_MANUAL_TESTS to run this test") + t.Skipf("set %s to run this test", azidentityRunManualTests) } if adfsLiveUser.clientID == fakeClientID { t.Skip("set ADFS_IDENTITY_TEST_CLIENT_ID environment variables to run this test live") diff --git a/sdk/azidentity/live_test.go b/sdk/azidentity/live_test.go index 36266579e727..4569b0b73e3d 100644 --- a/sdk/azidentity/live_test.go +++ b/sdk/azidentity/live_test.go @@ -57,12 +57,14 @@ var liveUser = struct { } const ( - fakeClientID = "fake-client-id" - fakeResourceID = "/fake/resource/ID" - fakeTenantID = "fake-tenant" - fakeUsername = "fake@user" - fakeAdfsAuthority = "fake.adfs.local" - fakeAdfsScope = "fake.adfs.local/fake-scope/.default" + azidentityRunManualTests = "AZIDENTITY_RUN_MANUAL_TESTS" + fakeClientID = "fake-client-id" + fakeResourceID = "/fake/resource/ID" + fakeTenantID = "fake-tenant" + fakeUsername = "fake@user" + fakeAdfsAuthority = "fake.adfs.local" + fakeAdfsScope = "fake.adfs.local/fake-scope/.default" + liveTestScope = "https://management.core.windows.net//.default" ) var adfsLiveSP = struct { @@ -90,8 +92,7 @@ var adfsLiveUser = struct { var ( adfsAuthority = os.Getenv("ADFS_AUTHORITY_HOST") adfsScope = os.Getenv("ADFS_SCOPE") - liveTestScope = "https://management.core.windows.net//.default" - _, runManualTests = os.LookupEnv("AZIDENTITY_RUN_MANUAL_TESTS") + _, runManualTests = os.LookupEnv(azidentityRunManualTests) ) func setFakeValues() { diff --git a/sdk/azidentity/managed_identity_client_test.go b/sdk/azidentity/managed_identity_client_test.go index 87a4d91640cd..aa604b266269 100644 --- a/sdk/azidentity/managed_identity_client_test.go +++ b/sdk/azidentity/managed_identity_client_test.go @@ -15,7 +15,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) type userAgentValidatingPolicy struct { @@ -42,13 +41,9 @@ func TestIMDSEndpointParse(t *testing.T) { } func TestManagedIdentityClient_UserAgent(t *testing.T) { - srv, close := mock.NewServer() - defer close() - srv.AppendResponse(mock.WithBody(accessTokenRespSuccess)) - setEnvironmentVariables(t, map[string]string{msiEndpoint: srv.URL()}) options := ManagedIdentityCredentialOptions{ ClientOptions: azcore.ClientOptions{ - Transport: srv, PerCallPolicies: []policy.Policy{userAgentValidatingPolicy{t: t}}, + Transport: &mockSTS{}, PerCallPolicies: []policy.Policy{userAgentValidatingPolicy{t: t}}, }, } client, err := newManagedIdentityClient(&options) @@ -59,20 +54,13 @@ func TestManagedIdentityClient_UserAgent(t *testing.T) { if err != nil { t.Fatal(err) } - if count := srv.Requests(); count != 1 { - t.Fatalf("expected 1 token request, got %d", count) - } } func TestManagedIdentityClient_ApplicationID(t *testing.T) { - srv, close := mock.NewServer() - defer close() - srv.AppendResponse(mock.WithBody(accessTokenRespSuccess)) - setEnvironmentVariables(t, map[string]string{msiEndpoint: srv.URL()}) appID := "customvalue" options := ManagedIdentityCredentialOptions{ ClientOptions: azcore.ClientOptions{ - Transport: srv, PerCallPolicies: []policy.Policy{userAgentValidatingPolicy{t: t, appID: appID}}, + Transport: &mockSTS{}, PerCallPolicies: []policy.Policy{userAgentValidatingPolicy{t: t, appID: appID}}, }, } options.Telemetry.ApplicationID = appID @@ -84,7 +72,4 @@ func TestManagedIdentityClient_ApplicationID(t *testing.T) { if err != nil { t.Fatal(err) } - if count := srv.Requests(); count != 1 { - t.Fatalf("expected 1 token request, got %d", count) - } } diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go index 5b25f247e7d0..64c6013611a3 100644 --- a/sdk/azidentity/managed_identity_credential_test.go +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -86,7 +86,7 @@ func TestManagedIdentityCredential_AzureArc(t *testing.T) { } func TestManagedIdentityCredential_CloudShell(t *testing.T) { - validateReq := func(req *http.Request) bool { + validateReq := func(req *http.Request) *http.Response { err := req.ParseForm() if err != nil { t.Fatal(err) @@ -97,15 +97,10 @@ func TestManagedIdentityCredential_CloudShell(t *testing.T) { if h := req.Header.Get("metadata"); h != "true" { t.Fatalf("unexpected metadata header: %s", h) } - return true + return nil } - srv, close := mock.NewServer() - defer close() - srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithBody(accessTokenRespSuccess)) - srv.AppendResponse() - setEnvironmentVariables(t, map[string]string{msiEndpoint: srv.URL()}) options := ManagedIdentityCredentialOptions{} - options.Transport = srv + options.Transport = &mockSTS{tokenRequestCallback: validateReq} msiCred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatal(err) @@ -301,11 +296,7 @@ func TestManagedIdentityCredential_GetTokenScopes(t *testing.T) { } func TestManagedIdentityCredential_ScopesImmutable(t *testing.T) { - srv, close := mock.NewServer() - defer close() - srv.AppendResponse(mock.WithBody([]byte(expiresOnIntResp))) - setEnvironmentVariables(t, map[string]string{msiEndpoint: srv.URL()}) - options := ManagedIdentityCredentialOptions{ClientOptions: azcore.ClientOptions{Transport: srv}} + options := ManagedIdentityCredentialOptions{ClientOptions: azcore.ClientOptions{Transport: &mockSTS{}}} cred, err := NewManagedIdentityCredential(&options) if err != nil { t.Fatalf("unexpected error: %v", err) diff --git a/sdk/azidentity/mock_test.go b/sdk/azidentity/mock_test.go index 76c14742b6dd..c4e1fee5fde4 100644 --- a/sdk/azidentity/mock_test.go +++ b/sdk/azidentity/mock_test.go @@ -35,9 +35,9 @@ func (m *mockSTS) Do(req *http.Request) (*http.Response, error) { } switch s := strings.Split(req.URL.Path, "/"); s[len(s)-1] { case "instance": - res.Body = io.NopCloser(bytes.NewReader(getInstanceDiscoveryResponse(tenant))) + res.Body = io.NopCloser(bytes.NewReader(instanceMetadata(tenant))) case "openid-configuration": - res.Body = io.NopCloser(bytes.NewReader(getTenantDiscoveryResponse(tenant))) + res.Body = io.NopCloser(bytes.NewReader(tenantMetadata(tenant))) case "devicecode": res.Body = io.NopCloser(strings.NewReader(`{"device_code":"...","expires_in":600,"interval":60}`)) case "token": @@ -60,3 +60,92 @@ func (m *mockSTS) Do(req *http.Request) (*http.Response, error) { } return res, nil } + +func instanceMetadata(tenant string) []byte { + return []byte(strings.ReplaceAll(`{ + "tenant_discovery_endpoint": "https://login.microsoftonline.com/{tenant}/v2.0/.well-known/openid-configuration", + "api-version": "1.1", + "metadata": [ + { + "preferred_network": "login.microsoftonline.com", + "preferred_cache": "login.windows.net", + "aliases": [ + "login.microsoftonline.com", + "login.windows.net", + "login.microsoft.com", + "sts.windows.net" + ] + } + ] + }`, "{tenant}", tenant)) +} + +func tenantMetadata(tenant string) []byte { + return []byte(strings.ReplaceAll(`{ + "token_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token", + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "private_key_jwt", + "client_secret_basic" + ], + "jwks_uri": "https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys", + "response_modes_supported": [ + "query", + "fragment", + "form_post" + ], + "subject_types_supported": [ + "pairwise" + ], + "id_token_signing_alg_values_supported": [ + "RS256" + ], + "response_types_supported": [ + "code", + "id_token", + "code id_token", + "id_token token" + ], + "scopes_supported": [ + "openid", + "profile", + "email", + "offline_access" + ], + "issuer": "https://login.microsoftonline.com/{tenant}/v2.0", + "request_uri_parameter_supported": false, + "userinfo_endpoint": "https://graph.microsoft.com/oidc/userinfo", + "authorization_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/authorize", + "device_authorization_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/devicecode", + "http_logout_supported": true, + "frontchannel_logout_supported": true, + "end_session_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/logout", + "claims_supported": [ + "sub", + "iss", + "cloud_instance_name", + "cloud_instance_host_name", + "cloud_graph_host_name", + "msgraph_host", + "aud", + "exp", + "iat", + "auth_time", + "acr", + "nonce", + "preferred_username", + "name", + "tid", + "ver", + "at_hash", + "c_hash", + "email" + ], + "kerberos_endpoint": "https://login.microsoftonline.com/{tenant}/kerberos", + "tenant_region_scope": "NA", + "cloud_instance_name": "microsoftonline.com", + "cloud_graph_host_name": "graph.windows.net", + "msgraph_host": "graph.microsoft.com", + "rbac_url": "https://pas.windows.net" + }`, "{tenant}", tenant)) +} diff --git a/sdk/azidentity/workload_identity_test.go b/sdk/azidentity/workload_identity_test.go index ac2af7befe47..20ba3fde5b62 100644 --- a/sdk/azidentity/workload_identity_test.go +++ b/sdk/azidentity/workload_identity_test.go @@ -27,7 +27,7 @@ import ( "github.com/google/uuid" ) -func getAssertion(cert *x509.Certificate, key crypto.PrivateKey) (string, error) { +func assertion(cert *x509.Certificate, key crypto.PrivateKey) (string, error) { j := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ "aud": fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", liveSP.tenantID), "exp": json.Number(strconv.FormatInt(time.Now().Add(10*time.Minute).Unix(), 10)), @@ -54,7 +54,7 @@ func TestWorkloadIdentityCredential_Live(t *testing.T) { if err != nil { t.Fatal(err) } - a, err := getAssertion(certs[0], key) + a, err := assertion(certs[0], key) if err != nil { t.Fatal(err) }