Skip to content

Commit 3f61ada

Browse files
authored
Improve handling of RawRepresentation in OpenAI{Response}ChatClient (#6500)
1 parent 233aa3a commit 3f61ada

File tree

3 files changed

+88
-28
lines changed

3 files changed

+88
-28
lines changed

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,9 @@ private static List<ChatMessageContentPart> ToOpenAIChatContent(IList<AIContent>
261261

262262
case DataContent dataContent when dataContent.MediaType.StartsWith("application/pdf", StringComparison.OrdinalIgnoreCase):
263263
return ChatMessageContentPart.CreateFilePart(BinaryData.FromBytes(dataContent.Data), dataContent.MediaType, $"{Guid.NewGuid():N}.pdf");
264+
265+
case AIContent when content.RawRepresentation is ChatMessageContentPart rawContentPart:
266+
return rawContentPart;
264267
}
265268

266269
return null;

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponseChatClient.cs

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
using OpenAI.Responses;
1616
using static Microsoft.Extensions.AI.OpenAIChatClient;
1717

18+
#pragma warning disable S907 // "goto" statement should not be used
1819
#pragma warning disable S1067 // Expressions should not be too complex
1920
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
2021
#pragma warning disable S3604 // Member initializer values should not be redundant
@@ -87,12 +88,13 @@ public async Task<ChatResponse> GetResponseAsync(
8788
// Convert and return the results.
8889
ChatResponse response = new()
8990
{
90-
ResponseId = openAIResponse.Id,
9191
ConversationId = openAIResponse.Id,
9292
CreatedAt = openAIResponse.CreatedAt,
9393
FinishReason = ToFinishReason(openAIResponse.IncompleteStatusDetails?.Reason),
9494
Messages = [new(ChatRole.Assistant, [])],
9595
ModelId = openAIResponse.Model,
96+
RawRepresentation = openAIResponse,
97+
ResponseId = openAIResponse.Id,
9698
Usage = ToUsageDetails(openAIResponse),
9799
};
98100

@@ -125,12 +127,20 @@ public async Task<ChatResponse> GetResponseAsync(
125127

126128
case FunctionCallResponseItem functionCall:
127129
response.FinishReason ??= ChatFinishReason.ToolCalls;
128-
message.Contents.Add(
129-
FunctionCallContent.CreateFromParsedArguments(
130-
functionCall.FunctionArguments.ToMemory(),
131-
functionCall.CallId,
132-
functionCall.FunctionName,
133-
static json => JsonSerializer.Deserialize(json.Span, ResponseClientJsonContext.Default.IDictionaryStringObject)!));
130+
var fcc = FunctionCallContent.CreateFromParsedArguments(
131+
functionCall.FunctionArguments.ToMemory(),
132+
functionCall.CallId,
133+
functionCall.FunctionName,
134+
static json => JsonSerializer.Deserialize(json.Span, ResponseClientJsonContext.Default.IDictionaryStringObject)!);
135+
fcc.RawRepresentation = outputItem;
136+
message.Contents.Add(fcc);
137+
break;
138+
139+
default:
140+
message.Contents.Add(new()
141+
{
142+
RawRepresentation = outputItem,
143+
});
134144
break;
135145
}
136146
}
@@ -170,20 +180,21 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
170180
createdAt = createdUpdate.Response.CreatedAt;
171181
responseId = createdUpdate.Response.Id;
172182
modelId = createdUpdate.Response.Model;
173-
break;
183+
goto default;
174184

175185
case StreamingResponseCompletedUpdate completedUpdate:
176186
yield return new()
177187
{
178-
Contents = ToUsageDetails(completedUpdate.Response) is { } usage ? [new UsageContent(usage)] : [],
179-
CreatedAt = createdAt,
180-
ResponseId = responseId,
181-
ConversationId = responseId,
182188
FinishReason =
183189
ToFinishReason(completedUpdate.Response?.IncompleteStatusDetails?.Reason) ??
184190
(functionCallInfos is not null ? ChatFinishReason.ToolCalls : ChatFinishReason.Stop),
191+
Contents = ToUsageDetails(completedUpdate.Response) is { } usage ? [new UsageContent(usage)] : [],
192+
ConversationId = responseId,
193+
CreatedAt = createdAt,
185194
MessageId = lastMessageId,
186195
ModelId = modelId,
196+
RawRepresentation = streamingUpdate,
197+
ResponseId = responseId,
187198
Role = lastRole,
188199
};
189200
break;
@@ -200,23 +211,24 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
200211
break;
201212
}
202213

203-
break;
214+
goto default;
204215

205216
case StreamingResponseOutputItemDoneUpdate outputItemDoneUpdate:
206217
_ = outputIndexToMessages.Remove(outputItemDoneUpdate.OutputIndex);
207-
break;
218+
goto default;
208219

209220
case StreamingResponseOutputTextDeltaUpdate outputTextDeltaUpdate:
210221
_ = outputIndexToMessages.TryGetValue(outputTextDeltaUpdate.OutputIndex, out MessageResponseItem? messageItem);
211222
lastMessageId = messageItem?.Id;
212223
lastRole = ToChatRole(messageItem?.Role);
213224
yield return new ChatResponseUpdate(lastRole, outputTextDeltaUpdate.Delta)
214225
{
226+
ConversationId = responseId,
215227
CreatedAt = createdAt,
216228
MessageId = lastMessageId,
217229
ModelId = modelId,
230+
RawRepresentation = streamingUpdate,
218231
ResponseId = responseId,
219-
ConversationId = responseId,
220232
};
221233
break;
222234

@@ -227,7 +239,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
227239
_ = (callInfo.Arguments ??= new()).Append(functionCallArgumentsDeltaUpdate.Delta);
228240
}
229241

230-
break;
242+
goto default;
231243
}
232244

233245
case StreamingResponseFunctionCallArgumentsDoneUpdate functionCallOutputDoneUpdate:
@@ -246,25 +258,23 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
246258
lastRole = ChatRole.Assistant;
247259
yield return new ChatResponseUpdate(lastRole, [fci])
248260
{
261+
ConversationId = responseId,
249262
CreatedAt = createdAt,
250263
MessageId = lastMessageId,
251264
ModelId = modelId,
265+
RawRepresentation = streamingUpdate,
252266
ResponseId = responseId,
253-
ConversationId = responseId,
254267
};
268+
269+
break;
255270
}
256271

257-
break;
272+
goto default;
258273
}
259274

260275
case StreamingResponseErrorUpdate errorUpdate:
261276
yield return new ChatResponseUpdate
262277
{
263-
CreatedAt = createdAt,
264-
MessageId = lastMessageId,
265-
ModelId = modelId,
266-
ResponseId = responseId,
267-
Role = lastRole,
268278
ConversationId = responseId,
269279
Contents =
270280
[
@@ -274,6 +284,12 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
274284
Details = errorUpdate.Param,
275285
}
276286
],
287+
CreatedAt = createdAt,
288+
MessageId = lastMessageId,
289+
ModelId = modelId,
290+
RawRepresentation = streamingUpdate,
291+
ResponseId = responseId,
292+
Role = lastRole,
277293
};
278294
break;
279295

@@ -283,12 +299,26 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
283299
CreatedAt = createdAt,
284300
MessageId = lastMessageId,
285301
ModelId = modelId,
302+
RawRepresentation = streamingUpdate,
286303
ResponseId = responseId,
287304
Role = lastRole,
288305
ConversationId = responseId,
289306
Contents = [new ErrorContent(refusalDone.Refusal) { ErrorCode = nameof(ResponseContentPart.Refusal) }],
290307
};
291308
break;
309+
310+
default:
311+
yield return new ChatResponseUpdate
312+
{
313+
ConversationId = responseId,
314+
CreatedAt = createdAt,
315+
MessageId = lastMessageId,
316+
ModelId = modelId,
317+
RawRepresentation = streamingUpdate,
318+
ResponseId = responseId,
319+
Role = lastRole,
320+
};
321+
break;
292322
}
293323
}
294324
}
@@ -487,6 +517,10 @@ private static IEnumerable<ResponseItem> ToOpenAIResponseItems(
487517
callContent.Arguments,
488518
AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IDictionary<string, object?>)))));
489519
break;
520+
521+
case AIContent when item.RawRepresentation is ResponseItem rawRep:
522+
yield return rawRep;
523+
break;
490524
}
491525
}
492526

@@ -530,11 +564,25 @@ private static List<AIContent> ToAIContents(IEnumerable<ResponseContentPart> con
530564
switch (part.Kind)
531565
{
532566
case ResponseContentPartKind.OutputText:
533-
results.Add(new TextContent(part.Text));
567+
results.Add(new TextContent(part.Text)
568+
{
569+
RawRepresentation = part,
570+
});
534571
break;
535572

536573
case ResponseContentPartKind.Refusal:
537-
results.Add(new ErrorContent(part.Refusal) { ErrorCode = nameof(ResponseContentPartKind.Refusal) });
574+
results.Add(new ErrorContent(part.Refusal)
575+
{
576+
ErrorCode = nameof(ResponseContentPartKind.Refusal),
577+
RawRepresentation = part,
578+
});
579+
break;
580+
581+
default:
582+
results.Add(new()
583+
{
584+
RawRepresentation = part,
585+
});
538586
break;
539587
}
540588
}
@@ -570,6 +618,10 @@ private static List<ResponseContentPart> ToOpenAIResponsesContent(IList<AIConten
570618
case ErrorContent errorContent when errorContent.ErrorCode == nameof(ResponseContentPartKind.Refusal):
571619
parts.Add(ResponseContentPart.CreateRefusalPart(errorContent.Message));
572620
break;
621+
622+
case AIContent when content.RawRepresentation is ResponseContentPart rawRep:
623+
parts.Add(rawRep);
624+
break;
573625
}
574626
}
575627

test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIResponseClientTests.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,19 +264,24 @@ public async Task BasicRequestResponse_Streaming()
264264
Assert.Equal("Hello! How can I assist you today?", string.Concat(updates.Select(u => u.Text)));
265265

266266
var createdAt = DateTimeOffset.FromUnixTimeSeconds(1_741_892_091);
267-
Assert.Equal(10, updates.Count);
267+
Assert.Equal(17, updates.Count);
268+
268269
for (int i = 0; i < updates.Count; i++)
269270
{
270271
Assert.Equal("resp_67d329fbc87c81919f8952fe71dafc96029dabe3ee19bb77", updates[i].ResponseId);
271272
Assert.Equal("resp_67d329fbc87c81919f8952fe71dafc96029dabe3ee19bb77", updates[i].ConversationId);
272273
Assert.Equal(createdAt, updates[i].CreatedAt);
273274
Assert.Equal("gpt-4o-mini-2024-07-18", updates[i].ModelId);
274-
Assert.Equal(ChatRole.Assistant, updates[i].Role);
275275
Assert.Null(updates[i].AdditionalProperties);
276-
Assert.Equal(i == 10 ? 0 : 1, updates[i].Contents.Count);
276+
Assert.Equal((i >= 4 && i <= 12) || i == 16 ? 1 : 0, updates[i].Contents.Count);
277277
Assert.Equal(i < updates.Count - 1 ? null : ChatFinishReason.Stop, updates[i].FinishReason);
278278
}
279279

280+
for (int i = 4; i < updates.Count; i++)
281+
{
282+
Assert.Equal(ChatRole.Assistant, updates[i].Role);
283+
}
284+
280285
UsageContent usage = updates.SelectMany(u => u.Contents).OfType<UsageContent>().Single();
281286
Assert.Equal(26, usage.Details.InputTokenCount);
282287
Assert.Equal(10, usage.Details.OutputTokenCount);

0 commit comments

Comments
 (0)