Skip to content

Commit adb96df

Browse files
authored
Allow a CachingChatClient to control per-request caching (#6524)
1 parent 694b95e commit adb96df

File tree

3 files changed

+55
-14
lines changed

3 files changed

+55
-14
lines changed

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public override Task<ChatResponse> GetResponseAsync(
5151
{
5252
_ = Throw.IfNull(messages);
5353

54-
return UseCaching(options) ?
54+
return EnableCaching(messages, options) ?
5555
GetCachedResponseAsync(messages, options, cancellationToken) :
5656
base.GetResponseAsync(messages, options, cancellationToken);
5757
}
@@ -79,7 +79,7 @@ public override IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
7979
{
8080
_ = Throw.IfNull(messages);
8181

82-
return UseCaching(options) ?
82+
return EnableCaching(messages, options) ?
8383
GetCachedStreamingResponseAsync(messages, options, cancellationToken) :
8484
base.GetStreamingResponseAsync(messages, options, cancellationToken);
8585
}
@@ -196,12 +196,25 @@ private async IAsyncEnumerable<ChatResponseUpdate> GetCachedStreamingResponseAsy
196196
/// <exception cref="ArgumentNullException"><paramref name="value"/> is <see langword="null"/>.</exception>
197197
protected abstract Task WriteCacheStreamingAsync(string key, IReadOnlyList<ChatResponseUpdate> value, CancellationToken cancellationToken);
198198

199-
/// <summary>Determine whether to use caching with the request.</summary>
200-
private static bool UseCaching(ChatOptions? options)
199+
/// <summary>Determines whether caching should be used with the specified request.</summary>
200+
/// <param name="messages">The sequence of chat messages included in the request.</param>
201+
/// <param name="options">The chat options included in the request.</param>
202+
/// <returns>
203+
/// <see langword="true"/> if caching should be used for the request, such that the <see cref="CachingChatClient"/>
204+
/// will try to satisfy the request from the cache, or if it can't, will try to cache the fetched response.
205+
/// <see langword="false"/> if caching should not be used for the request, such that the request will
206+
/// be passed through to the inner <see cref="IChatClient"/> without attempting to read from or write to the cache.
207+
/// </returns>
208+
/// <remarks>
209+
/// The default implementation returns <see langword="true"/> as long as the <paramref name="options"/>
210+
/// does not have a <see cref="ChatOptions.ConversationId"/> set.
211+
/// </remarks>
212+
protected virtual bool EnableCaching(IEnumerable<ChatMessage> messages, ChatOptions? options)
201213
{
202214
// We want to skip caching if options.ConversationId is set. If it's set, that implies there's
203215
// some state that will impact the response and that's not represented in the messages. Since
204-
// that state could change even with the same ID, we have to assume caching isn't valid.
216+
// that state could change even with the same ID (e.g. if it's a thread ID representing the
217+
// mutable state of a conversation), we have to assume caching isn't valid.
205218
return options?.ConversationId is null;
206219
}
207220
}

src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
"Member": "abstract string Microsoft.Extensions.AI.CachingChatClient.GetCacheKey(System.Collections.Generic.IEnumerable<Microsoft.Extensions.AI.ChatMessage> messages, Microsoft.Extensions.AI.ChatOptions? options, params System.ReadOnlySpan<object?> additionalValues);",
1717
"Stage": "Stable"
1818
},
19+
{
20+
"Member": "virtual bool Microsoft.Extensions.AI.CachingChatClient.EnableCaching(System.Collections.Generic.IEnumerable<Microsoft.Extensions.AI.ChatMessage> messages, Microsoft.Extensions.AI.ChatOptions? options);",
21+
"Stage": "Stable"
22+
},
1923
{
2024
"Member": "override System.Threading.Tasks.Task<Microsoft.Extensions.AI.ChatResponse> Microsoft.Extensions.AI.CachingChatClient.GetResponseAsync(System.Collections.Generic.IEnumerable<Microsoft.Extensions.AI.ChatMessage> messages, Microsoft.Extensions.AI.ChatOptions? options = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken));",
2125
"Stage": "Stable"

test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ public void Ctor_ExpectedDefaults()
3333
}
3434

3535
[Theory]
36-
[InlineData(false)]
37-
[InlineData(true)]
38-
public async Task CachesSuccessResultsAsync(bool conversationIdSet)
36+
[InlineData(false, false)]
37+
[InlineData(false, true)]
38+
[InlineData(true, false)]
39+
[InlineData(true, true)]
40+
public async Task CachesSuccessResultsAsync(bool conversationIdSet, bool customCaching)
3941
{
4042
// Arrange
4143
ChatOptions options = new() { ConversationId = conversationIdSet ? "123" : null };
@@ -79,10 +81,16 @@ public async Task CachesSuccessResultsAsync(bool conversationIdSet)
7981
return Task.FromResult(expectedResponse);
8082
}
8183
};
82-
using var outer = new DistributedCachingChatClient(testClient, _storage)
83-
{
84-
JsonSerializerOptions = TestJsonSerializerContext.Default.Options
85-
};
84+
85+
int enableCachingInvocations = 0;
86+
using var outer = customCaching ?
87+
new CustomCachingChatClient(testClient, _storage, (m, o) =>
88+
{
89+
return ++enableCachingInvocations % 2 == 0;
90+
}) :
91+
new DistributedCachingChatClient(testClient, _storage);
92+
93+
outer.JsonSerializerOptions = TestJsonSerializerContext.Default.Options;
8694

8795
// Make the initial request and do a quick sanity check
8896
var result1 = await outer.GetResponseAsync("some input", options);
@@ -93,12 +101,28 @@ public async Task CachesSuccessResultsAsync(bool conversationIdSet)
93101
var result2 = await outer.GetResponseAsync("some input", options);
94102

95103
// Assert
96-
Assert.Equal(conversationIdSet ? 2 : 1, innerCallCount);
104+
if (customCaching)
105+
{
106+
Assert.Equal(enableCachingInvocations % 2 == 0 ? 2 : 1, innerCallCount);
107+
}
108+
else
109+
{
110+
Assert.Equal(conversationIdSet ? 2 : 1, innerCallCount);
111+
}
112+
97113
AssertResponsesEqual(expectedResponse, result2);
98114

99115
// Act/Assert 2: Cache misses do not return cached results
100116
await outer.GetResponseAsync("some modified input", options);
101-
Assert.Equal(conversationIdSet ? 3 : 2, innerCallCount);
117+
Assert.Equal(conversationIdSet || customCaching ? 3 : 2, innerCallCount);
118+
119+
Assert.Equal(customCaching ? 3 : 0, enableCachingInvocations);
120+
}
121+
122+
private sealed class CustomCachingChatClient(IChatClient innerClient, IDistributedCache storage, Func<IEnumerable<ChatMessage>, ChatOptions?, bool> enableCaching) :
123+
DistributedCachingChatClient(innerClient, storage)
124+
{
125+
protected override bool EnableCaching(IEnumerable<ChatMessage> messages, ChatOptions? options) => enableCaching(messages, options);
102126
}
103127

104128
[Fact]

0 commit comments

Comments
 (0)