Skip to content

Added support for Token revocation support #567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
72 changes: 54 additions & 18 deletions apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ package managedidentity

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -82,7 +84,7 @@ const (
tokenName = "Tokens"

// App Service
appServiceAPIVersion = "2019-08-01"
appServiceAPIVersion = "2025-03-30"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks good and matches the app service new version. We just want to make sure not to merge this PR yet, as App Service rollout is still happening.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it work in MSAL .net about guarding the release or the version blocking ?


// AzureML
azureMLAPIVersion = "2017-09-01"
Expand Down Expand Up @@ -178,6 +180,7 @@ type Client struct {
authParams authority.AuthParams
retryPolicyEnabled bool
canRefresh *atomic.Value
clientCapabilities []string
}

type AcquireTokenOptions struct {
Expand All @@ -192,14 +195,34 @@ type AcquireTokenOption func(o *AcquireTokenOptions)
// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded.
func WithClaims(claims string) AcquireTokenOption {
return func(o *AcquireTokenOptions) {
o.claims = claims
if claims != "" {
o.claims = claims
}
}
}

// WithClientCapabilities sets the client capabilities to be used in the request.
// This is used to enable specific features or behaviors in the token request.
// The capabilities are passed as a slice of strings, and empty strings are filtered out.
func WithClientCapabilities(capabilities []string) ClientOption {
return func(o *Client) {
var filteredCapabilities []string
for _, cap := range capabilities {
if cap != "" {
filteredCapabilities = append(filteredCapabilities, cap)
}
}
o.clientCapabilities = filteredCapabilities
}
}

// WithHTTPClient allows for a custom HTTP client to be set.
// if nil, the default HTTP client will be used.
func WithHTTPClient(httpClient ops.HTTPClient) ClientOption {
return func(c *Client) {
c.httpClient = httpClient
if httpClient != nil {
c.httpClient = httpClient
}
}
}

Expand Down Expand Up @@ -323,28 +346,30 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac
}
c.authParams.Scopes = []string{resource}

// ignore cached access tokens when given claims
if o.claims == "" {
stResp, err := cacheManager.Read(ctx, c.authParams)
if err != nil {
return AuthResult{}, err
}
ar, err := base.AuthResultFromStorage(stResp)
if err == nil {
stResp, err := cacheManager.Read(ctx, c.authParams)
if err != nil {
return AuthResult{}, err
}
ar, err := base.AuthResultFromStorage(stResp)
if err == nil {
if o.claims != "" {
// When the claims are set, we need to passon bad/old token
return c.getToken(ctx, resource, ar.AccessToken)
} else {
if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) {
defer c.canRefresh.Store(false)
if tr, er := c.getToken(ctx, resource); er == nil {
if tr, er := c.getToken(ctx, resource, o.claims); er == nil {
return tr, nil
}
}
ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken)
return ar, err
}
}
return c.getToken(ctx, resource)
return c.getToken(ctx, resource, "")
}

func (c Client) getToken(ctx context.Context, resource string) (AuthResult, error) {
func (c Client) getToken(ctx context.Context, resource string, badToken string) (AuthResult, error) {
switch c.source {
case AzureArc:
return c.acquireTokenForAzureArc(ctx, resource)
Expand All @@ -355,16 +380,16 @@ func (c Client) getToken(ctx context.Context, resource string) (AuthResult, erro
case DefaultToIMDS:
return c.acquireTokenForIMDS(ctx, resource)
case AppService:
return c.acquireTokenForAppService(ctx, resource)
return c.acquireTokenForAppService(ctx, resource, badToken)
case ServiceFabric:
return c.acquireTokenForServiceFabric(ctx, resource)
default:
return AuthResult{}, fmt.Errorf("unsupported source %q", c.source)
}
}

func (c Client) acquireTokenForAppService(ctx context.Context, resource string) (AuthResult, error) {
req, err := createAppServiceAuthRequest(ctx, c.miType, resource)
func (c Client) acquireTokenForAppService(ctx context.Context, resource string, badToken string) (AuthResult, error) {
req, err := createAppServiceAuthRequest(ctx, c.miType, resource, badToken, c.clientCapabilities)
if err != nil {
return AuthResult{}, err
}
Expand Down Expand Up @@ -569,16 +594,27 @@ func (c Client) getTokenForRequest(req *http.Request, resource string) (accessto
return r, err
}

func createAppServiceAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) {
func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, badToken string, cc []string) (*http.Request, error) {
identityEndpoint := os.Getenv(identityEndpointEnvVar)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeaderEnvVar))

q := req.URL.Query()
q.Set("api-version", appServiceAPIVersion)
q.Set("resource", resource)

if badToken != "" {
hash := sha256.Sum256([]byte(badToken))
q.Set("token_sha256_to_refresh", hex.EncodeToString(hash[:]))
}

if len(cc) > 0 {
q.Set("xms_cc", strings.Join(cc, ","))
}

switch t := id.(type) {
case UserAssignedClientID:
q.Set(miQueryParameterClientId, string(t))
Expand Down
73 changes: 73 additions & 0 deletions apps/managedidentity/managedidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package managedidentity
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -1209,3 +1211,74 @@ func TestRefreshInMultipleRequests(t *testing.T) {
}
close(ch)
}

// TestAppServiceWithClaimsAndBadAccessToken tests the scenario where claims are passed
// and a bad access token is retrieved from the cache
func TestAppServiceWithClaimsAndBadAccessToken(t *testing.T) {
setEnvVars(t, AppService)
localUrl := &url.URL{}
mockClient := mock.NewClient()
// Second response is a successful token response after retrying with claims
responseBody, err := getSuccessfulResponse(resource, false)
if err != nil {
t.Fatalf(errorFormingJsonResponse, err.Error())
}
mockClient.AppendResponse(
mock.WithHTTPStatusCode(http.StatusOK),
mock.WithBody(responseBody),
)
mockClient.AppendResponse(
mock.WithHTTPStatusCode(http.StatusOK),
mock.WithBody(responseBody),
mock.WithCallback(func(r *http.Request) {
localUrl = r.URL
}))
// Reset cache for clean test
before := cacheManager
defer func() { cacheManager = before }()
cacheManager = storage.New(nil)

client, err := New(SystemAssigned(),
WithHTTPClient(mockClient),
WithClientCapabilities([]string{"c1", "c2"}))
if err != nil {
t.Fatal(err)
}

// Call AcquireToken which should trigger token revocation flow
result, err := client.AcquireToken(context.Background(), resource)
if err != nil {
t.Fatalf("AcquireToken failed: %v", err)
}

// Verify token was obtained successfully
if result.AccessToken != token {
t.Fatalf("Expected access token %q, got %q", token, result.AccessToken)
}

// Call AcquireToken which should trigger token revocation flow
result, err = client.AcquireToken(context.Background(), resource, WithClaims("dummyClaims"))
if err != nil {
t.Fatalf("AcquireToken failed: %v", err)
}

localUrlQuerry := localUrl.Query()

if localUrlQuerry.Get(apiVersionQueryParameterName) != appServiceAPIVersion {
t.Fatalf("api-version not on %s got %s", appServiceAPIVersion, localUrlQuerry.Get(apiVersionQueryParameterName))
}
if r := localUrlQuerry.Get(resourceQueryParameterName); strings.HasSuffix(r, "/.default") {
t.Fatal("suffix /.default was not removed.")
}
if localUrlQuerry.Get("xms_cc") != "c1,c2" {
t.Fatalf("Expected client capabilities %q, got %q", "c1,c2", localUrlQuerry.Get("xms_cc"))
}
hash := sha256.Sum256([]byte(token))
if localUrlQuerry.Get("token_sha256_to_refresh") != hex.EncodeToString(hash[:]) {
t.Fatalf("Expected token_sha256_to_refresh %q, got %q", hex.EncodeToString(hash[:]), localUrlQuerry.Get("token_sha256_to_refresh"))
}
// Verify token was obtained successfully
if result.AccessToken != token {
t.Fatalf("Expected access token %q, got %q", token, result.AccessToken)
}
}