Skip to content

IMDS retry exponential backoff #568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,17 @@ func (c Client) retry(maxRetries int, req *http.Request) (*http.Response, error)
if err == nil && !contains(retrylist, resp.StatusCode) {
return resp, nil
}
// For IMDS, use exponential backoff based on attempt number
var waitTime time.Duration
if c.source == DefaultToIMDS {
// Exponential backoff with base of 1 second: 1s, 2s, 4s, 8s, etc.
waitTime = time.Second * time.Duration(1<<uint(attempt))
} else {
// For non-IMDS sources, use the fixed 1 second delay
waitTime = time.Second
}
select {
case <-time.After(time.Second):
case <-time.After(waitTime):
case <-req.Context().Done():
err = req.Context().Err()
return resp, err
Expand Down
73 changes: 68 additions & 5 deletions apps/managedidentity/managedidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ func TestRetryFunction(t *testing.T) {
expectedBody string
maxRetries int
source Source
expectedDelays []time.Duration // Expected delays for IMDS exponential backoff
}{
{
name: "Successful Request",
Expand Down Expand Up @@ -228,27 +229,78 @@ func TestRetryFunction(t *testing.T) {
maxRetries: 2,
source: DefaultToIMDS,
},
{
name: "Successful Request IMDS with Exponential Backoff",
mockResponses: []struct {
body string
statusCode int
}{
{"Failed", http.StatusInternalServerError},
{"Failed", http.StatusInternalServerError},
{"Failed", http.StatusInternalServerError},
{"Success", http.StatusOK},
},
expectedStatus: http.StatusOK,
expectedBody: "Success",
maxRetries: 4,
source: DefaultToIMDS,
expectedDelays: []time.Duration{1 * time.Second, 2 * time.Second, 4 * time.Second},
},
{
name: "Successful Request Non-IMDS with Fixed Delay",
mockResponses: []struct {
body string
statusCode int
}{
{"Failed", http.StatusInternalServerError},
{"Success", http.StatusOK},
},
expectedStatus: http.StatusOK,
expectedBody: "Success",
maxRetries: 3,
source: AzureArc, // Non-IMDS source
expectedDelays: []time.Duration{1 * time.Second},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockClient := mock.NewClient()
for _, resp := range tt.mockResponses {
var actualDelays []time.Duration
var lastRequestTime time.Time

for i, resp := range tt.mockResponses {
body := bytes.NewBufferString(resp.body)
mockClient.AppendResponse(mock.WithBody(body.Bytes()), mock.WithHTTPStatusCode(resp.statusCode))
callback := func(r *http.Request) {
if !lastRequestTime.IsZero() {
actualDelays = append(actualDelays, time.Since(lastRequestTime))
}
lastRequestTime = time.Now()
}
// Apply callback only to retryable responses
if i < len(tt.mockResponses)-1 {
mockClient.AppendResponse(mock.WithBody(body.Bytes()), mock.WithHTTPStatusCode(resp.statusCode), mock.WithCallback(callback))
} else {
mockClient.AppendResponse(mock.WithBody(body.Bytes()), mock.WithHTTPStatusCode(resp.statusCode), mock.WithCallback(callback))
}
}
client, err := New(SystemAssigned(), WithHTTPClient(mockClient), WithRetryPolicyDisabled())
client, err := New(SystemAssigned(), WithHTTPClient(mockClient))
if err != nil {
t.Fatal(err)
}
// Manually set the source for testing purposes
client.source = tt.source

reqBody := bytes.NewBufferString("Test Body")
req, err := http.NewRequest("POST", "https://example.com", reqBody)
if err != nil {
t.Fatal(err)
}
finalResp, err := client.retry(tt.maxRetries, req)
if err != nil {
t.Fatal(err)
if tt.expectedStatus != finalResp.StatusCode {
t.Fatal(err)
}
}
if finalResp.StatusCode != tt.expectedStatus {
t.Fatalf("Expected status code %d, got %d", tt.expectedStatus, finalResp.StatusCode)
Expand All @@ -261,6 +313,17 @@ func TestRetryFunction(t *testing.T) {
if string(bodyBytes) != tt.expectedBody {
t.Fatalf("Expected body %q, got %q", tt.expectedBody, bodyBytes)
}

if len(tt.expectedDelays) > 0 {
if len(actualDelays) != len(tt.expectedDelays) {
t.Fatalf("Expected %d delays, got %d. Actual delays: %v", len(tt.expectedDelays), len(actualDelays), actualDelays)
}
for i, expectedDelay := range tt.expectedDelays {
if actualDelays[i] < expectedDelay-500*time.Millisecond || actualDelays[i] > expectedDelay+500*time.Millisecond {
t.Fatalf("Expected delay %v at attempt %d, got %v", expectedDelay, i, actualDelays[i])
}
}
}
})
}
}
Expand Down Expand Up @@ -964,7 +1027,7 @@ func TestAzureArcErrors(t *testing.T) {
},
{
name: "Invalid file path",
headerValue: "Basic realm=" + filepath.Join("path", "to", secretKey),
headerValue: basicRealm + filepath.Join("path", "to", secretKey),
expectedError: "invalid file path, expected " + testCaseFilePath + ", got " + filepath.Join("path", "to"),
},
{
Expand Down