Skip to content

[azopenai] Errors weren't propagating properly in image generation for OpenAI #21125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/cognitiveservices/azopenai/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "go",
"TagPrefix": "go/cognitiveservices/azopenai",
"Tag": "go/cognitiveservices/azopenai_25f5951837"
"Tag": "go/cognitiveservices/azopenai_2b6f93a94d"
}
25 changes: 24 additions & 1 deletion sdk/cognitiveservices/azopenai/client_chat_completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ var chatCompletionsRequest = azopenai.ChatCompletionsOptions{
},
MaxTokens: to.Ptr(int32(1024)),
Temperature: to.Ptr(float32(0.0)),
Model: &openAIChatCompletionsModelDeployment,
Model: &openAIChatCompletionsModel,
}

var expectedContent = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10."
Expand Down Expand Up @@ -192,3 +192,26 @@ func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
require.ErrorAs(t, err, &respErr)
require.Equal(t, "DeploymentNotFound", respErr.ErrorCode)
}

func TestClient_GetChatCompletionsStream_Error(t *testing.T) {
if recording.GetRecordMode() == recording.PlaybackMode {
t.Skip()
}

doTest := func(t *testing.T, client *azopenai.Client) {
t.Helper()
streamResp, err := client.GetChatCompletionsStream(context.Background(), chatCompletionsRequest, nil)
require.Empty(t, streamResp)
assertResponseIsError(t, err)
}

t.Run("AzureOpenAI", func(t *testing.T) {
client := newBogusAzureOpenAIClient(t, chatCompletionsModelDeployment)
doTest(t, client)
})

t.Run("OpenAI", func(t *testing.T) {
client := newBogusOpenAIClient(t)
doTest(t, client)
})
}
52 changes: 44 additions & 8 deletions sdk/cognitiveservices/azopenai/client_shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"strings"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
Expand All @@ -25,10 +26,10 @@ var (
completionsModelDeployment string // env: AOAI_COMPLETIONS_MODEL_DEPLOYMENT
chatCompletionsModelDeployment string // env: AOAI_CHAT_COMPLETIONS_MODEL_DEPLOYMENT

openAIKey string // env: OPENAI_API_KEY
openAIEndpoint string // env: OPENAI_ENDPOINT
openAICompletionsModelDeployment string // env: OPENAI_CHAT_COMPLETIONS_MODEL
openAIChatCompletionsModelDeployment string // env: OPENAI_COMPLETIONS_MODEL
openAIKey string // env: OPENAI_API_KEY
openAIEndpoint string // env: OPENAI_ENDPOINT
openAICompletionsModel string // env: OPENAI_CHAT_COMPLETIONS_MODEL
openAIChatCompletionsModel string // env: OPENAI_COMPLETIONS_MODEL
)

const fakeEndpoint = "https://recordedhost/"
Expand All @@ -42,10 +43,10 @@ func init() {
openAIEndpoint = fakeEndpoint

completionsModelDeployment = "text-davinci-003"
openAICompletionsModelDeployment = "text-davinci-003"
openAICompletionsModel = "text-davinci-003"

chatCompletionsModelDeployment = "gpt-4"
openAIChatCompletionsModelDeployment = "gpt-4"
openAIChatCompletionsModel = "gpt-4"
} else {
if err := godotenv.Load(); err != nil {
fmt.Printf("Failed to load .env file: %s\n", err)
Expand All @@ -67,8 +68,8 @@ func init() {

openAIKey = os.Getenv("OPENAI_API_KEY")
openAIEndpoint = os.Getenv("OPENAI_ENDPOINT")
openAICompletionsModelDeployment = os.Getenv("OPENAI_COMPLETIONS_MODEL")
openAIChatCompletionsModelDeployment = os.Getenv("OPENAI_CHAT_COMPLETIONS_MODEL")
openAICompletionsModel = os.Getenv("OPENAI_COMPLETIONS_MODEL")
openAIChatCompletionsModel = os.Getenv("OPENAI_CHAT_COMPLETIONS_MODEL")

if openAIEndpoint != "" && !strings.HasSuffix(openAIEndpoint, "/") {
// (this just makes recording replacement easier)
Expand All @@ -88,6 +89,9 @@ func newRecordingTransporter(t *testing.T) policy.Transporter {
err = recording.AddHeaderRegexSanitizer("Api-Key", fakeAPIKey, "", nil)
require.NoError(t, err)

err = recording.AddHeaderRegexSanitizer("User-Agent", "fake-user-agent", ".*", nil)
require.NoError(t, err)

// "RequestUri": "https://openai-shared.openai.azure.com/openai/deployments/text-davinci-003/completions?api-version=2023-03-15-preview",
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(endpoint), nil)
require.NoError(t, err)
Expand Down Expand Up @@ -138,3 +142,35 @@ func newClientOptionsForTest(t *testing.T) *azopenai.ClientOptions {

return co
}

// newBogusAzureOpenAIClient creates a client that uses an invalid key, which will cause Azure OpenAI to return
// a failure.
func newBogusAzureOpenAIClient(t *testing.T, modelDeploymentID string) *azopenai.Client {
cred, err := azopenai.NewKeyCredential("bogus-api-key")
require.NoError(t, err)

client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, modelDeploymentID, newClientOptionsForTest(t))
require.NoError(t, err)
return client
}

// newBogusOpenAIClient creates a client that uses an invalid key, which will cause OpenAI to return
// a failure.
func newBogusOpenAIClient(t *testing.T) *azopenai.Client {
cred, err := azopenai.NewKeyCredential("bogus-api-key")
require.NoError(t, err)

client, err := azopenai.NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)
return client
}

func assertResponseIsError(t *testing.T, err error) {
t.Helper()

var respErr *azcore.ResponseError
require.ErrorAs(t, err, &respErr)

// we sometimes get rate limited but (for this kind of test) it's actually okay
require.Truef(t, respErr.StatusCode == http.StatusUnauthorized || respErr.StatusCode == http.StatusTooManyRequests, "An acceptable error comes back (actual: %d)", respErr.StatusCode)
}
4 changes: 4 additions & 0 deletions sdk/cognitiveservices/azopenai/custom_client_image.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ func generateImageWithOpenAI(ctx context.Context, client *Client, body ImageGene
return CreateImageResponse{}, err
}

if !runtime.HasStatusCode(resp, http.StatusOK) {
return CreateImageResponse{}, runtime.NewResponseError(resp)
}

var gens *ImageGenerations

if err := runtime.UnmarshalAsJSON(resp, &gens); err != nil {
Expand Down
32 changes: 32 additions & 0 deletions sdk/cognitiveservices/azopenai/custom_client_image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,24 @@ func TestImageGeneration_OpenAI(t *testing.T) {
testImageGeneration(t, client, azopenai.ImageGenerationResponseFormatURL)
}

func TestImageGeneration_AzureOpenAI_WithError(t *testing.T) {
if recording.GetRecordMode() == recording.PlaybackMode {
t.Skip()
}

client := newBogusAzureOpenAIClient(t, "")
testImageGenerationFailure(t, client)
}

func TestImageGeneration_OpenAI_WithError(t *testing.T) {
if recording.GetRecordMode() == recording.PlaybackMode {
t.Skip()
}

client := newBogusOpenAIClient(t)
testImageGenerationFailure(t, client)
}

func TestImageGeneration_OpenAI_Base64(t *testing.T) {
client := newOpenAIClientForTest(t)
testImageGeneration(t, client, azopenai.ImageGenerationResponseFormatB64JSON)
Expand Down Expand Up @@ -67,3 +85,17 @@ func testImageGeneration(t *testing.T, client *azopenai.Client, responseFormat a
}
}
}

func testImageGenerationFailure(t *testing.T, bogusClient *azopenai.Client) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()

resp, err := bogusClient.CreateImage(ctx, azopenai.ImageGenerationOptions{
Prompt: to.Ptr("a cat"),
Size: to.Ptr(azopenai.ImageSize256x256),
ResponseFormat: to.Ptr(azopenai.ImageGenerationResponseFormatURL),
}, nil)
require.Empty(t, resp)

assertResponseIsError(t, err)
}
37 changes: 30 additions & 7 deletions sdk/cognitiveservices/azopenai/custom_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -88,12 +89,7 @@ func TestGetCompletionsStream_AzureOpenAI(t *testing.T) {
}

func TestGetCompletionsStream_OpenAI(t *testing.T) {
cred, err := azopenai.NewKeyCredential(openAIKey)
require.NoError(t, err)

client, err := azopenai.NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)

client := newOpenAIClientForTest(t)
testGetCompletionsStream(t, client, false)
}

Expand All @@ -102,7 +98,7 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, isAzure boo
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048)),
Temperature: to.Ptr(float32(0.0)),
Model: to.Ptr(openAICompletionsModelDeployment),
Model: to.Ptr(openAICompletionsModel),
}

response, err := client.GetCompletionsStream(context.TODO(), body, nil)
Expand Down Expand Up @@ -142,3 +138,30 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, isAzure boo
require.Equal(t, want, got)
require.Equal(t, 86, eventCount)
}

func TestClient_GetCompletions_Error(t *testing.T) {
if recording.GetRecordMode() == recording.PlaybackMode {
t.Skip()
}

doTest := func(t *testing.T, client *azopenai.Client) {
streamResp, err := client.GetCompletionsStream(context.Background(), azopenai.CompletionsOptions{
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048 - 127)),
Temperature: to.Ptr(float32(0.0)),
Model: &openAICompletionsModel,
}, nil)
require.Empty(t, streamResp)
assertResponseIsError(t, err)
}

t.Run("AzureOpenAI", func(t *testing.T) {
client := newBogusAzureOpenAIClient(t, completionsModelDeployment)
doTest(t, client)
})

t.Run("OpenAI", func(t *testing.T) {
client := newBogusOpenAIClient(t)
doTest(t, client)
})
}