@@ -10,6 +10,7 @@ import Combine
10
10
import OpenAI
11
11
import SwiftUI
12
12
13
+ @MainActor
13
14
public final class ChatStore : ObservableObject {
14
15
public var openAIClient : OpenAIProtocol
15
16
let idProvider : ( ) -> String
@@ -61,8 +62,7 @@ public final class ChatStore: ObservableObject {
61
62
func deleteConversation( _ conversationId: Conversation . ID ) {
62
63
conversations. removeAll ( where: { $0. id == conversationId } )
63
64
}
64
-
65
- @MainActor
65
+
66
66
func sendMessage(
67
67
_ message: Message ,
68
68
conversationId: Conversation . ID ,
@@ -78,7 +78,8 @@ public final class ChatStore: ObservableObject {
78
78
79
79
await completeChat (
80
80
conversationId: conversationId,
81
- model: model
81
+ model: model,
82
+ stream: true
82
83
)
83
84
// For assistant case we send chats to thread and then poll, polling will receive sent chat + new assistant messages.
84
85
case . assistant:
@@ -139,11 +140,11 @@ public final class ChatStore: ObservableObject {
139
140
}
140
141
}
141
142
}
142
-
143
- @MainActor
143
+
144
144
func completeChat(
145
145
conversationId: Conversation . ID ,
146
- model: Model
146
+ model: Model ,
147
+ stream: Bool
147
148
) async {
148
149
guard let conversation = conversations. first ( where: { $0. id == conversationId } ) else {
149
150
return
@@ -169,59 +170,88 @@ public final class ChatStore: ObservableObject {
169
170
) )
170
171
171
172
let functions = [ weatherFunction]
172
-
173
- let chatsStream : AsyncThrowingStream < ChatStreamResult , Error > = openAIClient. chatsStream (
174
- query: ChatQuery (
175
- messages: conversation. messages. map { message in
176
- ChatQuery . ChatCompletionMessageParam ( role: message. role, content: message. content) !
177
- } , model: model,
178
- tools: functions
179
- )
173
+
174
+ let chatQuery = ChatQuery (
175
+ messages: conversation. messages. map { message in
176
+ ChatQuery . ChatCompletionMessageParam ( role: message. role, content: message. content) !
177
+ } , model: model,
178
+ tools: functions
180
179
)
181
-
182
- var functionCalls = [ ( name: String, argument: String? ) ] ( )
183
- for try await partialChatResult in chatsStream {
184
- for choice in partialChatResult. choices {
185
- let existingMessages = conversations [ conversationIndex] . messages
186
- // Function calls are also streamed, so we need to accumulate.
187
- choice. delta. toolCalls? . forEach { toolCallDelta in
188
- if let functionCallDelta = toolCallDelta. function {
189
- if let nameDelta = functionCallDelta. name {
190
- functionCalls. append ( ( nameDelta, functionCallDelta. arguments) )
191
- }
180
+
181
+ if stream {
182
+ try await completeConversationStreaming (
183
+ conversationIndex: conversationIndex,
184
+ model: model,
185
+ query: chatQuery
186
+ )
187
+ } else {
188
+ try await completeConversation ( conversationIndex: conversationIndex, model: model, query: chatQuery)
189
+ }
190
+ } catch {
191
+ conversationErrors [ conversationId] = error
192
+ }
193
+ }
194
+
195
+ private func completeConversation( conversationIndex: Int , model: Model , query: ChatQuery ) async throws {
196
+ let chatResult : ChatResult = try await openAIClient. chats ( query: query)
197
+ chatResult. choices
198
+ . map {
199
+ Message (
200
+ id: chatResult. id,
201
+ role: $0. message. role,
202
+ content: $0. message. content? . string ?? " " ,
203
+ createdAt: Date ( timeIntervalSince1970: TimeInterval ( chatResult. created) )
204
+ )
205
+ } . forEach { message in
206
+ conversations [ conversationIndex] . messages. append ( message)
207
+ }
208
+ }
209
+
210
+ private func completeConversationStreaming( conversationIndex: Int , model: Model , query: ChatQuery ) async throws {
211
+ let chatsStream : AsyncThrowingStream < ChatStreamResult , Error > = openAIClient. chatsStream (
212
+ query: query
213
+ )
214
+
215
+ var functionCalls = [ ( name: String, argument: String? ) ] ( )
216
+ for try await partialChatResult in chatsStream {
217
+ for choice in partialChatResult. choices {
218
+ let existingMessages = conversations [ conversationIndex] . messages
219
+ // Function calls are also streamed, so we need to accumulate.
220
+ choice. delta. toolCalls? . forEach { toolCallDelta in
221
+ if let functionCallDelta = toolCallDelta. function {
222
+ if let nameDelta = functionCallDelta. name {
223
+ functionCalls. append ( ( nameDelta, functionCallDelta. arguments) )
192
224
}
193
225
}
194
- var messageText = choice . delta . content ?? " "
195
- if let finishReason = choice. finishReason ,
196
- finishReason == . toolCalls
197
- {
198
- functionCalls . forEach { ( name : String , argument : String ? ) in
199
- messageText += " Function call: name= \ ( name) arguments= \( argument ?? " " ) \n "
200
- }
226
+ }
227
+ var messageText = choice. delta . content ?? " "
228
+ if let finishReason = choice . finishReason ,
229
+ finishReason == . toolCalls
230
+ {
231
+ functionCalls . forEach { ( name: String , argument: String ? ) in
232
+ messageText += " Function call: name= \( name ) arguments= \( argument ?? " " ) \n "
201
233
}
202
- let message = Message (
203
- id: partialChatResult. id,
204
- role: choice. delta. role ?? . assistant,
205
- content: messageText,
206
- createdAt: Date ( timeIntervalSince1970: TimeInterval ( partialChatResult. created) )
234
+ }
235
+ let message = Message (
236
+ id: partialChatResult. id,
237
+ role: choice. delta. role ?? . assistant,
238
+ content: messageText,
239
+ createdAt: Date ( timeIntervalSince1970: TimeInterval ( partialChatResult. created) )
240
+ )
241
+ if let existingMessageIndex = existingMessages. firstIndex ( where: { $0. id == partialChatResult. id } ) {
242
+ // Meld into previous message
243
+ let previousMessage = existingMessages [ existingMessageIndex]
244
+ let combinedMessage = Message (
245
+ id: message. id, // id stays the same for different deltas
246
+ role: message. role,
247
+ content: previousMessage. content + message. content,
248
+ createdAt: message. createdAt
207
249
)
208
- if let existingMessageIndex = existingMessages. firstIndex ( where: { $0. id == partialChatResult. id } ) {
209
- // Meld into previous message
210
- let previousMessage = existingMessages [ existingMessageIndex]
211
- let combinedMessage = Message (
212
- id: message. id, // id stays the same for different deltas
213
- role: message. role,
214
- content: previousMessage. content + message. content,
215
- createdAt: message. createdAt
216
- )
217
- conversations [ conversationIndex] . messages [ existingMessageIndex] = combinedMessage
218
- } else {
219
- conversations [ conversationIndex] . messages. append ( message)
220
- }
250
+ conversations [ conversationIndex] . messages [ existingMessageIndex] = combinedMessage
251
+ } else {
252
+ conversations [ conversationIndex] . messages. append ( message)
221
253
}
222
254
}
223
- } catch {
224
- conversationErrors [ conversationId] = error
225
255
}
226
256
}
227
257
0 commit comments