Skip to content

Commit b1c34ae

Browse files
authored
Merge pull request #314 from dounan/feat/add-dynamic-json-schema
2 parents 2d2c06f + ff0fe4e commit b1c34ae

File tree

2 files changed

+120
-1
lines changed

2 files changed

+120
-1
lines changed

Sources/OpenAI/Public/Models/ChatQuery.swift

+55-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//
22
// ChatQuery.swift
3-
//
3+
//
44
//
55
// Created by Sergii Kryvoblotskyi on 02/04/2023.
66
//
@@ -819,10 +819,12 @@ public struct ChatQuery: Equatable, Codable, Streamable, Sendable {
819819
case text
820820
case jsonObject
821821
case jsonSchema(name: String, type: StructuredOutput.Type)
822+
case dynamicJsonSchema(DynamicJSONSchema)
822823

823824
enum CodingKeys: String, CodingKey {
824825
case type
825826
case jsonSchema = "json_schema"
827+
case dynamicJsonSchema
826828
}
827829

828830
public func encode(to encoder: any Encoder) throws {
@@ -836,6 +838,9 @@ public struct ChatQuery: Equatable, Codable, Streamable, Sendable {
836838
try container.encode("json_schema", forKey: .type)
837839
let schema = JSONSchema(name: name, schema: type.example)
838840
try container.encode(schema, forKey: .jsonSchema)
841+
case .dynamicJsonSchema(let dynamicJSONSchema):
842+
try container.encode("json_schema", forKey: .type)
843+
try container.encode(dynamicJSONSchema, forKey: .jsonSchema)
839844
}
840845
}
841846

@@ -845,6 +850,8 @@ public struct ChatQuery: Equatable, Codable, Streamable, Sendable {
845850
case (.jsonObject, .jsonObject): return true
846851
case (.jsonSchema(let lhsName, let lhsType), .jsonSchema(let rhsName, let rhsType)):
847852
return lhsName == rhsName && lhsType == rhsType
853+
case (.dynamicJsonSchema(let lhsSchema), .dynamicJsonSchema(let rhsSchema)):
854+
return lhsSchema == rhsSchema
848855
default:
849856
return false
850857
}
@@ -1072,6 +1079,53 @@ public struct ChatQuery: Equatable, Codable, Streamable, Sendable {
10721079
}
10731080
}
10741081
}
1082+
1083+
public struct DynamicJSONSchema: Encodable, Sendable, Equatable {
1084+
let name: String
1085+
let description: String?
1086+
let schema: Encodable & Sendable
1087+
let strict: Bool?
1088+
1089+
enum CodingKeys: String, CodingKey {
1090+
case name
1091+
case description
1092+
case schema
1093+
case strict
1094+
}
1095+
1096+
public init(
1097+
name: String,
1098+
description: String? = nil,
1099+
schema: Encodable & Sendable,
1100+
strict: Bool? = nil
1101+
) {
1102+
self.name = name
1103+
self.description = description
1104+
self.schema = schema
1105+
self.strict = strict
1106+
}
1107+
1108+
public func encode(to encoder: any Encoder) throws {
1109+
var container = encoder.container(keyedBy: CodingKeys.self)
1110+
try container.encode(name, forKey: .name)
1111+
if let description {
1112+
try container.encode(description, forKey: .description)
1113+
}
1114+
try container.encode(schema, forKey: .schema)
1115+
if let strict {
1116+
try container.encode(strict, forKey: .strict)
1117+
}
1118+
}
1119+
1120+
public static func == (lhs: DynamicJSONSchema, rhs: DynamicJSONSchema) -> Bool {
1121+
guard lhs.name == rhs.name else { return false }
1122+
guard lhs.description == rhs.description else { return false }
1123+
guard lhs.strict == rhs.strict else { return false }
1124+
let lhsData = try? JSONEncoder().encode(lhs.schema)
1125+
let rhsData = try? JSONEncoder().encode(rhs.schema)
1126+
return lhsData == rhsData
1127+
}
1128+
}
10751129

10761130
public enum ChatCompletionFunctionCallOptionParam: Codable, Equatable, Sendable {
10771131
case none

Tests/OpenAITests/OpenAITests.swift

+65
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,71 @@ class OpenAITests: XCTestCase {
128128
let result = try await openAI.chats(query: query)
129129
XCTAssertEqual(result, chatResult)
130130
}
131+
132+
func testChatQueryWithDynamicStructuredOutput() async throws {
133+
134+
let chatResult = ChatResult(
135+
id: "id-12312", created: 100, model: .gpt3_5Turbo, object: "foo", serviceTier: nil, systemFingerprint: "fing",
136+
choices: [],
137+
usage: .init(completionTokens: 200, promptTokens: 100, totalTokens: 300),
138+
citations: nil
139+
)
140+
try self.stub(result: chatResult)
141+
142+
struct AnyEncodable: Encodable {
143+
144+
private let _encode: (Encoder) throws -> Void
145+
public init<T: Encodable>(_ wrapped: T) {
146+
_encode = wrapped.encode
147+
}
148+
149+
func encode(to encoder: Encoder) throws {
150+
try _encode(encoder)
151+
}
152+
}
153+
154+
let schema = [
155+
"type": AnyEncodable("object"),
156+
"properties": AnyEncodable([
157+
"title": AnyEncodable([
158+
"type": "string"
159+
]),
160+
"director": AnyEncodable([
161+
"type": "string"
162+
]),
163+
"release": AnyEncodable([
164+
"type": "string"
165+
]),
166+
"genres": AnyEncodable([
167+
"type": AnyEncodable("array"),
168+
"items": AnyEncodable([
169+
"type": AnyEncodable("string"),
170+
"enum": AnyEncodable(["action", "drama", "comedy", "scifi"])
171+
])
172+
]),
173+
"cast": AnyEncodable([
174+
"type": AnyEncodable("array"),
175+
"items": AnyEncodable([
176+
"type": "string"
177+
])
178+
])
179+
]),
180+
"additionalProperties": AnyEncodable(false)
181+
]
182+
let query = ChatQuery(
183+
messages: [.system(.init(content: "Return a structured response."))],
184+
model: .gpt4_o,
185+
responseFormat: .dynamicJsonSchema(
186+
.init(
187+
name: "movie-info",
188+
schema: schema
189+
)
190+
)
191+
)
192+
193+
let result = try await openAI.chats(query: query)
194+
XCTAssertEqual(result, chatResult)
195+
}
131196

132197
func testChatsFunction() async throws {
133198
let query = ChatQuery(messages: [

0 commit comments

Comments
 (0)