Skip to content

Implement cancellation by utilizing native URLSession functionality #251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions Demo/DemoChat/Sources/ImageStore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,50 @@

import Foundation
import OpenAI
import Combine

public final class ImageStore: ObservableObject {
public var openAIClient: OpenAIProtocol

@Published var imagesQueryInProgress = false
@Published var images: [ImagesResult.Image] = []

private var subscription: AnyCancellable?

public init(
openAIClient: OpenAIProtocol
) {
self.openAIClient = openAIClient
}

@MainActor
func images(query: ImagesQuery) async {
func images(query: ImagesQuery) {
imagesQueryInProgress = true
images.removeAll()
do {
let response = try await openAIClient.images(query: query)
images = response.data
} catch {
// TODO: Better error handling
print(error.localizedDescription)
}

subscription = openAIClient
.images(query: query)
.receive(on: DispatchQueue.main)
.handleEvents(receiveCancel: {
self.imagesQueryInProgress = false
})
.sink(receiveCompletion: { completion in
self.imagesQueryInProgress = false

switch completion {
case .finished:
break
case .failure(let failure):
// TODO: Better error handling
print(failure)
}
}, receiveValue: { imagesResult in
self.images = imagesResult.data
})
}

func cancelImagesQuery() {
subscription?.cancel()
subscription = nil
}
}
7 changes: 6 additions & 1 deletion Demo/DemoChat/Sources/UI/ChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ public struct ChatView: View {

@Environment(\.dateProviderValue) var dateProvider
@Environment(\.idProviderValue) var idProvider

@State private var sendMessageTask: Task<Void, Never>?

public init(store: ChatStore, assistantStore: AssistantStore) {
self.store = store
Expand Down Expand Up @@ -52,7 +54,7 @@ public struct ChatView: View {
availableAssistants: assistantStore.availableAssistants, conversation: conversation,
error: store.conversationErrors[conversation.id],
sendMessage: { message, selectedModel in
Task {
self.sendMessageTask = Task {
await store.sendMessage(
Message(
id: idProvider(),
Expand All @@ -68,6 +70,9 @@ public struct ChatView: View {
)
}
}
}.onDisappear {
// It may not produce an ideal behavior, but it's here for demonstrating that cancellation works as expected
sendMessageTask?.cancel()
}
}
}
20 changes: 14 additions & 6 deletions Demo/DemoChat/Sources/UI/Images/ImageCreationView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ public struct ImageCreationView: View {
}
Section {
HStack {
Button("Create Image" + (n == 1 ? "" : "s")) {
Task {
let query = ImagesQuery(prompt: prompt, n: n, size: size)
await store.images(query: query)
}
}
actionButton
.foregroundColor(.accentColor)
Spacer()
}
Expand All @@ -72,4 +67,17 @@ public struct ImageCreationView: View {
.listStyle(.insetGrouped)
.navigationTitle("Create Image")
}

private var actionButton: some View {
if store.imagesQueryInProgress {
Button("Cancel") {
store.cancelImagesQuery()
}
} else {
Button("Create Image" + (n == 1 ? "" : "s")) {
let query = ImagesQuery(prompt: prompt, n: n, size: size)
store.images(query: query)
}
}
}
}
6 changes: 2 additions & 4 deletions Demo/DemoChat/Sources/UI/Misc/ListModelsView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ public struct ListModelsView: View {
.listStyle(.insetGrouped)
.navigationTitle("Models")
}
.onAppear {
Task {
await store.getModels()
}
.task {
await store.getModels()
}
}
}
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ This repository contains Swift community-maintained implementation over [OpenAI]
- [Submit Tool Outputs for Run](#submit-tool-outputs-for-run)
- [Files](#files)
- [Upload File](#upload-file)
- [Cancelling requests](#cancelling-requests)
- [Example Project](#example-project)
- [Contribution Guidelines](#contribution-guidelines)
- [Links](#links)
Expand Down Expand Up @@ -1043,6 +1044,39 @@ openAI.files(query: query) { result in
}
```

### Cancelling requests
#### Closure based API
When you call any of the closure-based API methods, it returns discardable `CancellableRequest`. Hold a reference to it to be able to cancel the request later.
```swift
let cancellableRequest = object.chats(query: query, completion: { _ in })
cancellableReques
```

#### Swift Concurrency
For Swift Concurrency calls, you can simply cancel the calling task, and corresponding `URLSessionDataTask` would get cancelled automatically.
```swift
let task = Task {
do {
let chatResult = try await openAIClient.chats(query: .init(messages: [], model: "asd"))
} catch {
// Handle cancellation or error
}
}

task.cancel()
```

#### Combine
In Combine, use a default cancellation mechanism. Just discard the reference to a subscription, or call `cancel()` on it.

```swift
let subscription = openAIClient
.images(query: query)
.sink(receiveCompletion: { completion in }, receiveValue: { imagesResult in })

subscription.cancel()
```

## Example Project

You can find example iOS application in [Demo](/Demo) folder.
Expand Down
218 changes: 218 additions & 0 deletions Sources/OpenAI/OpenAI+OpenAIAsync.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
//
// OpenAI+OpenAIAsync.swift
// OpenAI
//
// Created by Oleksii Nezhyborets on 31.01.2025.
//

import Foundation

@available(iOS 13.0, macOS 10.15, tvOS 13.0, watchOS 6.0, *)
extension OpenAI: OpenAIAsync {
public func images(query: ImagesQuery) async throws -> ImagesResult {
try await performRequestAsync(
request: makeImagesRequest(query: query)
)
}

public func imageEdits(query: ImageEditsQuery) async throws -> ImagesResult {
try await performRequestAsync(
request: makeImageEditsRequest(query: query)
)
}

public func imageVariations(query: ImageVariationsQuery) async throws -> ImagesResult {
try await performRequestAsync(
request: makeImageVariationsRequest(query: query)
)
}

public func embeddings(query: EmbeddingsQuery) async throws -> EmbeddingsResult {
try await performRequestAsync(
request: makeEmbeddingsRequest(query: query)
)
}

public func chats(query: ChatQuery) async throws -> ChatResult {
try await performRequestAsync(
request: makeChatsRequest(query: query)
)
}

public func chatsStream(query: ChatQuery) -> AsyncThrowingStream<ChatStreamResult, Error> {
return AsyncThrowingStream { continuation in
let cancellableRequest = chatsStream(query: query) { result in
continuation.yield(with: result)
} completion: { error in
continuation.finish(throwing: error)
}

continuation.onTermination = { termination in
switch termination {
case .cancelled:
cancellableRequest.cancelRequest()
case .finished:
break
@unknown default:
break
}
}
}
}

public func model(query: ModelQuery) async throws -> ModelResult {
try await performRequestAsync(
request: makeModelRequest(query: query)
)
}

public func models() async throws -> ModelsResult {
try await performRequestAsync(
request: makeModelsRequest()
)
}

public func moderations(query: ModerationsQuery) async throws -> ModerationsResult {
try await performRequestAsync(
request: makeModerationsRequest(query: query)
)
}

public func audioCreateSpeech(query: AudioSpeechQuery) async throws -> AudioSpeechResult {
try await performRequestAsync(
request: makeAudioCreateSpeechRequest(query: query)
)
}

public func audioTranscriptions(query: AudioTranscriptionQuery) async throws -> AudioTranscriptionResult {
try await performRequestAsync(
request: makeAudioTranscriptionsRequest(query: query)
)
}

public func audioTranslations(query: AudioTranslationQuery) async throws -> AudioTranslationResult {
try await performRequestAsync(
request: makeAudioTranslationsRequest(query: query)
)
}

public func assistants() async throws -> AssistantsResult {
try await assistants(after: nil)
}

public func assistants(after: String?) async throws -> AssistantsResult {
try await performRequestAsync(
request: makeAssistantsRequest(after)
)
}

public func assistantCreate(query: AssistantsQuery) async throws -> AssistantResult {
try await performRequestAsync(
request: makeAssistantCreateRequest(query)
)
}

public func assistantModify(query: AssistantsQuery, assistantId: String) async throws -> AssistantResult {
try await performRequestAsync(
request: makeAssistantModifyRequest(assistantId, query)
)
}

public func threads(query: ThreadsQuery) async throws -> ThreadsResult {
try await performRequestAsync(
request: makeThreadsRequest(query)
)
}

public func threadRun(query: ThreadRunQuery) async throws -> RunResult {
try await performRequestAsync(
request: makeThreadRunRequest(query)
)
}

public func runs(threadId: String,query: RunsQuery) async throws -> RunResult {
try await performRequestAsync(
request: makeRunsRequest(threadId, query)
)
}

public func runRetrieve(threadId: String, runId: String) async throws -> RunResult {
try await performRequestAsync(
request: makeRunRetrieveRequest(threadId, runId)
)
}

public func runRetrieveSteps(threadId: String, runId: String) async throws -> RunRetrieveStepsResult {
try await runRetrieveSteps(threadId: threadId, runId: runId, before: nil)
}

public func runRetrieveSteps(threadId: String, runId: String, before: String?) async throws -> RunRetrieveStepsResult {
try await performRequestAsync(
request: makeRunRetrieveStepsRequest(threadId, runId, before)
)
}

public func runSubmitToolOutputs(threadId: String, runId: String, query: RunToolOutputsQuery) async throws -> RunResult {
try await performRequestAsync(
request: makeRunSubmitToolOutputsRequest(threadId, runId, query)
)
}

public func threadsMessages(threadId: String) async throws -> ThreadsMessagesResult {
try await performRequestAsync(
request: makeThreadsMessagesRequest(threadId, before: nil)
)
}

public func threadsMessages(threadId: String, before: String?) async throws -> ThreadsMessagesResult {
try await performRequestAsync(
request: makeThreadsMessagesRequest(threadId, before: before)
)
}

public func threadsAddMessage(threadId: String, query: MessageQuery) async throws -> ThreadAddMessageResult {
try await performRequestAsync(
request: makeThreadsAddMessageRequest(threadId, query)
)
}

public func files(query: FilesQuery) async throws -> FilesResult {
try await performRequestAsync(
request: makeFilesRequest(query: query)
)
}

func performRequestAsync<ResultType: Codable>(request: any URLRequestBuildable) async throws -> ResultType {
let urlRequest = try request.build(token: configuration.token,
organizationIdentifier: configuration.organizationIdentifier,
timeoutInterval: configuration.timeoutInterval)
if #available(iOS 15.0, macOS 12.0, tvOS 15.0, watchOS 8.0, *) {
let (data, _) = try await session.data(for: urlRequest, delegate: nil)
let decoder = JSONDecoder()
do {
return try decoder.decode(ResultType.self, from: data)
} catch {
throw (try? decoder.decode(APIErrorResponse.self, from: data)) ?? error
}
} else {
let dataTaskStore = URLSessionDataTaskStore()
return try await withTaskCancellationHandler {
return try await withCheckedThrowingContinuation { continuation in
let dataTask = self.makeDataTask(forRequest: urlRequest) { (result: Result<ResultType, Error>) in
continuation.resume(with: result)
}

dataTask.resume()

Task {
await dataTaskStore.setDataTask(dataTask)
}
}
} onCancel: {
Task {
await dataTaskStore.getDataTask()?.cancel()
}
}
}
}
}
Loading
Loading