Skip to content

WorkloadIdentityCredential enhancement for AKS FIC limit #24442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
### Bugs Fixed

### Other Changes
* enhanced support for AKS workloads

## 1.9.0 (2025-04-08)

Expand Down
134 changes: 133 additions & 1 deletion sdk/azidentity/workload_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,31 @@
package azidentity

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net/http"
"os"
"strings"
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo"
)

const credNameWorkloadIdentity = "WorkloadIdentityCredential"
const (
aksCAData = "AZURE_KUBERNETES_CA_DATA"
aksCAFile = "AZURE_KUBERNETES_CA_FILE"
aksSNIName = "AZURE_KUBERNETES_SNI_NAME"
aksTokenEndpoint = "AZURE_KUBERNETES_TOKEN_ENDPOINT"
Comment on lines +29 to +32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the SDK changes are generic and are meant to be usable on any kube running on Azure, i.e. AKS, CAPZ, etc.

Suggested change
aksCAData = "AZURE_KUBERNETES_CA_DATA"
aksCAFile = "AZURE_KUBERNETES_CA_FILE"
aksSNIName = "AZURE_KUBERNETES_SNI_NAME"
aksTokenEndpoint = "AZURE_KUBERNETES_TOKEN_ENDPOINT"
azureCAData = "AZURE_KUBERNETES_CA_DATA"
azureCAFile = "AZURE_KUBERNETES_CA_FILE"
azureSNIName = "AZURE_KUBERNETES_SNI_NAME"
azureTokenEndpoint = "AZURE_KUBERNETES_TOKEN_ENDPOINT"

credNameWorkloadIdentity = "WorkloadIdentityCredential"
)

// WorkloadIdentityCredential supports Azure workload identity on Kubernetes.
// See [Azure Kubernetes Service documentation] for more information.
Expand Down Expand Up @@ -94,6 +107,13 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) (
ClientOptions: options.ClientOptions,
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
}
if p, err := newAKSTokenRequestPolicy(); err != nil {
return nil, err
} else if p != nil {
// add the policy to the end of the pipeline. It will run
// after all other policies, including any added by the caller
caco.ClientOptions.PerRetryPolicies = append(caco.ClientOptions.PerRetryPolicies, p)
}
cred, err := NewClientAssertionCredential(tenantID, clientID, w.getAssertion, &caco)
if err != nil {
return nil, err
Expand Down Expand Up @@ -139,3 +159,115 @@ func (w *WorkloadIdentityCredential) getAssertion(context.Context) (string, erro
}
return w.assertion, nil
}

// aksTokenRequestPolicy redirects token requests to the AKS token endpoint, sending them via its
// own HTTP client. It sends all other requests unchanged, via the pipeline's configured transport.
type aksTokenRequestPolicy struct {
// c is configured for the AKS token endpoint
c *http.Client
Comment on lines +166 to +167
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the request seen by this client still contain the correct headers such as user agent, etc?

// ca trusted by c
ca []byte
caFile, host, serverName string
}

func newAKSTokenRequestPolicy() (*aksTokenRequestPolicy, error) {
host := os.Getenv(aksTokenEndpoint)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a host, it is a full URL that can contain a path.

serverName := os.Getenv(aksSNIName)
if host == "" || serverName == "" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Server name should not be required. Only host should be needed to use this code path.

// the AKS feature isn't enabled for this process
return nil, nil
}
b := []byte(os.Getenv(aksCAData))
f := os.Getenv(aksCAFile)
switch {
case len(b) == 0 && len(f) == 0:
return nil, fmt.Errorf("no value found for %s or %s", aksCAData, aksCAFile)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are optional, system roots should be used if these are not given.

case len(b) > 0 && len(f) > 0:
return nil, fmt.Errorf("found values for both %s and %s", aksCAData, aksCAFile)
}
p := &aksTokenRequestPolicy{caFile: f, ca: b, host: host, serverName: serverName}
if _, err := p.client(); err != nil {
return nil, err
}
return p, nil
}

func (a *aksTokenRequestPolicy) Do(req *policy.Request) (*http.Response, error) {
if r := req.Raw(); strings.HasSuffix(r.URL.Path, "/token") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not believe this assumption is correct.

c, err := a.client()
if err != nil {
return nil, errorinfo.NonRetriableError(err)
}
r.URL.Host = a.host
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct.

r.Host = ""
res, err := c.Do(r)
if err != nil {
return nil, err
}
if res == nil {
// this policy is effectively a transport, so it must handle
// this rare case. Returning an error makes the retry policy
// try the request again
err = errors.New("received nil response")
}
return res, err
}
return req.Next()
}

func (a *aksTokenRequestPolicy) client() (*http.Client, error) {
// this function doesn't need synchronization because
// it's called under confidentialClient's lock

if a.caFile == "" {
// host provided CA bytes in AZURE_KUBERNETES_CA_DATA and can't change
// them now, so we need to create a client only if we haven't done so yet
if a.c == nil {
if len(a.ca) == 0 {
return nil, errors.New("no value found for " + aksCAData)
}
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(a.ca) {
return nil, errors.New("couldn't parse " + aksCAData)
}
a.c = &http.Client{
Transport: &http.Transport{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there other transport level defaults that the SDK normally uses?

TLSClientConfig: &tls.Config{
RootCAs: cp,
ServerName: a.serverName,
},
},
}
// this copy of the CA bytes is redundant because we've
// configured the client and won't execute this block again
a.ca = nil
}
return a.c, nil
}

// host provided the CA bytes in a file whose contents it can change,
// so we must read that file and maybe create a new client
b, err := os.ReadFile(a.caFile)
if err != nil {
return nil, fmt.Errorf("couldn't parse %s: %s", aksCAFile, err)
}
if len(b) == 0 {
return nil, errors.New(aksCAFile + " specifies an empty file")
}
if !bytes.Equal(b, a.ca) {
a.ca = b
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should only be set after a.c is successfully set.

cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(a.ca) {
return nil, errors.New("couldn't parse " + aksCAFile)
}
a.c = &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: cp,
ServerName: a.serverName,
},
},
}
}
return a.c, nil
}
178 changes: 178 additions & 0 deletions sdk/azidentity/workload_identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,26 @@ import (
"context"
"crypto"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
Expand Down Expand Up @@ -216,6 +222,178 @@ func TestTestWorkloadIdentityCredential_IncompleteConfig(t *testing.T) {
}
}

func TestWorkloadIdentityCredential_SNIPolicy(t *testing.T) {
called := false
expected := ""
newServer := func() ([]byte, *url.URL) {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprintln(w, string(accessTokenRespSuccess))
}))
t.Cleanup(ts.Close)
ts.TLS = &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
called = true
if expected == "" {
t.Error("test bug: expected server name not set; should match test server's DNS name")
} else if actual := info.ServerName; actual != expected {
t.Errorf("expected %q, got %q", expected, actual)
}
return nil, nil
},
}
ts.StartTLS()
cert := ts.Certificate()
expected = cert.DNSNames[0]
pemData := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
})
u, err := url.Parse(ts.URL)
require.NoError(t, err)
return pemData, u
}

pemData, u := newServer()
caFile := filepath.Join(t.TempDir(), t.Name())
require.NoError(t, os.WriteFile(caFile, pemData, 0600))

f := filepath.Join(t.TempDir(), t.Name())
require.NoError(t, os.WriteFile(f, []byte(tokenValue), 0600))

for k, v := range map[string]string{
aksSNIName: expected,
aksTokenEndpoint: u.Host,
azureClientID: fakeClientID,
azureFederatedTokenFile: f,
azureTenantID: fakeTenantID,
} {
t.Setenv(k, v)
}

for _, test := range []struct {
name string
vars map[string]string
}{
{"no cert specified", nil},
{"two certs specified", map[string]string{aksCAData: "...", aksCAFile: "..."}},
} {
t.Run(test.name, func(t *testing.T) {
for k, v := range test.vars {
t.Setenv(k, v)
}
_, err := NewWorkloadIdentityCredential(nil)
require.ErrorContains(t, err, aksCAData)
require.ErrorContains(t, err, aksCAFile)
})
}
o := WorkloadIdentityCredentialOptions{
ClientOptions: policy.ClientOptions{
Transport: &mockSTS{
tokenRequestCallback: func(*http.Request) *http.Response {
t.Fatal("credential should have sent token request to endpoint specified in " + aksTokenEndpoint)
return nil
},
},
},
}
for k, v := range map[string]string{
aksCAData: string(pemData),
aksCAFile: caFile,
} {
called = false
t.Run(k, func(t *testing.T) {
t.Setenv(k, v)
cred, err := NewWorkloadIdentityCredential(&o)
require.NoError(t, err)

tk, err := cred.GetToken(ctx, testTRO)
require.NoError(t, err)
require.Equal(t, tokenValue, tk.Token)
require.True(t, called, "test bug: test server's GetCertificate function wasn't called")

t.Run("race", func(t *testing.T) {
cred, err := NewWorkloadIdentityCredential(&o)
require.NoError(t, err)
wg := sync.WaitGroup{}
ch := make(chan error, 1)
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if _, err := cred.GetToken(ctx, testTRO); err != nil {
select {
case ch <- err:
default:
}
}
}()
}
wg.Wait()
select {
case err := <-ch:
t.Fatal(err)
default:
}
})
})
}

t.Run("file", func(t *testing.T) {
t.Setenv(aksCAFile, caFile)
t.Run("updated", func(t *testing.T) {
p, err := newAKSTokenRequestPolicy()
require.NoError(t, err)
pl := runtime.NewPipeline("", "", runtime.PipelineOptions{}, &policy.ClientOptions{
PerRetryPolicies: []policy.Policy{p},
Transport: &mockSTS{
tokenRequestCallback: func(*http.Request) *http.Response {
t.Fatal("policy should have sent this request to the AKS endpoint")
return nil
},
},
})

called = false
r, err := runtime.NewRequest(ctx, http.MethodGet, u.String()+"/tenant/token")
require.NoError(t, err)
_, err = pl.Do(r)
require.NoError(t, err)
require.True(t, called, "test bug: test server's GetCertificate function wasn't called")

// need a new server because a started one's TLS cert is immutable. Unfortunately, a new
// server listens on a different port, so we need to update the policy's host. This is
// why this test exercises the policy directly rather than through a credential instance
pemData, u := newServer()
p.host = u.Host
require.NoError(t, os.WriteFile(caFile, pemData, 0600))

called = false
r, err = runtime.NewRequest(ctx, http.MethodGet, u.String()+"/tenant/token")
require.NoError(t, err)
_, err = pl.Do(r)
require.NoError(t, err)
require.True(t, called, "test bug: test server's GetCertificate function wasn't called")
})
t.Run("invalid", func(t *testing.T) {
require.NoError(t, os.WriteFile(caFile, []byte("not a cert"), 0600))
_, err := NewWorkloadIdentityCredential(nil)
require.ErrorContains(t, err, "couldn't parse")
require.ErrorContains(t, err, aksCAFile)
})
t.Run("empty", func(t *testing.T) {
require.NoError(t, os.Truncate(caFile, 0))
_, err := NewWorkloadIdentityCredential(nil)
require.ErrorContains(t, err, "empty file")
})
t.Run("not found", func(t *testing.T) {
require.NoError(t, os.Remove(caFile))
_, err := NewWorkloadIdentityCredential(nil)
require.ErrorContains(t, err, aksCAFile)
require.ErrorContains(t, err, caFile)
})
})
}

func TestWorkloadIdentityCredential_NoFile(t *testing.T) {
for k, v := range map[string]string{
azureClientID: fakeClientID,
Expand Down