Skip to content

Commit 190c355

Browse files
authored
Merge pull request #266 from MacPaw/bug/261-streaming-does-not-honor-http-status-codes
Bug: Fix error parsing in StreamInterpreter
2 parents 762d8ea + 3e64d91 commit 190c355

File tree

3 files changed

+100
-52
lines changed

3 files changed

+100
-52
lines changed

Demo/DemoChat/Sources/ChatStore.swift

+82-52
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Combine
1010
import OpenAI
1111
import SwiftUI
1212

13+
@MainActor
1314
public final class ChatStore: ObservableObject {
1415
public var openAIClient: OpenAIProtocol
1516
let idProvider: () -> String
@@ -61,8 +62,7 @@ public final class ChatStore: ObservableObject {
6162
func deleteConversation(_ conversationId: Conversation.ID) {
6263
conversations.removeAll(where: { $0.id == conversationId })
6364
}
64-
65-
@MainActor
65+
6666
func sendMessage(
6767
_ message: Message,
6868
conversationId: Conversation.ID,
@@ -78,7 +78,8 @@ public final class ChatStore: ObservableObject {
7878

7979
await completeChat(
8080
conversationId: conversationId,
81-
model: model
81+
model: model,
82+
stream: true
8283
)
8384
// For assistant case we send chats to thread and then poll, polling will receive sent chat + new assistant messages.
8485
case .assistant:
@@ -139,11 +140,11 @@ public final class ChatStore: ObservableObject {
139140
}
140141
}
141142
}
142-
143-
@MainActor
143+
144144
func completeChat(
145145
conversationId: Conversation.ID,
146-
model: Model
146+
model: Model,
147+
stream: Bool
147148
) async {
148149
guard let conversation = conversations.first(where: { $0.id == conversationId }) else {
149150
return
@@ -169,59 +170,88 @@ public final class ChatStore: ObservableObject {
169170
))
170171

171172
let functions = [weatherFunction]
172-
173-
let chatsStream: AsyncThrowingStream<ChatStreamResult, Error> = openAIClient.chatsStream(
174-
query: ChatQuery(
175-
messages: conversation.messages.map { message in
176-
ChatQuery.ChatCompletionMessageParam(role: message.role, content: message.content)!
177-
}, model: model,
178-
tools: functions
179-
)
173+
174+
let chatQuery = ChatQuery(
175+
messages: conversation.messages.map { message in
176+
ChatQuery.ChatCompletionMessageParam(role: message.role, content: message.content)!
177+
}, model: model,
178+
tools: functions
180179
)
181-
182-
var functionCalls = [(name: String, argument: String?)]()
183-
for try await partialChatResult in chatsStream {
184-
for choice in partialChatResult.choices {
185-
let existingMessages = conversations[conversationIndex].messages
186-
// Function calls are also streamed, so we need to accumulate.
187-
choice.delta.toolCalls?.forEach { toolCallDelta in
188-
if let functionCallDelta = toolCallDelta.function {
189-
if let nameDelta = functionCallDelta.name {
190-
functionCalls.append((nameDelta, functionCallDelta.arguments))
191-
}
180+
181+
if stream {
182+
try await completeConversationStreaming(
183+
conversationIndex: conversationIndex,
184+
model: model,
185+
query: chatQuery
186+
)
187+
} else {
188+
try await completeConversation(conversationIndex: conversationIndex, model: model, query: chatQuery)
189+
}
190+
} catch {
191+
conversationErrors[conversationId] = error
192+
}
193+
}
194+
195+
private func completeConversation(conversationIndex: Int, model: Model, query: ChatQuery) async throws {
196+
let chatResult: ChatResult = try await openAIClient.chats(query: query)
197+
chatResult.choices
198+
.map {
199+
Message(
200+
id: chatResult.id,
201+
role: $0.message.role,
202+
content: $0.message.content?.string ?? "",
203+
createdAt: Date(timeIntervalSince1970: TimeInterval(chatResult.created))
204+
)
205+
}.forEach { message in
206+
conversations[conversationIndex].messages.append(message)
207+
}
208+
}
209+
210+
private func completeConversationStreaming(conversationIndex: Int, model: Model, query: ChatQuery) async throws {
211+
let chatsStream: AsyncThrowingStream<ChatStreamResult, Error> = openAIClient.chatsStream(
212+
query: query
213+
)
214+
215+
var functionCalls = [(name: String, argument: String?)]()
216+
for try await partialChatResult in chatsStream {
217+
for choice in partialChatResult.choices {
218+
let existingMessages = conversations[conversationIndex].messages
219+
// Function calls are also streamed, so we need to accumulate.
220+
choice.delta.toolCalls?.forEach { toolCallDelta in
221+
if let functionCallDelta = toolCallDelta.function {
222+
if let nameDelta = functionCallDelta.name {
223+
functionCalls.append((nameDelta, functionCallDelta.arguments))
192224
}
193225
}
194-
var messageText = choice.delta.content ?? ""
195-
if let finishReason = choice.finishReason,
196-
finishReason == .toolCalls
197-
{
198-
functionCalls.forEach { (name: String, argument: String?) in
199-
messageText += "Function call: name=\(name) arguments=\(argument ?? "")\n"
200-
}
226+
}
227+
var messageText = choice.delta.content ?? ""
228+
if let finishReason = choice.finishReason,
229+
finishReason == .toolCalls
230+
{
231+
functionCalls.forEach { (name: String, argument: String?) in
232+
messageText += "Function call: name=\(name) arguments=\(argument ?? "")\n"
201233
}
202-
let message = Message(
203-
id: partialChatResult.id,
204-
role: choice.delta.role ?? .assistant,
205-
content: messageText,
206-
createdAt: Date(timeIntervalSince1970: TimeInterval(partialChatResult.created))
234+
}
235+
let message = Message(
236+
id: partialChatResult.id,
237+
role: choice.delta.role ?? .assistant,
238+
content: messageText,
239+
createdAt: Date(timeIntervalSince1970: TimeInterval(partialChatResult.created))
240+
)
241+
if let existingMessageIndex = existingMessages.firstIndex(where: { $0.id == partialChatResult.id }) {
242+
// Meld into previous message
243+
let previousMessage = existingMessages[existingMessageIndex]
244+
let combinedMessage = Message(
245+
id: message.id, // id stays the same for different deltas
246+
role: message.role,
247+
content: previousMessage.content + message.content,
248+
createdAt: message.createdAt
207249
)
208-
if let existingMessageIndex = existingMessages.firstIndex(where: { $0.id == partialChatResult.id }) {
209-
// Meld into previous message
210-
let previousMessage = existingMessages[existingMessageIndex]
211-
let combinedMessage = Message(
212-
id: message.id, // id stays the same for different deltas
213-
role: message.role,
214-
content: previousMessage.content + message.content,
215-
createdAt: message.createdAt
216-
)
217-
conversations[conversationIndex].messages[existingMessageIndex] = combinedMessage
218-
} else {
219-
conversations[conversationIndex].messages.append(message)
220-
}
250+
conversations[conversationIndex].messages[existingMessageIndex] = combinedMessage
251+
} else {
252+
conversations[conversationIndex].messages.append(message)
221253
}
222254
}
223-
} catch {
224-
conversationErrors[conversationId] = error
225255
}
226256
}
227257

Sources/OpenAI/Private/StreamInterpreter.swift

+5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ class StreamInterpreter<ResultType: Codable> {
1616
var onEventDispatched: ((ResultType) -> Void)?
1717

1818
func processData(_ data: Data) throws {
19+
let decoder = JSONDecoder()
20+
if let decoded = try? decoder.decode(APIErrorResponse.self, from: data) {
21+
throw decoded
22+
}
23+
1924
guard let stringContent = String(data: data, encoding: .utf8) else {
2025
throw StreamingError.unknownContent
2126
}

Tests/OpenAITests/StreamInterpreterTests.swift

+13
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ struct StreamInterpreterTests {
3232
#expect(chatStreamResults.count == 1)
3333
}
3434

35+
@Test func parseApiError() throws {
36+
do {
37+
try interpreter.processData(chatCompletionError())
38+
} catch {
39+
#expect(error is APIErrorResponse)
40+
}
41+
}
42+
3543
// Chunk with 3 objects. I captured it from a real response. It's a very short response that contains just "Hi"
3644
private func chatCompletionChunk() -> Data {
3745
"data: {\"id\":\"chatcmpl-AwnboO5ZnaUyii9xxC5ZVmM5vGark\",\"object\":\"chat.completion.chunk\",\"created\":1738577084,\"model\":\"gpt-4-0613\",\"service_tier\":\"default\",\"system_fingerprint\":null,\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\",\"refusal\":null},\"logprobs\":null,\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-AwnboO5ZnaUyii9xxC5ZVmM5vGark\",\"object\":\"chat.completion.chunk\",\"created\":1738577084,\"model\":\"gpt-4-0613\",\"service_tier\":\"default\",\"system_fingerprint\":null,\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hi\"},\"logprobs\":null,\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-AwnboO5ZnaUyii9xxC5ZVmM5vGark\",\"object\":\"chat.completion.chunk\",\"created\":1738577084,\"model\":\"gpt-4-0613\",\"service_tier\":\"default\",\"system_fingerprint\":null,\"choices\":[{\"index\":0,\"delta\":{},\"logprobs\":null,\"finish_reason\":\"stop\"}]}\n\n".data(using: .utf8)!
@@ -44,4 +52,9 @@ struct StreamInterpreterTests {
4452
private func chatCompletionChunkTermination() -> Data {
4553
"data: [DONE]\n\n".data(using: .utf8)!
4654
}
55+
56+
// Copied from an actual reponse that was an input to inreptreter
57+
private func chatCompletionError() -> Data {
58+
"{\n \"error\": {\n \"message\": \"The model `o3-mini` does not exist or you do not have access to it.\",\n \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\": \"model_not_found\"\n }\n}\n".data(using: .utf8)!
59+
}
4760
}

0 commit comments

Comments
 (0)