Skip to content

Commit 34cdd3a

Browse files
authored
Bring back AsIChatClient for OpenAI AssistantClient (#6501)
* Bring back AsIChatClient for OpenAI Assistantclient * Address PR feedback
1 parent 3f61ada commit 34cdd3a

File tree

9 files changed

+648
-26
lines changed

9 files changed

+648
-26
lines changed

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantChatClient.cs

Lines changed: 446 additions & 0 deletions
Large diffs are not rendered by default.

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@ internal sealed partial class OpenAIChatClient : IChatClient
3434
MoveDefaultKeywordToDescription = true,
3535
});
3636

37-
/// <summary>Gets the default OpenAI endpoint.</summary>
38-
private static Uri DefaultOpenAIEndpoint { get; } = new("https://api.openai.com/v1");
39-
4037
/// <summary>Metadata about the client.</summary>
4138
private readonly ChatClientMetadata _metadata;
4239

@@ -57,7 +54,7 @@ public OpenAIChatClient(ChatClient chatClient)
5754
// implement the abstractions directly rather than providing adapters on top of the public APIs,
5855
// the package can provide such implementations separate from what's exposed in the public API.
5956
Uri providerUrl = typeof(ChatClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
60-
?.GetValue(chatClient) as Uri ?? DefaultOpenAIEndpoint;
57+
?.GetValue(chatClient) as Uri ?? OpenAIResponseChatClient.DefaultOpenAIEndpoint;
6158
string? model = typeof(ChatClient).GetField("_model", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
6259
?.GetValue(chatClient) as string;
6360

@@ -113,8 +110,6 @@ void IDisposable.Dispose()
113110
// Nothing to dispose. Implementation required for the IChatClient interface.
114111
}
115112

116-
private static ChatRole ChatRoleDeveloper { get; } = new ChatRole("developer");
117-
118113
/// <summary>Converts an Extensions chat message enumerable to an OpenAI chat message enumerable.</summary>
119114
private static IEnumerable<OpenAI.Chat.ChatMessage> ToOpenAIChatMessages(IEnumerable<ChatMessage> inputs, JsonSerializerOptions options)
120115
{
@@ -125,12 +120,12 @@ void IDisposable.Dispose()
125120
{
126121
if (input.Role == ChatRole.System ||
127122
input.Role == ChatRole.User ||
128-
input.Role == ChatRoleDeveloper)
123+
input.Role == OpenAIResponseChatClient.ChatRoleDeveloper)
129124
{
130125
var parts = ToOpenAIChatContent(input.Contents);
131126
yield return
132127
input.Role == ChatRole.System ? new SystemChatMessage(parts) { ParticipantName = input.AuthorName } :
133-
input.Role == ChatRoleDeveloper ? new DeveloperChatMessage(parts) { ParticipantName = input.AuthorName } :
128+
input.Role == OpenAIResponseChatClient.ChatRoleDeveloper ? new DeveloperChatMessage(parts) { ParticipantName = input.AuthorName } :
134129
new UserChatMessage(parts) { ParticipantName = input.AuthorName };
135130
}
136131
else if (input.Role == ChatRole.Tool)
@@ -622,7 +617,7 @@ private static ChatRole FromOpenAIChatRole(ChatMessageRole role) =>
622617
ChatMessageRole.User => ChatRole.User,
623618
ChatMessageRole.Assistant => ChatRole.Assistant,
624619
ChatMessageRole.Tool => ChatRole.Tool,
625-
ChatMessageRole.Developer => ChatRoleDeveloper,
620+
ChatMessageRole.Developer => OpenAIResponseChatClient.ChatRoleDeveloper,
626621
_ => new ChatRole(role.ToString()),
627622
};
628623

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System.Diagnostics.CodeAnalysis;
55
using OpenAI;
6+
using OpenAI.Assistants;
67
using OpenAI.Audio;
78
using OpenAI.Chat;
89
using OpenAI.Embeddings;
@@ -25,6 +26,19 @@ public static IChatClient AsIChatClient(this ChatClient chatClient) =>
2526
public static IChatClient AsIChatClient(this OpenAIResponseClient responseClient) =>
2627
new OpenAIResponseChatClient(responseClient);
2728

29+
/// <summary>Gets an <see cref="IChatClient"/> for use with this <see cref="AssistantClient"/>.</summary>
30+
/// <param name="assistantClient">The <see cref="AssistantClient"/> instance to be accessed as an <see cref="IChatClient"/>.</param>
31+
/// <param name="assistantId">The unique identifier of the assistant with which to interact.</param>
32+
/// <param name="threadId">
33+
/// An optional existing thread identifier for the chat session. This serves as a default, and may be overridden per call to
34+
/// <see cref="IChatClient.GetResponseAsync"/> or <see cref="IChatClient.GetStreamingResponseAsync"/> via the <see cref="ChatOptions.ConversationId"/>
35+
/// property. If no thread ID is provided via either mechanism, a new thread will be created for the request.
36+
/// </param>
37+
/// <returns>An <see cref="IChatClient"/> instance configured to interact with the specified agent and thread.</returns>
38+
[Experimental("OPENAI001")]
39+
public static IChatClient AsIChatClient(this AssistantClient assistantClient, string assistantId, string? threadId = null) =>
40+
new OpenAIAssistantChatClient(assistantClient, assistantId, threadId);
41+
2842
/// <summary>Gets an <see cref="ISpeechToTextClient"/> for use with this <see cref="AudioClient"/>.</summary>
2943
/// <param name="audioClient">The client.</param>
3044
/// <returns>An <see cref="ISpeechToTextClient"/> that can be used to transcribe audio via the <see cref="AudioClient"/>.</returns>

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponseChatClient.cs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ namespace Microsoft.Extensions.AI;
2727
internal sealed partial class OpenAIResponseChatClient : IChatClient
2828
{
2929
/// <summary>Gets the default OpenAI endpoint.</summary>
30-
private static Uri DefaultOpenAIEndpoint { get; } = new("https://api.openai.com/v1");
30+
internal static Uri DefaultOpenAIEndpoint { get; } = new("https://api.openai.com/v1");
3131

32-
/// <summary>A <see cref="ChatRole"/> for "developer".</summary>
33-
private static readonly ChatRole _chatRoleDeveloper = new("developer");
32+
/// <summary>Gets a <see cref="ChatRole"/> for "developer".</summary>
33+
internal static ChatRole ChatRoleDeveloper { get; } = new ChatRole("developer");
3434

3535
/// <summary>Metadata about the client.</summary>
3636
private readonly ChatClientMetadata _metadata;
@@ -88,7 +88,7 @@ public async Task<ChatResponse> GetResponseAsync(
8888
// Convert and return the results.
8989
ChatResponse response = new()
9090
{
91-
ConversationId = openAIResponse.Id,
91+
ConversationId = openAIOptions.StoredOutputEnabled is false ? null : openAIResponse.Id,
9292
CreatedAt = openAIResponse.CreatedAt,
9393
FinishReason = ToFinishReason(openAIResponse.IncompleteStatusDetails?.Reason),
9494
Messages = [new(ChatRole.Assistant, [])],
@@ -167,6 +167,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
167167
// Make the call to the OpenAIResponseClient and process the streaming results.
168168
DateTimeOffset? createdAt = null;
169169
string? responseId = null;
170+
string? conversationId = null;
170171
string? modelId = null;
171172
string? lastMessageId = null;
172173
ChatRole? lastRole = null;
@@ -179,18 +180,19 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
179180
case StreamingResponseCreatedUpdate createdUpdate:
180181
createdAt = createdUpdate.Response.CreatedAt;
181182
responseId = createdUpdate.Response.Id;
183+
conversationId = openAIOptions.StoredOutputEnabled is false ? null : responseId;
182184
modelId = createdUpdate.Response.Model;
183185
goto default;
184186

185187
case StreamingResponseCompletedUpdate completedUpdate:
186188
yield return new()
187189
{
190+
Contents = ToUsageDetails(completedUpdate.Response) is { } usage ? [new UsageContent(usage)] : [],
191+
ConversationId = conversationId,
192+
CreatedAt = createdAt,
188193
FinishReason =
189194
ToFinishReason(completedUpdate.Response?.IncompleteStatusDetails?.Reason) ??
190195
(functionCallInfos is not null ? ChatFinishReason.ToolCalls : ChatFinishReason.Stop),
191-
Contents = ToUsageDetails(completedUpdate.Response) is { } usage ? [new UsageContent(usage)] : [],
192-
ConversationId = responseId,
193-
CreatedAt = createdAt,
194196
MessageId = lastMessageId,
195197
ModelId = modelId,
196198
RawRepresentation = streamingUpdate,
@@ -223,7 +225,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
223225
lastRole = ToChatRole(messageItem?.Role);
224226
yield return new ChatResponseUpdate(lastRole, outputTextDeltaUpdate.Delta)
225227
{
226-
ConversationId = responseId,
228+
ConversationId = conversationId,
227229
CreatedAt = createdAt,
228230
MessageId = lastMessageId,
229231
ModelId = modelId,
@@ -258,7 +260,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
258260
lastRole = ChatRole.Assistant;
259261
yield return new ChatResponseUpdate(lastRole, [fci])
260262
{
261-
ConversationId = responseId,
263+
ConversationId = conversationId,
262264
CreatedAt = createdAt,
263265
MessageId = lastMessageId,
264266
ModelId = modelId,
@@ -275,7 +277,6 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
275277
case StreamingResponseErrorUpdate errorUpdate:
276278
yield return new ChatResponseUpdate
277279
{
278-
ConversationId = responseId,
279280
Contents =
280281
[
281282
new ErrorContent(errorUpdate.Message)
@@ -284,6 +285,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
284285
Details = errorUpdate.Param,
285286
}
286287
],
288+
ConversationId = conversationId,
287289
CreatedAt = createdAt,
288290
MessageId = lastMessageId,
289291
ModelId = modelId,
@@ -296,21 +298,21 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
296298
case StreamingResponseRefusalDoneUpdate refusalDone:
297299
yield return new ChatResponseUpdate
298300
{
301+
Contents = [new ErrorContent(refusalDone.Refusal) { ErrorCode = nameof(ResponseContentPart.Refusal) }],
302+
ConversationId = conversationId,
299303
CreatedAt = createdAt,
300304
MessageId = lastMessageId,
301305
ModelId = modelId,
302306
RawRepresentation = streamingUpdate,
303307
ResponseId = responseId,
304308
Role = lastRole,
305-
ConversationId = responseId,
306-
Contents = [new ErrorContent(refusalDone.Refusal) { ErrorCode = nameof(ResponseContentPart.Refusal) }],
307309
};
308310
break;
309311

310312
default:
311313
yield return new ChatResponseUpdate
312314
{
313-
ConversationId = responseId,
315+
ConversationId = conversationId,
314316
CreatedAt = createdAt,
315317
MessageId = lastMessageId,
316318
ModelId = modelId,
@@ -334,7 +336,7 @@ private static ChatRole ToChatRole(MessageRole? role) =>
334336
role switch
335337
{
336338
MessageRole.System => ChatRole.System,
337-
MessageRole.Developer => _chatRoleDeveloper,
339+
MessageRole.Developer => ChatRoleDeveloper,
338340
MessageRole.User => ChatRole.User,
339341
_ => ChatRole.Assistant,
340342
};
@@ -452,7 +454,7 @@ private static IEnumerable<ResponseItem> ToOpenAIResponseItems(
452454
foreach (ChatMessage input in inputs)
453455
{
454456
if (input.Role == ChatRole.System ||
455-
input.Role == _chatRoleDeveloper)
457+
input.Role == ChatRoleDeveloper)
456458
{
457459
string text = input.Text;
458460
if (!string.IsNullOrWhiteSpace(text))

test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,9 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputUnchange
618618

619619
// Second time, the calls to the LLM don't happen, but the function is called again
620620
var secondResponse = await chatClient.GetResponseAsync([message]);
621-
Assert.Equal(response.Text, secondResponse.Text);
622621
Assert.Equal(2, functionCallCount);
623622
Assert.Equal(FunctionInvokingChatClientSetsConversationId ? 3 : 2, llmCallCount!.CallCount);
623+
Assert.Equal(response.Text, secondResponse.Text);
624624
}
625625

626626
public virtual bool FunctionInvokingChatClientSetsConversationId => false;

test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ private static BinaryEmbedding QuantizeToBinary(Embedding<float> embedding)
5252
{
5353
if (vector[i] > 0)
5454
{
55-
result[i / 8] = true;
55+
result[i] = true;
5656
}
5757
}
5858

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
5+
#pragma warning disable CA1822 // Mark members as static
6+
#pragma warning disable CA2000 // Dispose objects before losing scope
7+
#pragma warning disable S1135 // Track uses of "TODO" tags
8+
#pragma warning disable xUnit1013 // Public method should be marked as test
9+
10+
using System;
11+
using System.Linq;
12+
using System.Net;
13+
using System.Net.Http;
14+
using System.Text.RegularExpressions;
15+
using System.Threading.Tasks;
16+
using OpenAI.Assistants;
17+
using Xunit;
18+
19+
namespace Microsoft.Extensions.AI;
20+
21+
public class OpenAIAssistantChatClientIntegrationTests : ChatClientIntegrationTests
22+
{
23+
protected override IChatClient? CreateChatClient()
24+
{
25+
var openAIClient = IntegrationTestHelpers.GetOpenAIClient();
26+
if (openAIClient is null)
27+
{
28+
return null;
29+
}
30+
31+
AssistantClient ac = openAIClient.GetAssistantClient();
32+
var assistant =
33+
ac.GetAssistants().FirstOrDefault() ??
34+
ac.CreateAssistant("gpt-4o-mini");
35+
36+
return ac.AsIChatClient(assistant.Id);
37+
}
38+
39+
public override bool FunctionInvokingChatClientSetsConversationId => true;
40+
41+
// These tests aren't written in a way that works well with threads.
42+
public override Task Caching_AfterFunctionInvocation_FunctionOutputChangedAsync() => Task.CompletedTask;
43+
public override Task Caching_AfterFunctionInvocation_FunctionOutputUnchangedAsync() => Task.CompletedTask;
44+
45+
// Assistants doesn't support data URIs.
46+
public override Task MultiModal_DescribeImage() => Task.CompletedTask;
47+
public override Task MultiModal_DescribePdf() => Task.CompletedTask;
48+
49+
// [Fact] // uncomment and run to clear out _all_ threads in your OpenAI account
50+
public async Task DeleteAllThreads()
51+
{
52+
using HttpClient client = new(new HttpClientHandler
53+
{
54+
AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate,
55+
});
56+
57+
// These values need to be filled in. The bearer token needs to be sniffed from a browser
58+
// session interacting with the dashboard (e.g. use F12 networking tools to look at request headers
59+
// made to "https://api.openai.com/v1/threads?limit=10" after clicking on Assistants | Threads in the
60+
// OpenAI portal dashboard).
61+
client.DefaultRequestHeaders.Add("authorization", $"Bearer sess-ENTERYOURSESSIONTOKEN");
62+
client.DefaultRequestHeaders.Add("openai-organization", "org-ENTERYOURORGID");
63+
client.DefaultRequestHeaders.Add("openai-project", "proj_ENTERYOURPROJECTID");
64+
65+
AssistantClient ac = new AssistantClient(Environment.GetEnvironmentVariable("AI:OpenAI:ApiKey")!);
66+
while (true)
67+
{
68+
string listing = await client.GetStringAsync("https://api.openai.com/v1/threads?limit=100");
69+
70+
var matches = Regex.Matches(listing, @"thread_\w+");
71+
if (matches.Count == 0)
72+
{
73+
break;
74+
}
75+
76+
foreach (Match m in matches)
77+
{
78+
var dr = await ac.DeleteThreadAsync(m.Value);
79+
Assert.True(dr.Value.Deleted);
80+
}
81+
}
82+
}
83+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.ClientModel;
6+
using Azure.AI.OpenAI;
7+
using Microsoft.Extensions.Caching.Distributed;
8+
using Microsoft.Extensions.Caching.Memory;
9+
using OpenAI;
10+
using OpenAI.Assistants;
11+
using Xunit;
12+
13+
#pragma warning disable S103 // Lines should not be too long
14+
#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
15+
16+
namespace Microsoft.Extensions.AI;
17+
18+
public class OpenAIAssistantChatClientTests
19+
{
20+
[Fact]
21+
public void AsIChatClient_InvalidArgs_Throws()
22+
{
23+
Assert.Throws<ArgumentNullException>("assistantClient", () => ((AssistantClient)null!).AsIChatClient("assistantId"));
24+
Assert.Throws<ArgumentNullException>("assistantId", () => new AssistantClient("ignored").AsIChatClient(null!));
25+
}
26+
27+
[Theory]
28+
[InlineData(false)]
29+
[InlineData(true)]
30+
public void AsIChatClient_OpenAIClient_ProducesExpectedMetadata(bool useAzureOpenAI)
31+
{
32+
Uri endpoint = new("http://localhost/some/endpoint");
33+
34+
var client = useAzureOpenAI ?
35+
new AzureOpenAIClient(endpoint, new ApiKeyCredential("key")) :
36+
new OpenAIClient(new ApiKeyCredential("key"), new OpenAIClientOptions { Endpoint = endpoint });
37+
38+
IChatClient[] clients =
39+
[
40+
client.GetAssistantClient().AsIChatClient("assistantId"),
41+
client.GetAssistantClient().AsIChatClient("assistantId", "threadId"),
42+
];
43+
44+
foreach (var chatClient in clients)
45+
{
46+
var metadata = chatClient.GetService<ChatClientMetadata>();
47+
Assert.Equal("openai", metadata?.ProviderName);
48+
Assert.Equal(endpoint, metadata?.ProviderUri);
49+
}
50+
}
51+
52+
[Fact]
53+
public void GetService_AssistantClient_SuccessfullyReturnsUnderlyingClient()
54+
{
55+
AssistantClient assistantClient = new OpenAIClient("key").GetAssistantClient();
56+
IChatClient chatClient = assistantClient.AsIChatClient("assistantId");
57+
58+
Assert.Same(assistantClient, chatClient.GetService<AssistantClient>());
59+
60+
Assert.Null(chatClient.GetService<OpenAIClient>());
61+
62+
using IChatClient pipeline = chatClient
63+
.AsBuilder()
64+
.UseFunctionInvocation()
65+
.UseOpenTelemetry()
66+
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
67+
.Build();
68+
69+
Assert.NotNull(pipeline.GetService<FunctionInvokingChatClient>());
70+
Assert.NotNull(pipeline.GetService<DistributedCachingChatClient>());
71+
Assert.NotNull(pipeline.GetService<CachingChatClient>());
72+
Assert.NotNull(pipeline.GetService<OpenTelemetryChatClient>());
73+
74+
Assert.Same(assistantClient, pipeline.GetService<AssistantClient>());
75+
Assert.IsType<FunctionInvokingChatClient>(pipeline.GetService<IChatClient>());
76+
}
77+
}

0 commit comments

Comments
 (0)