Skip to content

Commit 9344962

Browse files
authored
Add FunctionInvokingChatClient.FunctionInvoker delegate (#6564)
We've had a bunch of requests to be able to customize how function invocation is handled, and while it's already possible today by deriving from FunctionInvokingChatClient and overriding its InvokeFunctionAsync, there's a lot of ceremony involved in that. By having a property on the client instance, that behavior can instead be configured as part of a UseFunctionInvocation call.
1 parent c49b57b commit 9344962

File tree

3 files changed

+89
-1
lines changed

3 files changed

+89
-1
lines changed

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,15 @@ public int MaximumConsecutiveErrorsPerRequest
205205
set => _maximumConsecutiveErrorsPerRequest = Throw.IfLessThan(value, 0);
206206
}
207207

208+
/// <summary>Gets or sets a delegate used to invoke <see cref="AIFunction"/> instances.</summary>
209+
/// <remarks>
210+
/// By default, the protected <see cref="InvokeFunctionAsync"/> method is called for each <see cref="AIFunction"/> to be invoked,
211+
/// invoking the instance and returning its result. If this delegate is set to a non-<see langword="null"/> value,
212+
/// <see cref="InvokeFunctionAsync"/> will replace its normal invocation with a call to this delegate, enabling
213+
/// this delegate to assume all invocation handling of the function.
214+
/// </remarks>
215+
public Func<FunctionInvocationContext, CancellationToken, ValueTask<object?>>? FunctionInvoker { get; set; }
216+
208217
/// <inheritdoc/>
209218
public override async Task<ChatResponse> GetResponseAsync(
210219
IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
@@ -872,7 +881,9 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
872881
{
873882
_ = Throw.IfNull(context);
874883

875-
return context.Function.InvokeAsync(context.Arguments, cancellationToken);
884+
return FunctionInvoker is { } invoker ?
885+
invoker(context, cancellationToken) :
886+
context.Function.InvokeAsync(context.Arguments, cancellationToken);
876887
}
877888

878889
private static TimeSpan GetElapsedTime(long startingTimestamp) =>

src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,10 @@
527527
"Member": "System.IServiceProvider? Microsoft.Extensions.AI.FunctionInvokingChatClient.FunctionInvocationServices { get; }",
528528
"Stage": "Stable"
529529
},
530+
{
531+
"Member": "System.Func<Microsoft.Extensions.AI.FunctionInvocationContext, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask<object?>>? Microsoft.Extensions.AI.FunctionInvokingChatClient.FunctionInvoker { get; set; }",
532+
"Stage": "Stable"
533+
},
530534
{
531535
"Member": "bool Microsoft.Extensions.AI.FunctionInvokingChatClient.IncludeDetailedErrors { get; set; }",
532536
"Stage": "Stable"

test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Collections.Generic;
77
using System.Diagnostics;
88
using System.Linq;
9+
using System.Text.Json;
910
using System.Threading;
1011
using System.Threading.Tasks;
1112
using Microsoft.Extensions.DependencyInjection;
@@ -37,6 +38,35 @@ public void Ctor_HasExpectedDefaults()
3738
Assert.False(client.IncludeDetailedErrors);
3839
Assert.Equal(10, client.MaximumIterationsPerRequest);
3940
Assert.Equal(3, client.MaximumConsecutiveErrorsPerRequest);
41+
Assert.Null(client.FunctionInvoker);
42+
}
43+
44+
[Fact]
45+
public void Properties_Roundtrip()
46+
{
47+
using TestChatClient innerClient = new();
48+
using FunctionInvokingChatClient client = new(innerClient);
49+
50+
Assert.False(client.AllowConcurrentInvocation);
51+
client.AllowConcurrentInvocation = true;
52+
Assert.True(client.AllowConcurrentInvocation);
53+
54+
Assert.False(client.IncludeDetailedErrors);
55+
client.IncludeDetailedErrors = true;
56+
Assert.True(client.IncludeDetailedErrors);
57+
58+
Assert.Equal(10, client.MaximumIterationsPerRequest);
59+
client.MaximumIterationsPerRequest = 5;
60+
Assert.Equal(5, client.MaximumIterationsPerRequest);
61+
62+
Assert.Equal(3, client.MaximumConsecutiveErrorsPerRequest);
63+
client.MaximumConsecutiveErrorsPerRequest = 1;
64+
Assert.Equal(1, client.MaximumConsecutiveErrorsPerRequest);
65+
66+
Assert.Null(client.FunctionInvoker);
67+
Func<FunctionInvocationContext, CancellationToken, ValueTask<object?>> invoker = (ctx, ct) => new ValueTask<object?>("test");
68+
client.FunctionInvoker = invoker;
69+
Assert.Same(invoker, client.FunctionInvoker);
4070
}
4171

4272
[Fact]
@@ -208,6 +238,49 @@ public async Task ConcurrentInvocationOfParallelCallsDisabledByDefaultAsync()
208238
await InvokeAndAssertStreamingAsync(options, plan);
209239
}
210240

241+
[Fact]
242+
public async Task FunctionInvokerDelegateOverridesHandlingAsync()
243+
{
244+
var options = new ChatOptions
245+
{
246+
Tools =
247+
[
248+
AIFunctionFactory.Create(() => "Result 1", "Func1"),
249+
AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
250+
AIFunctionFactory.Create((int i) => { }, "VoidReturn"),
251+
]
252+
};
253+
254+
List<ChatMessage> plan =
255+
[
256+
new ChatMessage(ChatRole.User, "hello"),
257+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
258+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1 from delegate")]),
259+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
260+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42 from delegate")]),
261+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary<string, object?> { { "i", 43 } })]),
262+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]),
263+
new ChatMessage(ChatRole.Assistant, "world"),
264+
];
265+
266+
Func<ChatClientBuilder, ChatClientBuilder> configure = b => b.Use(
267+
s => new FunctionInvokingChatClient(s)
268+
{
269+
FunctionInvoker = async (ctx, cancellationToken) =>
270+
{
271+
Assert.NotNull(ctx);
272+
var result = await ctx.Function.InvokeAsync(ctx.Arguments, cancellationToken);
273+
return result is JsonElement e ?
274+
JsonSerializer.SerializeToElement($"{e.GetString()} from delegate", AIJsonUtilities.DefaultOptions) :
275+
result;
276+
}
277+
});
278+
279+
await InvokeAndAssertAsync(options, plan, configurePipeline: configure);
280+
281+
await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure);
282+
}
283+
211284
[Fact]
212285
public async Task ContinuesWithSuccessfulCallsUntilMaximumIterations()
213286
{

0 commit comments

Comments
 (0)