Skip to content

Commit 1a7a5ba

Browse files
authored
Merge pull request #191 from batanus/feat/add-ssl-delegate-handler
Add SSLDelegateProtocol to StreamingSession
2 parents 46eacc7 + 8de021d commit 1a7a5ba

11 files changed

+72
-26
lines changed

Sources/OpenAI/OpenAI.swift

+19-12
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,27 @@ final public class OpenAI: @unchecked Sendable {
6767
public let configuration: Configuration
6868

6969
public convenience init(apiToken: String) {
70-
self.init(configuration: Configuration(token: apiToken), session: URLSession.shared)
70+
self.init(configuration: Configuration(token: apiToken), session: URLSession.shared, sslStreamingDelegate: nil)
7171
}
7272

7373
public convenience init(configuration: Configuration) {
74-
self.init(configuration: configuration, session: URLSession.shared)
74+
self.init(configuration: configuration, session: URLSession.shared, sslStreamingDelegate: nil)
75+
}
76+
77+
public convenience init(configuration: Configuration, session: URLSession = URLSession.shared, sslStreamingDelegate: SSLDelegateProtocol? = nil) {
78+
let streamingSessionFactory = ImplicitURLSessionStreamingSessionFactory(sslDelegate: sslStreamingDelegate)
79+
80+
self.init(
81+
configuration: configuration,
82+
session: session,
83+
streamingSessionFactory: streamingSessionFactory
84+
)
7585
}
7686

7787
init(
7888
configuration: Configuration,
7989
session: URLSessionProtocol,
80-
streamingSessionFactory: StreamingSessionFactory = ImplicitURLSessionStreamingSessionFactory(),
90+
streamingSessionFactory: StreamingSessionFactory,
8191
cancellablesFactory: CancellablesFactory = DefaultCancellablesFactory(),
8292
executionSerializer: ExecutionSerializer = GCDQueueAsyncExecutionSerializer(queue: .userInitiated)
8393
) {
@@ -87,13 +97,6 @@ final public class OpenAI: @unchecked Sendable {
8797
self.cancellablesFactory = cancellablesFactory
8898
self.executionSerializer = executionSerializer
8999
}
90-
91-
public convenience init(configuration: Configuration, session: URLSession = URLSession.shared) {
92-
self.init(
93-
configuration: configuration,
94-
session: session as URLSessionProtocol
95-
)
96-
}
97100

98101
public func threadsAddMessage(
99102
threadId: String,
@@ -250,7 +253,7 @@ final public class OpenAI: @unchecked Sendable {
250253
performSpeechRequest(request: makeAudioCreateSpeechRequest(query: query), completion: completion)
251254
}
252255

253-
public func audioCreateSpeechStream(query: AudioSpeechQuery, onResult: @escaping (Result<AudioSpeechResult, Error>) -> Void, completion: ((Error?) -> Void)?) -> CancellableRequest {
256+
public func audioCreateSpeechStream(query: AudioSpeechQuery, onResult: @escaping @Sendable (Result<AudioSpeechResult, Error>) -> Void, completion: (@Sendable (Error?) -> Void)?) -> CancellableRequest {
254257
performSpeechStreamingRequest(
255258
request: JSONRequest<AudioSpeechResult>(body: query, url: buildURL(path: .audioSpeech)),
256259
onResult: onResult,
@@ -318,7 +321,11 @@ extension OpenAI {
318321
}
319322
}
320323

321-
func performSpeechStreamingRequest(request: any URLRequestBuildable, onResult: @escaping (Result<AudioSpeechResult, Error>) -> Void, completion: ((Error?) -> Void)?) -> CancellableRequest {
324+
func performSpeechStreamingRequest(
325+
request: any URLRequestBuildable,
326+
onResult: @escaping @Sendable (Result<AudioSpeechResult, Error>) -> Void,
327+
completion: (@Sendable (Error?) -> Void)?
328+
) -> CancellableRequest {
322329
do {
323330
let urlRequest = try request.build(configuration: configuration)
324331

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import Foundation
2+
#if canImport(FoundationNetworking)
3+
import FoundationNetworking
4+
#endif
5+
6+
public protocol SSLDelegateProtocol: Sendable {
7+
func urlSession(
8+
_ session: URLSession,
9+
didReceive challenge: URLAuthenticationChallenge,
10+
completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void
11+
)
12+
}

Sources/OpenAI/Private/Streaming/ServerSentEventsStreamingSessionFactory.swift

+4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ protocol StreamingSessionFactory {
2828
}
2929

3030
struct ImplicitURLSessionStreamingSessionFactory: StreamingSessionFactory {
31+
let sslDelegate: SSLDelegateProtocol?
32+
3133
func makeServerSentEventsStreamingSession<ResultType>(
3234
urlRequest: URLRequest,
3335
onReceiveContent: @Sendable @escaping (StreamingSession<ServerSentEventsStreamInterpreter<ResultType>>, ResultType) -> Void,
@@ -37,6 +39,7 @@ struct ImplicitURLSessionStreamingSessionFactory: StreamingSessionFactory {
3739
.init(
3840
urlRequest: urlRequest,
3941
interpreter: .init(),
42+
sslDelegate: sslDelegate,
4043
onReceiveContent: onReceiveContent,
4144
onProcessingError: onProcessingError,
4245
onComplete: onComplete
@@ -52,6 +55,7 @@ struct ImplicitURLSessionStreamingSessionFactory: StreamingSessionFactory {
5255
.init(
5356
urlRequest: urlRequest,
5457
interpreter: .init(),
58+
sslDelegate: sslDelegate,
5559
onReceiveContent: onReceiveContent,
5660
onProcessingError: onProcessingError,
5761
onComplete: onComplete

Sources/OpenAI/Private/Streaming/StreamingSession.swift

+12
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,21 @@ final class StreamingSession<Interpreter: StreamInterpreter>: NSObject, Identifi
2020
private let onProcessingError: (@Sendable (StreamingSession, Error) -> Void)?
2121
private let onComplete: (@Sendable (StreamingSession, Error?) -> Void)?
2222
private let interpreter: Interpreter
23+
private let sslDelegate: SSLDelegateProtocol?
2324

2425
init(
2526
urlSessionFactory: URLSessionFactory = FoundationURLSessionFactory(),
2627
urlRequest: URLRequest,
2728
interpreter: Interpreter,
29+
sslDelegate: SSLDelegateProtocol?,
2830
onReceiveContent: @escaping @Sendable (StreamingSession, ResultType) -> Void,
2931
onProcessingError: @escaping @Sendable (StreamingSession, Error) -> Void,
3032
onComplete: @escaping @Sendable (StreamingSession, Error?) -> Void
3133
) {
3234
self.urlSessionFactory = urlSessionFactory
3335
self.urlRequest = urlRequest
3436
self.interpreter = interpreter
37+
self.sslDelegate = sslDelegate
3538
self.onReceiveContent = onReceiveContent
3639
self.onProcessingError = onProcessingError
3740
self.onComplete = onComplete
@@ -67,6 +70,15 @@ final class StreamingSession<Interpreter: StreamInterpreter>: NSObject, Identifi
6770
completionHandler(.allow)
6871
}
6972

73+
func urlSession(
74+
_ session: URLSession,
75+
didReceive challenge: URLAuthenticationChallenge,
76+
completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void
77+
) {
78+
guard let sslDelegate else { return completionHandler(.performDefaultHandling, nil) }
79+
sslDelegate.urlSession(session, didReceive: challenge, completionHandler: completionHandler)
80+
}
81+
7082
private func subscribeToParser() {
7183
interpreter.setCallbackClosures { [weak self] content in
7284
guard let self else { return }

Sources/OpenAI/Private/URLSessionDataDelegateForwarder.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import Foundation
1111
import FoundationNetworking
1212
#endif
1313

14-
class URLSessionDataDelegateForwarder: NSObject, URLSessionDataDelegate {
14+
final class URLSessionDataDelegateForwarder: NSObject, URLSessionDataDelegate {
1515
let target: URLSessionDataDelegateProtocol
1616

1717
init(target: URLSessionDataDelegateProtocol) {

Sources/OpenAI/Private/URLSessionDelegateProtocol.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import Foundation
99

10-
protocol URLSessionDelegateProtocol {
10+
protocol URLSessionDelegateProtocol: Sendable { // Sendable to make a better match with URLSessionDelegate, it's sendable too
1111
func urlSession(_ session: URLSessionProtocol, task: URLSessionTaskProtocol, didCompleteWithError error: Error?)
1212
}
1313

Sources/OpenAI/Public/Protocols/OpenAIProtocol.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ public protocol OpenAIProtocol: OpenAIModern {
209209
- onResult: A closure which receives the result when the API request finishes. The closure's parameter, `Result<AudioSpeechResult, Error>`, will contain either the `AudioSpeechResult` object with the generated Audio chunk, or an error if the request failed.
210210
- completion: A closure that is being called when all chunks are delivered or uncrecoverable error occured
211211
*/
212-
func audioCreateSpeechStream(query: AudioSpeechQuery, onResult: @escaping (Result<AudioSpeechResult, Error>) -> Void, completion: ((Error?) -> Void)?) -> CancellableRequest
212+
func audioCreateSpeechStream(query: AudioSpeechQuery, onResult: @escaping @Sendable (Result<AudioSpeechResult, Error>) -> Void, completion: (@Sendable (Error?) -> Void)?) -> CancellableRequest
213213

214214
/**
215215
Transcribes audio data using OpenAI's audio transcription API and completes the operation asynchronously.

Tests/OpenAITests/Mocks/MockStreamingSessionFactory.swift

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class MockStreamingSessionFactory: StreamingSessionFactory {
2525
urlSessionFactory: urlSessionFactory,
2626
urlRequest: urlRequest,
2727
interpreter: .init(executionSerializer: NoDispatchExecutionSerializer()),
28+
sslDelegate: nil,
2829
onReceiveContent: onReceiveContent,
2930
onProcessingError: onProcessingError,
3031
onComplete: onComplete
@@ -41,6 +42,7 @@ class MockStreamingSessionFactory: StreamingSessionFactory {
4142
urlSessionFactory: urlSessionFactory,
4243
urlRequest: urlRequest,
4344
interpreter: .init(),
45+
sslDelegate: nil,
4446
onReceiveContent: onReceiveContent,
4547
onProcessingError: onProcessingError,
4648
onComplete: onComplete

Tests/OpenAITests/OpenAITests.swift

+10-10
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class OpenAITests: XCTestCase {
1616

1717
override func setUp() async throws {
1818
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14)
19-
self.openAI = OpenAI(configuration: configuration, session: self.urlSession)
19+
self.openAI = OpenAI(configuration: configuration, session: self.urlSession, streamingSessionFactory: MockStreamingSessionFactory())
2020
}
2121

2222
func testImages() async throws {
@@ -440,21 +440,21 @@ class OpenAITests: XCTestCase {
440440

441441
func testDefaultHostURLBuilt() {
442442
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14)
443-
let openAI = OpenAI(configuration: configuration, session: self.urlSession)
443+
let openAI = OpenAI(configuration: configuration, session: self.urlSession, streamingSessionFactory: MockStreamingSessionFactory())
444444
let chatsURL = openAI.buildURL(path: .chats)
445445
XCTAssertEqual(chatsURL, URL(string: "https://api.openai.com:443/v1/chat/completions"))
446446
}
447447

448448
func testDefaultHostURLBuiltWithCustomBasePath() {
449449
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", basePath: "/api/v9527", timeoutInterval: 14)
450-
let openAI = OpenAI(configuration: configuration, session: self.urlSession)
450+
let openAI = OpenAI(configuration: configuration, session: self.urlSession, streamingSessionFactory: MockStreamingSessionFactory())
451451
let chatsURL = openAI.buildURL(path: .chats)
452452
XCTAssertEqual(chatsURL, URL(string: "https://api.openai.com:443/api/v9527/chat/completions"))
453453
}
454454

455455
func testCustomURLBuiltWithPredefinedPath() {
456456
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", host: "my.host.com", timeoutInterval: 14)
457-
let openAI = OpenAI(configuration: configuration, session: self.urlSession)
457+
let openAI = OpenAI(configuration: configuration, session: self.urlSession, streamingSessionFactory: MockStreamingSessionFactory())
458458
let chatsURL = openAI.buildURL(path: .chats)
459459
XCTAssertEqual(chatsURL, URL(string: "https://my.host.com:443/v1/chat/completions"))
460460
}
@@ -466,7 +466,7 @@ class OpenAITests: XCTestCase {
466466
host: "bizbaz.com",
467467
timeoutInterval: 14
468468
)
469-
let openAI = OpenAI(configuration: configuration, session: URLSessionMock())
469+
let openAI = OpenAI(configuration: configuration, session: URLSessionMock(), streamingSessionFactory: MockStreamingSessionFactory())
470470
XCTAssertEqual(openAI.buildURL(path: "foo"), URL(string: "https://bizbaz.com:443/v1/foo"))
471471
}
472472

@@ -478,7 +478,7 @@ class OpenAITests: XCTestCase {
478478
basePath: "/openai",
479479
timeoutInterval: 14
480480
)
481-
let openAI = OpenAI(configuration: configuration, session: URLSessionMock())
481+
let openAI = OpenAI(configuration: configuration, session: URLSessionMock(), streamingSessionFactory: MockStreamingSessionFactory())
482482
XCTAssertEqual(openAI.buildURL(path: "foo"), URL(string:"https://bizbaz.com:443/openai/foo"))
483483
}
484484

@@ -490,7 +490,7 @@ class OpenAITests: XCTestCase {
490490
basePath: "/openai/",
491491
timeoutInterval: 14
492492
)
493-
let openAI = OpenAI(configuration: configuration, session: URLSessionMock())
493+
let openAI = OpenAI(configuration: configuration, session: URLSessionMock(), streamingSessionFactory: MockStreamingSessionFactory())
494494
XCTAssertEqual(openAI.buildURL(path: "/foo"), URL(string: "https://bizbaz.com:443/openai/foo"))
495495
}
496496

@@ -709,21 +709,21 @@ class OpenAITests: XCTestCase {
709709

710710
func testCustomRunsURLBuilt() {
711711
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", host: "my.host.com", timeoutInterval: 14)
712-
let openAI = OpenAI(configuration: configuration, session: self.urlSession)
712+
let openAI = OpenAI(configuration: configuration, session: self.urlSession, streamingSessionFactory: MockStreamingSessionFactory())
713713
let completionsURL = openAI.buildRunsURL(path: APIPath.Assistants.runs.stringValue, threadId: "thread_4321")
714714
XCTAssertEqual(completionsURL, URL(string: "https://my.host.com:443/v1/threads/thread_4321/runs"))
715715
}
716716

717717
func testCustomRunsRetrieveURLBuilt() {
718718
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", host: "my.host.com", timeoutInterval: 14)
719-
let openAI = OpenAI(configuration: configuration, session: self.urlSession)
719+
let openAI = OpenAI(configuration: configuration, session: self.urlSession, streamingSessionFactory: MockStreamingSessionFactory())
720720
let completionsURL = openAI.buildRunRetrieveURL(path: APIPath.Assistants.runRetrieve.stringValue, threadId: "thread_4321", runId: "run_1234")
721721
XCTAssertEqual(completionsURL, URL(string: "https://my.host.com:443/v1/threads/thread_4321/runs/run_1234"))
722722
}
723723

724724
func testCustomRunRetrieveStepsURLBuilt() {
725725
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", host: "my.host.com", timeoutInterval: 14)
726-
let openAI = OpenAI(configuration: configuration, session: self.urlSession)
726+
let openAI = OpenAI(configuration: configuration, session: self.urlSession, streamingSessionFactory: MockStreamingSessionFactory())
727727
let completionsURL = openAI.buildRunRetrieveURL(path: APIPath.Assistants.runRetrieveSteps.stringValue, threadId: "thread_4321", runId: "run_1234")
728728
XCTAssertEqual(completionsURL, URL(string: "https://my.host.com:443/v1/threads/thread_4321/runs/run_1234/steps"))
729729
}

Tests/OpenAITests/OpenAITestsCombine.swift

+9-1
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,17 @@ import Combine
1616

1717
private var openAI: OpenAIProtocol!
1818
private let urlSession: URLSessionMockCombine = URLSessionMockCombine()
19+
private let streamingSessionFactory = MockStreamingSessionFactory()
1920
private let cancellablesFactory = MockCancellablesFactory()
2021

2122
override func setUp() async throws {
2223
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14)
23-
self.openAI = OpenAI(configuration: configuration, session: self.urlSession, cancellablesFactory: cancellablesFactory)
24+
self.openAI = OpenAI(
25+
configuration: configuration,
26+
session: self.urlSession,
27+
streamingSessionFactory: streamingSessionFactory,
28+
cancellablesFactory: cancellablesFactory
29+
)
2430
}
2531

2632
func testChats() throws {
@@ -252,13 +258,15 @@ extension OpenAITestsCombine {
252258
func stub(error: URLError) {
253259
let task = DataTaskMock.failed(with: error)
254260
self.urlSession.dataTask = task
261+
self.streamingSessionFactory.urlSessionFactory.urlSession.dataTask = task
255262
}
256263

257264
func stub(result: Codable) throws {
258265
let encoder = JSONEncoder()
259266
let data = try encoder.encode(result)
260267
let task = DataTaskMock.successful(with: data)
261268
self.urlSession.dataTask = task
269+
self.streamingSessionFactory.urlSessionFactory.urlSession.dataTask = task
262270
}
263271
}
264272

Tests/OpenAITests/StreamingSessionTests.swift

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ final class StreamingSessionTests: XCTestCase {
1717
urlSessionFactory: MockURLSessionFactory(),
1818
urlRequest: .init(url: .init(string: "/")!),
1919
interpreter: streamInterpreter,
20+
sslDelegate: nil,
2021
onReceiveContent: { _, _ in
2122
Task {
2223
await MainActor.run {

0 commit comments

Comments
 (0)