Skip to content

Added support for Token revocation support #567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
26 changes: 26 additions & 0 deletions apps/managedidentity/example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package managedidentity_test

import (
mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity"
)

func ExampleNew() {
// System assigned Client
miSystemassignedClient, err := mi.New(mi.SystemAssigned())
if err != nil {
// TODO: Handle error
}
_ = miSystemassignedClient

// User assigned Client
clientId := "ClientId" // TODO: replace with your Managed Identity Id

miClientIdAssignedClient, err := mi.New(mi.UserAssignedClientID(clientId), mi.WithClientCapabilities([]string{"cp1"}))
if err != nil {
// TODO: Handle error
}
_ = miClientIdAssignedClient
}
93 changes: 67 additions & 26 deletions apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ package managedidentity

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -82,7 +84,7 @@ const (
tokenName = "Tokens"

// App Service
appServiceAPIVersion = "2019-08-01"
appServiceAPIVersion = "2025-03-30"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks good and matches the app service new version. We just want to make sure not to merge this PR yet, as App Service rollout is still happening.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it work in MSAL .net about guarding the release or the version blocking ?


// AzureML
azureMLAPIVersion = "2017-09-01"
Expand Down Expand Up @@ -178,6 +180,7 @@ type Client struct {
authParams authority.AuthParams
retryPolicyEnabled bool
canRefresh *atomic.Value
clientCapabilities []string
}

type AcquireTokenOptions struct {
Expand All @@ -192,14 +195,34 @@ type AcquireTokenOption func(o *AcquireTokenOptions)
// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded.
func WithClaims(claims string) AcquireTokenOption {
return func(o *AcquireTokenOptions) {
o.claims = claims
if claims != "" {
o.claims = claims
}
}
}

// WithClientCapabilities sets the client capabilities to be used in the request.
// For details see https://learn.microsoft.com/entra/identity/conditional-access/concept-continuous-access-evaluation
// The capabilities are passed as a slice of strings, and empty strings are filtered out.
func WithClientCapabilities(capabilities []string) ClientOption {
return func(o *Client) {
var filteredCapabilities []string
for _, cap := range capabilities {
if cap != "" {
filteredCapabilities = append(filteredCapabilities, cap)
}
}
o.clientCapabilities = filteredCapabilities
}
}

// WithHTTPClient allows for a custom HTTP client to be set.
// if nil, the default HTTP client will be used.
func WithHTTPClient(httpClient ops.HTTPClient) ClientOption {
return func(c *Client) {
c.httpClient = httpClient
if httpClient != nil {
c.httpClient = httpClient
}
}
}

Expand Down Expand Up @@ -323,28 +346,29 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac
}
c.authParams.Scopes = []string{resource}

// ignore cached access tokens when given claims
if o.claims == "" {
stResp, err := cacheManager.Read(ctx, c.authParams)
if err != nil {
return AuthResult{}, err
stResp, err := cacheManager.Read(ctx, c.authParams)
if err != nil {
return AuthResult{}, err
}
ar, err := base.AuthResultFromStorage(stResp)
if err == nil {
if o.claims != "" {
// When the claims are set, we need to pass on revoked token to MSIv1 (AppService, ServiceFabric)
return c.getToken(ctx, resource, ar.AccessToken)
}
ar, err := base.AuthResultFromStorage(stResp)
if err == nil {
if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) {
defer c.canRefresh.Store(false)
if tr, er := c.getToken(ctx, resource); er == nil {
return tr, nil
}
if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) {
defer c.canRefresh.Store(false)
if tr, er := c.getToken(ctx, resource, ""); er == nil {
return tr, nil
}
ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken)
return ar, err
}
ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken)
return ar, err
}
return c.getToken(ctx, resource)
return c.getToken(ctx, resource, "")
}

func (c Client) getToken(ctx context.Context, resource string) (AuthResult, error) {
func (c Client) getToken(ctx context.Context, resource string, revokedToken string) (AuthResult, error) {
switch c.source {
case AzureArc:
return c.acquireTokenForAzureArc(ctx, resource)
Expand All @@ -355,16 +379,16 @@ func (c Client) getToken(ctx context.Context, resource string) (AuthResult, erro
case DefaultToIMDS:
return c.acquireTokenForIMDS(ctx, resource)
case AppService:
return c.acquireTokenForAppService(ctx, resource)
return c.acquireTokenForAppService(ctx, resource, revokedToken)
case ServiceFabric:
return c.acquireTokenForServiceFabric(ctx, resource)
return c.acquireTokenForServiceFabric(ctx, resource, revokedToken)
default:
return AuthResult{}, fmt.Errorf("unsupported source %q", c.source)
}
}

func (c Client) acquireTokenForAppService(ctx context.Context, resource string) (AuthResult, error) {
req, err := createAppServiceAuthRequest(ctx, c.miType, resource)
func (c Client) acquireTokenForAppService(ctx context.Context, resource string, revokedToken string) (AuthResult, error) {
req, err := createAppServiceAuthRequest(ctx, c.miType, resource, revokedToken, c.clientCapabilities)
if err != nil {
return AuthResult{}, err
}
Expand Down Expand Up @@ -411,8 +435,8 @@ func (c Client) acquireTokenForAzureML(ctx context.Context, resource string) (Au
return authResultFromToken(c.authParams, tokenResponse)
}

func (c Client) acquireTokenForServiceFabric(ctx context.Context, resource string) (AuthResult, error) {
req, err := createServiceFabricAuthRequest(ctx, resource)
func (c Client) acquireTokenForServiceFabric(ctx context.Context, resource string, revokedToken string) (AuthResult, error) {
req, err := createServiceFabricAuthRequest(ctx, resource, revokedToken, c.clientCapabilities)
if err != nil {
return AuthResult{}, err
}
Expand Down Expand Up @@ -569,16 +593,26 @@ func (c Client) getTokenForRequest(req *http.Request, resource string) (accessto
return r, err
}

func createAppServiceAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) {
func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, revokedToken string, cc []string) (*http.Request, error) {
identityEndpoint := os.Getenv(identityEndpointEnvVar)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeaderEnvVar))

q := req.URL.Query()
q.Set("api-version", appServiceAPIVersion)
q.Set("resource", resource)

if revokedToken != "" {
q.Set("token_sha256_to_refresh", convertTokenToSHA256HashString(revokedToken))
}

if len(cc) > 0 {
q.Set("xms_cc", strings.Join(cc, ","))
}

switch t := id.(type) {
case UserAssignedClientID:
q.Set(miQueryParameterClientId, string(t))
Expand All @@ -594,6 +628,13 @@ func createAppServiceAuthRequest(ctx context.Context, id ID, resource string) (*
return req, nil
}

func convertTokenToSHA256HashString(revokedToken string) string {
hash := sha256.New()
hash.Write([]byte(revokedToken))
hashBytes := hash.Sum(nil)
return hex.EncodeToString(hashBytes)
}

func createIMDSAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) {
msiEndpoint, err := url.Parse(imdsDefaultEndpoint)
if err != nil {
Expand Down
102 changes: 101 additions & 1 deletion apps/managedidentity/managedidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package managedidentity
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -964,7 +966,7 @@ func TestAzureArcErrors(t *testing.T) {
},
{
name: "Invalid file path",
headerValue: "Basic realm=" + filepath.Join("path", "to", secretKey),
headerValue: basicRealm + filepath.Join("path", "to", secretKey),
expectedError: "invalid file path, expected " + testCaseFilePath + ", got " + filepath.Join("path", "to"),
},
{
Expand Down Expand Up @@ -1209,3 +1211,101 @@ func TestRefreshInMultipleRequests(t *testing.T) {
}
close(ch)
}

// TestAppServiceWithClaimsAndBadAccessToken tests the scenario where claims are passed
// and a bad access token is retrieved from the cache
func TestAppServiceWithClaimsAndBadAccessToken(t *testing.T) {
setEnvVars(t, AppService)
localUrl := &url.URL{}
mockClient := mock.NewClient()
// Second response is a successful token response after retrying with claims
responseBody, err := getSuccessfulResponse(resource, false)
if err != nil {
t.Fatalf(errorFormingJsonResponse, err.Error())
}
mockClient.AppendResponse(
mock.WithHTTPStatusCode(http.StatusOK),
mock.WithBody(responseBody),
)
mockClient.AppendResponse(
mock.WithHTTPStatusCode(http.StatusOK),
mock.WithBody(responseBody),
mock.WithCallback(func(r *http.Request) {
localUrl = r.URL
}))
// Reset cache for clean test
before := cacheManager
defer func() { cacheManager = before }()
cacheManager = storage.New(nil)

client, err := New(SystemAssigned(),
WithHTTPClient(mockClient),
WithClientCapabilities([]string{"c1", "c2"}))
if err != nil {
t.Fatal(err)
}

// Call AcquireToken which should trigger token revocation flow
result, err := client.AcquireToken(context.Background(), resource)
if err != nil {
t.Fatalf("AcquireToken failed: %v", err)
}

// Verify token was obtained successfully
if result.AccessToken != token {
t.Fatalf("Expected access token %q, got %q", token, result.AccessToken)
}

// Call AcquireToken which should trigger token revocation flow
result, err = client.AcquireToken(context.Background(), resource, WithClaims("dummyClaims"))
if err != nil {
t.Fatalf("AcquireToken failed: %v", err)
}

localUrlQuerry := localUrl.Query()

if localUrlQuerry.Get(apiVersionQueryParameterName) != appServiceAPIVersion {
t.Fatalf("api-version not on %s got %s", appServiceAPIVersion, localUrlQuerry.Get(apiVersionQueryParameterName))
}
if r := localUrlQuerry.Get(resourceQueryParameterName); strings.HasSuffix(r, "/.default") {
t.Fatal("suffix /.default was not removed.")
}
if localUrlQuerry.Get("xms_cc") != "c1,c2" {
t.Fatalf("Expected client capabilities %q, got %q", "c1,c2", localUrlQuerry.Get("xms_cc"))
}
hash := sha256.Sum256([]byte(token))
if localUrlQuerry.Get("token_sha256_to_refresh") != hex.EncodeToString(hash[:]) {
t.Fatalf("Expected token_sha256_to_refresh %q, got %q", hex.EncodeToString(hash[:]), localUrlQuerry.Get("token_sha256_to_refresh"))
}
// Verify token was obtained successfully
if result.AccessToken != token {
t.Fatalf("Expected access token %q, got %q", token, result.AccessToken)
}
}

func TestConvertTokenToSHA256HashString(t *testing.T) {
tests := []struct {
token string
expectedHash string
}{
{
token: "test_token",
expectedHash: "cc0af97287543b65da2c7e1476426021826cab166f1e063ed012b855ff819656",
},
{
token: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~",
expectedHash: "01588d5a948b6c4facd47866877491b42866b5c10a4d342cf168e994101d352a",
},
{
token: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~",
expectedHash: "29c538690068a8ad1797a391bfe23e7fb817b601fc7b78288cb499ab8fd37947",
},
}

for _, test := range tests {
hash := convertTokenToSHA256HashString(test.token)
if hash != test.expectedHash {
t.Fatalf("for token %q, expected %q, got %q", test.token, test.expectedHash, hash)
}
}
}
11 changes: 10 additions & 1 deletion apps/managedidentity/servicefabric.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import (
"context"
"net/http"
"os"
"strings"
)

func createServiceFabricAuthRequest(ctx context.Context, resource string) (*http.Request, error) {
func createServiceFabricAuthRequest(ctx context.Context, resource string, revokedToken string, cc []string) (*http.Request, error) {
identityEndpoint := os.Getenv(identityEndpointEnvVar)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil)
if err != nil {
Expand All @@ -20,6 +21,14 @@ func createServiceFabricAuthRequest(ctx context.Context, resource string) (*http
q := req.URL.Query()
q.Set("api-version", serviceFabricAPIVersion)
q.Set("resource", resource)
if revokedToken != "" {
q.Set("token_sha256_to_refresh", convertTokenToSHA256HashString(revokedToken))
}

if len(cc) > 0 {
q.Set("xms_cc", strings.Join(cc, ","))
}

req.URL.RawQuery = q.Encode()
return req, nil
}
Loading