diff --git a/sdk/cognitiveservices/azopenai/assets.json b/sdk/cognitiveservices/azopenai/assets.json index 08529e3efde5..d0419e0d7f21 100644 --- a/sdk/cognitiveservices/azopenai/assets.json +++ b/sdk/cognitiveservices/azopenai/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "go", "TagPrefix": "go/cognitiveservices/azopenai", - "Tag": "go/cognitiveservices/azopenai_25f5951837" + "Tag": "go/cognitiveservices/azopenai_2b6f93a94d" } diff --git a/sdk/cognitiveservices/azopenai/client_chat_completions_test.go b/sdk/cognitiveservices/azopenai/client_chat_completions_test.go index 1dace8679f04..a0d60c6fa557 100644 --- a/sdk/cognitiveservices/azopenai/client_chat_completions_test.go +++ b/sdk/cognitiveservices/azopenai/client_chat_completions_test.go @@ -31,7 +31,7 @@ var chatCompletionsRequest = azopenai.ChatCompletionsOptions{ }, MaxTokens: to.Ptr(int32(1024)), Temperature: to.Ptr(float32(0.0)), - Model: &openAIChatCompletionsModelDeployment, + Model: &openAIChatCompletionsModel, } var expectedContent = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10." @@ -192,3 +192,26 @@ func TestClient_GetChatCompletions_InvalidModel(t *testing.T) { require.ErrorAs(t, err, &respErr) require.Equal(t, "DeploymentNotFound", respErr.ErrorCode) } + +func TestClient_GetChatCompletionsStream_Error(t *testing.T) { + if recording.GetRecordMode() == recording.PlaybackMode { + t.Skip() + } + + doTest := func(t *testing.T, client *azopenai.Client) { + t.Helper() + streamResp, err := client.GetChatCompletionsStream(context.Background(), chatCompletionsRequest, nil) + require.Empty(t, streamResp) + assertResponseIsError(t, err) + } + + t.Run("AzureOpenAI", func(t *testing.T) { + client := newBogusAzureOpenAIClient(t, chatCompletionsModelDeployment) + doTest(t, client) + }) + + t.Run("OpenAI", func(t *testing.T) { + client := newBogusOpenAIClient(t) + doTest(t, client) + }) +} diff --git a/sdk/cognitiveservices/azopenai/client_shared_test.go b/sdk/cognitiveservices/azopenai/client_shared_test.go index 79ce02791afc..543175a171a7 100644 --- a/sdk/cognitiveservices/azopenai/client_shared_test.go +++ b/sdk/cognitiveservices/azopenai/client_shared_test.go @@ -12,6 +12,7 @@ import ( "strings" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" @@ -25,10 +26,10 @@ var ( completionsModelDeployment string // env: AOAI_COMPLETIONS_MODEL_DEPLOYMENT chatCompletionsModelDeployment string // env: AOAI_CHAT_COMPLETIONS_MODEL_DEPLOYMENT - openAIKey string // env: OPENAI_API_KEY - openAIEndpoint string // env: OPENAI_ENDPOINT - openAICompletionsModelDeployment string // env: OPENAI_CHAT_COMPLETIONS_MODEL - openAIChatCompletionsModelDeployment string // env: OPENAI_COMPLETIONS_MODEL + openAIKey string // env: OPENAI_API_KEY + openAIEndpoint string // env: OPENAI_ENDPOINT + openAICompletionsModel string // env: OPENAI_CHAT_COMPLETIONS_MODEL + openAIChatCompletionsModel string // env: OPENAI_COMPLETIONS_MODEL ) const fakeEndpoint = "https://recordedhost/" @@ -42,10 +43,10 @@ func init() { openAIEndpoint = fakeEndpoint completionsModelDeployment = "text-davinci-003" - openAICompletionsModelDeployment = "text-davinci-003" + openAICompletionsModel = "text-davinci-003" chatCompletionsModelDeployment = "gpt-4" - openAIChatCompletionsModelDeployment = "gpt-4" + openAIChatCompletionsModel = "gpt-4" } else { if err := godotenv.Load(); err != nil { fmt.Printf("Failed to load .env file: %s\n", err) @@ -67,8 +68,8 @@ func init() { openAIKey = os.Getenv("OPENAI_API_KEY") openAIEndpoint = os.Getenv("OPENAI_ENDPOINT") - openAICompletionsModelDeployment = os.Getenv("OPENAI_COMPLETIONS_MODEL") - openAIChatCompletionsModelDeployment = os.Getenv("OPENAI_CHAT_COMPLETIONS_MODEL") + openAICompletionsModel = os.Getenv("OPENAI_COMPLETIONS_MODEL") + openAIChatCompletionsModel = os.Getenv("OPENAI_CHAT_COMPLETIONS_MODEL") if openAIEndpoint != "" && !strings.HasSuffix(openAIEndpoint, "/") { // (this just makes recording replacement easier) @@ -88,6 +89,9 @@ func newRecordingTransporter(t *testing.T) policy.Transporter { err = recording.AddHeaderRegexSanitizer("Api-Key", fakeAPIKey, "", nil) require.NoError(t, err) + err = recording.AddHeaderRegexSanitizer("User-Agent", "fake-user-agent", ".*", nil) + require.NoError(t, err) + // "RequestUri": "https://openai-shared.openai.azure.com/openai/deployments/text-davinci-003/completions?api-version=2023-03-15-preview", err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(endpoint), nil) require.NoError(t, err) @@ -138,3 +142,35 @@ func newClientOptionsForTest(t *testing.T) *azopenai.ClientOptions { return co } + +// newBogusAzureOpenAIClient creates a client that uses an invalid key, which will cause Azure OpenAI to return +// a failure. +func newBogusAzureOpenAIClient(t *testing.T, modelDeploymentID string) *azopenai.Client { + cred, err := azopenai.NewKeyCredential("bogus-api-key") + require.NoError(t, err) + + client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, modelDeploymentID, newClientOptionsForTest(t)) + require.NoError(t, err) + return client +} + +// newBogusOpenAIClient creates a client that uses an invalid key, which will cause OpenAI to return +// a failure. +func newBogusOpenAIClient(t *testing.T) *azopenai.Client { + cred, err := azopenai.NewKeyCredential("bogus-api-key") + require.NoError(t, err) + + client, err := azopenai.NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t)) + require.NoError(t, err) + return client +} + +func assertResponseIsError(t *testing.T, err error) { + t.Helper() + + var respErr *azcore.ResponseError + require.ErrorAs(t, err, &respErr) + + // we sometimes get rate limited but (for this kind of test) it's actually okay + require.Truef(t, respErr.StatusCode == http.StatusUnauthorized || respErr.StatusCode == http.StatusTooManyRequests, "An acceptable error comes back (actual: %d)", respErr.StatusCode) +} diff --git a/sdk/cognitiveservices/azopenai/custom_client_image.go b/sdk/cognitiveservices/azopenai/custom_client_image.go index 0c7127618ea8..e244eb38ea0c 100644 --- a/sdk/cognitiveservices/azopenai/custom_client_image.go +++ b/sdk/cognitiveservices/azopenai/custom_client_image.go @@ -72,6 +72,10 @@ func generateImageWithOpenAI(ctx context.Context, client *Client, body ImageGene return CreateImageResponse{}, err } + if !runtime.HasStatusCode(resp, http.StatusOK) { + return CreateImageResponse{}, runtime.NewResponseError(resp) + } + var gens *ImageGenerations if err := runtime.UnmarshalAsJSON(resp, &gens); err != nil { diff --git a/sdk/cognitiveservices/azopenai/custom_client_image_test.go b/sdk/cognitiveservices/azopenai/custom_client_image_test.go index 4f100ce15815..cb30377a4c8b 100644 --- a/sdk/cognitiveservices/azopenai/custom_client_image_test.go +++ b/sdk/cognitiveservices/azopenai/custom_client_image_test.go @@ -38,6 +38,24 @@ func TestImageGeneration_OpenAI(t *testing.T) { testImageGeneration(t, client, azopenai.ImageGenerationResponseFormatURL) } +func TestImageGeneration_AzureOpenAI_WithError(t *testing.T) { + if recording.GetRecordMode() == recording.PlaybackMode { + t.Skip() + } + + client := newBogusAzureOpenAIClient(t, "") + testImageGenerationFailure(t, client) +} + +func TestImageGeneration_OpenAI_WithError(t *testing.T) { + if recording.GetRecordMode() == recording.PlaybackMode { + t.Skip() + } + + client := newBogusOpenAIClient(t) + testImageGenerationFailure(t, client) +} + func TestImageGeneration_OpenAI_Base64(t *testing.T) { client := newOpenAIClientForTest(t) testImageGeneration(t, client, azopenai.ImageGenerationResponseFormatB64JSON) @@ -67,3 +85,17 @@ func testImageGeneration(t *testing.T, client *azopenai.Client, responseFormat a } } } + +func testImageGenerationFailure(t *testing.T, bogusClient *azopenai.Client) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + resp, err := bogusClient.CreateImage(ctx, azopenai.ImageGenerationOptions{ + Prompt: to.Ptr("a cat"), + Size: to.Ptr(azopenai.ImageSize256x256), + ResponseFormat: to.Ptr(azopenai.ImageGenerationResponseFormatURL), + }, nil) + require.Empty(t, resp) + + assertResponseIsError(t, err) +} diff --git a/sdk/cognitiveservices/azopenai/custom_client_test.go b/sdk/cognitiveservices/azopenai/custom_client_test.go index 3ae4941316ea..2319f93378a1 100644 --- a/sdk/cognitiveservices/azopenai/custom_client_test.go +++ b/sdk/cognitiveservices/azopenai/custom_client_test.go @@ -16,6 +16,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" + "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" "github.com/stretchr/testify/require" ) @@ -88,12 +89,7 @@ func TestGetCompletionsStream_AzureOpenAI(t *testing.T) { } func TestGetCompletionsStream_OpenAI(t *testing.T) { - cred, err := azopenai.NewKeyCredential(openAIKey) - require.NoError(t, err) - - client, err := azopenai.NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t)) - require.NoError(t, err) - + client := newOpenAIClientForTest(t) testGetCompletionsStream(t, client, false) } @@ -102,7 +98,7 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, isAzure boo Prompt: []string{"What is Azure OpenAI?"}, MaxTokens: to.Ptr(int32(2048)), Temperature: to.Ptr(float32(0.0)), - Model: to.Ptr(openAICompletionsModelDeployment), + Model: to.Ptr(openAICompletionsModel), } response, err := client.GetCompletionsStream(context.TODO(), body, nil) @@ -142,3 +138,30 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, isAzure boo require.Equal(t, want, got) require.Equal(t, 86, eventCount) } + +func TestClient_GetCompletions_Error(t *testing.T) { + if recording.GetRecordMode() == recording.PlaybackMode { + t.Skip() + } + + doTest := func(t *testing.T, client *azopenai.Client) { + streamResp, err := client.GetCompletionsStream(context.Background(), azopenai.CompletionsOptions{ + Prompt: []string{"What is Azure OpenAI?"}, + MaxTokens: to.Ptr(int32(2048 - 127)), + Temperature: to.Ptr(float32(0.0)), + Model: &openAICompletionsModel, + }, nil) + require.Empty(t, streamResp) + assertResponseIsError(t, err) + } + + t.Run("AzureOpenAI", func(t *testing.T) { + client := newBogusAzureOpenAIClient(t, completionsModelDeployment) + doTest(t, client) + }) + + t.Run("OpenAI", func(t *testing.T) { + client := newBogusOpenAIClient(t) + doTest(t, client) + }) +}