Skip to content

Commit 5705975

Browse files
authored
Merge pull request #251 from MacPaw/70-how-to-cancel-closure-openaichatsstreamquery-query
Implement cancellation by utilizing native URLSession functionality
2 parents c1a8a0b + fa7ce71 commit 5705975

33 files changed

+1457
-940
lines changed

Demo/DemoChat/Sources/ImageStore.swift

+31-8
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,50 @@
77

88
import Foundation
99
import OpenAI
10+
import Combine
1011

1112
public final class ImageStore: ObservableObject {
1213
public var openAIClient: OpenAIProtocol
1314

15+
@Published var imagesQueryInProgress = false
1416
@Published var images: [ImagesResult.Image] = []
1517

18+
private var subscription: AnyCancellable?
19+
1620
public init(
1721
openAIClient: OpenAIProtocol
1822
) {
1923
self.openAIClient = openAIClient
2024
}
2125

2226
@MainActor
23-
func images(query: ImagesQuery) async {
27+
func images(query: ImagesQuery) {
28+
imagesQueryInProgress = true
2429
images.removeAll()
25-
do {
26-
let response = try await openAIClient.images(query: query)
27-
images = response.data
28-
} catch {
29-
// TODO: Better error handling
30-
print(error.localizedDescription)
31-
}
30+
31+
subscription = openAIClient
32+
.images(query: query)
33+
.receive(on: DispatchQueue.main)
34+
.handleEvents(receiveCancel: {
35+
self.imagesQueryInProgress = false
36+
})
37+
.sink(receiveCompletion: { completion in
38+
self.imagesQueryInProgress = false
39+
40+
switch completion {
41+
case .finished:
42+
break
43+
case .failure(let failure):
44+
// TODO: Better error handling
45+
print(failure)
46+
}
47+
}, receiveValue: { imagesResult in
48+
self.images = imagesResult.data
49+
})
50+
}
51+
52+
func cancelImagesQuery() {
53+
subscription?.cancel()
54+
subscription = nil
3255
}
3356
}

Demo/DemoChat/Sources/UI/ChatView.swift

+6-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ public struct ChatView: View {
1414

1515
@Environment(\.dateProviderValue) var dateProvider
1616
@Environment(\.idProviderValue) var idProvider
17+
18+
@State private var sendMessageTask: Task<Void, Never>?
1719

1820
public init(store: ChatStore, assistantStore: AssistantStore) {
1921
self.store = store
@@ -52,7 +54,7 @@ public struct ChatView: View {
5254
availableAssistants: assistantStore.availableAssistants, conversation: conversation,
5355
error: store.conversationErrors[conversation.id],
5456
sendMessage: { message, selectedModel in
55-
Task {
57+
self.sendMessageTask = Task {
5658
await store.sendMessage(
5759
Message(
5860
id: idProvider(),
@@ -68,6 +70,9 @@ public struct ChatView: View {
6870
)
6971
}
7072
}
73+
}.onDisappear {
74+
// It may not produce an ideal behavior, but it's here for demonstrating that cancellation works as expected
75+
sendMessageTask?.cancel()
7176
}
7277
}
7378
}

Demo/DemoChat/Sources/UI/Images/ImageCreationView.swift

+14-6
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,7 @@ public struct ImageCreationView: View {
4444
}
4545
Section {
4646
HStack {
47-
Button("Create Image" + (n == 1 ? "" : "s")) {
48-
Task {
49-
let query = ImagesQuery(prompt: prompt, n: n, size: size)
50-
await store.images(query: query)
51-
}
52-
}
47+
actionButton
5348
.foregroundColor(.accentColor)
5449
Spacer()
5550
}
@@ -72,4 +67,17 @@ public struct ImageCreationView: View {
7267
.listStyle(.insetGrouped)
7368
.navigationTitle("Create Image")
7469
}
70+
71+
private var actionButton: some View {
72+
if store.imagesQueryInProgress {
73+
Button("Cancel") {
74+
store.cancelImagesQuery()
75+
}
76+
} else {
77+
Button("Create Image" + (n == 1 ? "" : "s")) {
78+
let query = ImagesQuery(prompt: prompt, n: n, size: size)
79+
store.images(query: query)
80+
}
81+
}
82+
}
7583
}

Demo/DemoChat/Sources/UI/Misc/ListModelsView.swift

+2-4
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@ public struct ListModelsView: View {
1818
.listStyle(.insetGrouped)
1919
.navigationTitle("Models")
2020
}
21-
.onAppear {
22-
Task {
23-
await store.getModels()
24-
}
21+
.task {
22+
await store.getModels()
2523
}
2624
}
2725
}

README.md

+34
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ This repository contains Swift community-maintained implementation over [OpenAI]
4949
- [Submit Tool Outputs for Run](#submit-tool-outputs-for-run)
5050
- [Files](#files)
5151
- [Upload File](#upload-file)
52+
- [Cancelling requests](#cancelling-requests)
5253
- [Example Project](#example-project)
5354
- [Contribution Guidelines](#contribution-guidelines)
5455
- [Links](#links)
@@ -1043,6 +1044,39 @@ openAI.files(query: query) { result in
10431044
}
10441045
```
10451046

1047+
### Cancelling requests
1048+
#### Closure based API
1049+
When you call any of the closure-based API methods, it returns discardable `CancellableRequest`. Hold a reference to it to be able to cancel the request later.
1050+
```swift
1051+
let cancellableRequest = object.chats(query: query, completion: { _ in })
1052+
cancellableReques
1053+
```
1054+
1055+
#### Swift Concurrency
1056+
For Swift Concurrency calls, you can simply cancel the calling task, and corresponding `URLSessionDataTask` would get cancelled automatically.
1057+
```swift
1058+
let task = Task {
1059+
do {
1060+
let chatResult = try await openAIClient.chats(query: .init(messages: [], model: "asd"))
1061+
} catch {
1062+
// Handle cancellation or error
1063+
}
1064+
}
1065+
1066+
task.cancel()
1067+
```
1068+
1069+
#### Combine
1070+
In Combine, use a default cancellation mechanism. Just discard the reference to a subscription, or call `cancel()` on it.
1071+
1072+
```swift
1073+
let subscription = openAIClient
1074+
.images(query: query)
1075+
.sink(receiveCompletion: { completion in }, receiveValue: { imagesResult in })
1076+
1077+
subscription.cancel()
1078+
```
1079+
10461080
## Example Project
10471081

10481082
You can find example iOS application in [Demo](/Demo) folder.
+218
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
//
2+
// OpenAI+OpenAIAsync.swift
3+
// OpenAI
4+
//
5+
// Created by Oleksii Nezhyborets on 31.01.2025.
6+
//
7+
8+
import Foundation
9+
10+
@available(iOS 13.0, macOS 10.15, tvOS 13.0, watchOS 6.0, *)
11+
extension OpenAI: OpenAIAsync {
12+
public func images(query: ImagesQuery) async throws -> ImagesResult {
13+
try await performRequestAsync(
14+
request: makeImagesRequest(query: query)
15+
)
16+
}
17+
18+
public func imageEdits(query: ImageEditsQuery) async throws -> ImagesResult {
19+
try await performRequestAsync(
20+
request: makeImageEditsRequest(query: query)
21+
)
22+
}
23+
24+
public func imageVariations(query: ImageVariationsQuery) async throws -> ImagesResult {
25+
try await performRequestAsync(
26+
request: makeImageVariationsRequest(query: query)
27+
)
28+
}
29+
30+
public func embeddings(query: EmbeddingsQuery) async throws -> EmbeddingsResult {
31+
try await performRequestAsync(
32+
request: makeEmbeddingsRequest(query: query)
33+
)
34+
}
35+
36+
public func chats(query: ChatQuery) async throws -> ChatResult {
37+
try await performRequestAsync(
38+
request: makeChatsRequest(query: query)
39+
)
40+
}
41+
42+
public func chatsStream(query: ChatQuery) -> AsyncThrowingStream<ChatStreamResult, Error> {
43+
return AsyncThrowingStream { continuation in
44+
let cancellableRequest = chatsStream(query: query) { result in
45+
continuation.yield(with: result)
46+
} completion: { error in
47+
continuation.finish(throwing: error)
48+
}
49+
50+
continuation.onTermination = { termination in
51+
switch termination {
52+
case .cancelled:
53+
cancellableRequest.cancelRequest()
54+
case .finished:
55+
break
56+
@unknown default:
57+
break
58+
}
59+
}
60+
}
61+
}
62+
63+
public func model(query: ModelQuery) async throws -> ModelResult {
64+
try await performRequestAsync(
65+
request: makeModelRequest(query: query)
66+
)
67+
}
68+
69+
public func models() async throws -> ModelsResult {
70+
try await performRequestAsync(
71+
request: makeModelsRequest()
72+
)
73+
}
74+
75+
public func moderations(query: ModerationsQuery) async throws -> ModerationsResult {
76+
try await performRequestAsync(
77+
request: makeModerationsRequest(query: query)
78+
)
79+
}
80+
81+
public func audioCreateSpeech(query: AudioSpeechQuery) async throws -> AudioSpeechResult {
82+
try await performRequestAsync(
83+
request: makeAudioCreateSpeechRequest(query: query)
84+
)
85+
}
86+
87+
public func audioTranscriptions(query: AudioTranscriptionQuery) async throws -> AudioTranscriptionResult {
88+
try await performRequestAsync(
89+
request: makeAudioTranscriptionsRequest(query: query)
90+
)
91+
}
92+
93+
public func audioTranslations(query: AudioTranslationQuery) async throws -> AudioTranslationResult {
94+
try await performRequestAsync(
95+
request: makeAudioTranslationsRequest(query: query)
96+
)
97+
}
98+
99+
public func assistants() async throws -> AssistantsResult {
100+
try await assistants(after: nil)
101+
}
102+
103+
public func assistants(after: String?) async throws -> AssistantsResult {
104+
try await performRequestAsync(
105+
request: makeAssistantsRequest(after)
106+
)
107+
}
108+
109+
public func assistantCreate(query: AssistantsQuery) async throws -> AssistantResult {
110+
try await performRequestAsync(
111+
request: makeAssistantCreateRequest(query)
112+
)
113+
}
114+
115+
public func assistantModify(query: AssistantsQuery, assistantId: String) async throws -> AssistantResult {
116+
try await performRequestAsync(
117+
request: makeAssistantModifyRequest(assistantId, query)
118+
)
119+
}
120+
121+
public func threads(query: ThreadsQuery) async throws -> ThreadsResult {
122+
try await performRequestAsync(
123+
request: makeThreadsRequest(query)
124+
)
125+
}
126+
127+
public func threadRun(query: ThreadRunQuery) async throws -> RunResult {
128+
try await performRequestAsync(
129+
request: makeThreadRunRequest(query)
130+
)
131+
}
132+
133+
public func runs(threadId: String,query: RunsQuery) async throws -> RunResult {
134+
try await performRequestAsync(
135+
request: makeRunsRequest(threadId, query)
136+
)
137+
}
138+
139+
public func runRetrieve(threadId: String, runId: String) async throws -> RunResult {
140+
try await performRequestAsync(
141+
request: makeRunRetrieveRequest(threadId, runId)
142+
)
143+
}
144+
145+
public func runRetrieveSteps(threadId: String, runId: String) async throws -> RunRetrieveStepsResult {
146+
try await runRetrieveSteps(threadId: threadId, runId: runId, before: nil)
147+
}
148+
149+
public func runRetrieveSteps(threadId: String, runId: String, before: String?) async throws -> RunRetrieveStepsResult {
150+
try await performRequestAsync(
151+
request: makeRunRetrieveStepsRequest(threadId, runId, before)
152+
)
153+
}
154+
155+
public func runSubmitToolOutputs(threadId: String, runId: String, query: RunToolOutputsQuery) async throws -> RunResult {
156+
try await performRequestAsync(
157+
request: makeRunSubmitToolOutputsRequest(threadId, runId, query)
158+
)
159+
}
160+
161+
public func threadsMessages(threadId: String) async throws -> ThreadsMessagesResult {
162+
try await performRequestAsync(
163+
request: makeThreadsMessagesRequest(threadId, before: nil)
164+
)
165+
}
166+
167+
public func threadsMessages(threadId: String, before: String?) async throws -> ThreadsMessagesResult {
168+
try await performRequestAsync(
169+
request: makeThreadsMessagesRequest(threadId, before: before)
170+
)
171+
}
172+
173+
public func threadsAddMessage(threadId: String, query: MessageQuery) async throws -> ThreadAddMessageResult {
174+
try await performRequestAsync(
175+
request: makeThreadsAddMessageRequest(threadId, query)
176+
)
177+
}
178+
179+
public func files(query: FilesQuery) async throws -> FilesResult {
180+
try await performRequestAsync(
181+
request: makeFilesRequest(query: query)
182+
)
183+
}
184+
185+
func performRequestAsync<ResultType: Codable>(request: any URLRequestBuildable) async throws -> ResultType {
186+
let urlRequest = try request.build(token: configuration.token,
187+
organizationIdentifier: configuration.organizationIdentifier,
188+
timeoutInterval: configuration.timeoutInterval)
189+
if #available(iOS 15.0, macOS 12.0, tvOS 15.0, watchOS 8.0, *) {
190+
let (data, _) = try await session.data(for: urlRequest, delegate: nil)
191+
let decoder = JSONDecoder()
192+
do {
193+
return try decoder.decode(ResultType.self, from: data)
194+
} catch {
195+
throw (try? decoder.decode(APIErrorResponse.self, from: data)) ?? error
196+
}
197+
} else {
198+
let dataTaskStore = URLSessionDataTaskStore()
199+
return try await withTaskCancellationHandler {
200+
return try await withCheckedThrowingContinuation { continuation in
201+
let dataTask = self.makeDataTask(forRequest: urlRequest) { (result: Result<ResultType, Error>) in
202+
continuation.resume(with: result)
203+
}
204+
205+
dataTask.resume()
206+
207+
Task {
208+
await dataTaskStore.setDataTask(dataTask)
209+
}
210+
}
211+
} onCancel: {
212+
Task {
213+
await dataTaskStore.getDataTask()?.cancel()
214+
}
215+
}
216+
}
217+
}
218+
}

0 commit comments

Comments
 (0)