Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit f94527f

Browse files
authored
feat: add openai assistant (#1826)
1 parent 3456c7b commit f94527f

30 files changed

+3289
-199
lines changed

docs/static/openapi/cortex.json

Lines changed: 497 additions & 45 deletions
Large diffs are not rendered by default.

engine/common/assistant.h

Lines changed: 267 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
#pragma once
22

33
#include <string>
4+
#include "common/assistant_code_interpreter_tool.h"
5+
#include "common/assistant_file_search_tool.h"
6+
#include "common/assistant_function_tool.h"
47
#include "common/assistant_tool.h"
5-
#include "common/thread_tool_resources.h"
8+
#include "common/tool_resources.h"
69
#include "common/variant_map.h"
10+
#include "utils/logging_utils.h"
711
#include "utils/result.hpp"
812

913
namespace OpenAi {
@@ -75,7 +79,49 @@ struct JanAssistant : JsonSerializable {
7579
}
7680
};
7781

78-
struct Assistant {
82+
struct Assistant : JsonSerializable {
83+
Assistant() = default;
84+
85+
~Assistant() = default;
86+
87+
Assistant(const Assistant&) = delete;
88+
89+
Assistant& operator=(const Assistant&) = delete;
90+
91+
Assistant(Assistant&& other) noexcept
92+
: id{std::move(other.id)},
93+
object{std::move(other.object)},
94+
created_at{other.created_at},
95+
name{std::move(other.name)},
96+
description{std::move(other.description)},
97+
model(std::move(other.model)),
98+
instructions(std::move(other.instructions)),
99+
tools(std::move(other.tools)),
100+
tool_resources(std::move(other.tool_resources)),
101+
metadata(std::move(other.metadata)),
102+
temperature{std::move(other.temperature)},
103+
top_p{std::move(other.top_p)},
104+
response_format{std::move(other.response_format)} {}
105+
106+
Assistant& operator=(Assistant&& other) noexcept {
107+
if (this != &other) {
108+
id = std::move(other.id);
109+
object = std::move(other.object);
110+
created_at = other.created_at;
111+
name = std::move(other.name);
112+
description = std::move(other.description);
113+
model = std::move(other.model);
114+
instructions = std::move(other.instructions);
115+
tools = std::move(other.tools);
116+
tool_resources = std::move(other.tool_resources);
117+
metadata = std::move(other.metadata);
118+
temperature = std::move(other.temperature);
119+
top_p = std::move(other.top_p);
120+
response_format = std::move(other.response_format);
121+
}
122+
return *this;
123+
}
124+
79125
/**
80126
* The identifier, which can be referenced in API endpoints.
81127
*/
@@ -126,8 +172,7 @@ struct Assistant {
126172
* requires a list of file IDs, while the file_search tool requires a list
127173
* of vector store IDs.
128174
*/
129-
std::optional<std::variant<ThreadCodeInterpreter, ThreadFileSearch>>
130-
tool_resources;
175+
std::unique_ptr<OpenAi::ToolResources> tool_resources;
131176

132177
/**
133178
* Set of 16 key-value pairs that can be attached to an object. This can be
@@ -153,5 +198,223 @@ struct Assistant {
153198
* We generally recommend altering this or temperature but not both.
154199
*/
155200
std::optional<float> top_p;
201+
202+
std::variant<std::string, Json::Value> response_format;
203+
204+
cpp::result<Json::Value, std::string> ToJson() override {
205+
try {
206+
Json::Value root;
207+
208+
root["id"] = std::move(id);
209+
root["object"] = "assistant";
210+
root["created_at"] = created_at;
211+
if (name.has_value()) {
212+
root["name"] = name.value();
213+
}
214+
if (description.has_value()) {
215+
root["description"] = description.value();
216+
}
217+
root["model"] = model;
218+
if (instructions.has_value()) {
219+
root["instructions"] = instructions.value();
220+
}
221+
222+
Json::Value tools_jarr{Json::arrayValue};
223+
for (auto& tool_ptr : tools) {
224+
if (auto it = tool_ptr->ToJson(); it.has_value()) {
225+
tools_jarr.append(it.value());
226+
} else {
227+
CTL_WRN("Failed to convert content to json: " + it.error());
228+
}
229+
}
230+
root["tools"] = tools_jarr;
231+
if (tool_resources) {
232+
Json::Value tool_resources_json{Json::objectValue};
233+
234+
if (auto* code_interpreter =
235+
dynamic_cast<OpenAi::CodeInterpreter*>(tool_resources.get())) {
236+
auto result = code_interpreter->ToJson();
237+
if (result.has_value()) {
238+
tool_resources_json["code_interpreter"] = result.value();
239+
} else {
240+
CTL_WRN("Failed to convert code_interpreter to json: " +
241+
result.error());
242+
}
243+
} else if (auto* file_search = dynamic_cast<OpenAi::FileSearch*>(
244+
tool_resources.get())) {
245+
auto result = file_search->ToJson();
246+
if (result.has_value()) {
247+
tool_resources_json["file_search"] = result.value();
248+
} else {
249+
CTL_WRN("Failed to convert file_search to json: " + result.error());
250+
}
251+
}
252+
253+
// Only add tool_resources to root if we successfully serialized some resources
254+
if (!tool_resources_json.empty()) {
255+
root["tool_resources"] = tool_resources_json;
256+
}
257+
}
258+
Json::Value metadata_json{Json::objectValue};
259+
for (const auto& [key, value] : metadata) {
260+
if (std::holds_alternative<bool>(value)) {
261+
metadata_json[key] = std::get<bool>(value);
262+
} else if (std::holds_alternative<uint64_t>(value)) {
263+
metadata_json[key] = std::get<uint64_t>(value);
264+
} else if (std::holds_alternative<double>(value)) {
265+
metadata_json[key] = std::get<double>(value);
266+
} else {
267+
metadata_json[key] = std::get<std::string>(value);
268+
}
269+
}
270+
root["metadata"] = metadata_json;
271+
272+
if (temperature.has_value()) {
273+
root["temperature"] = temperature.value();
274+
}
275+
if (top_p.has_value()) {
276+
root["top_p"] = top_p.value();
277+
}
278+
return root;
279+
} catch (const std::exception& e) {
280+
return cpp::fail("ToJson failed: " + std::string(e.what()));
281+
}
282+
}
283+
284+
static cpp::result<Assistant, std::string> FromJson(Json::Value&& json) {
285+
try {
286+
Assistant assistant;
287+
288+
// Parse required fields
289+
if (!json.isMember("id") || !json["id"].isString()) {
290+
return cpp::fail("Missing or invalid 'id' field");
291+
}
292+
assistant.id = json["id"].asString();
293+
294+
if (!json.isMember("object") || !json["object"].isString() ||
295+
json["object"].asString() != "assistant") {
296+
return cpp::fail("Missing or invalid 'object' field");
297+
}
298+
299+
if (!json.isMember("created_at") || !json["created_at"].isUInt64()) {
300+
return cpp::fail("Missing or invalid 'created_at' field");
301+
}
302+
assistant.created_at = json["created_at"].asUInt64();
303+
304+
if (!json.isMember("model") || !json["model"].isString()) {
305+
return cpp::fail("Missing or invalid 'model' field");
306+
}
307+
assistant.model = json["model"].asString();
308+
309+
// Parse optional fields
310+
if (json.isMember("name") && json["name"].isString()) {
311+
assistant.name = json["name"].asString();
312+
}
313+
314+
if (json.isMember("description") && json["description"].isString()) {
315+
assistant.description = json["description"].asString();
316+
}
317+
318+
if (json.isMember("instructions") && json["instructions"].isString()) {
319+
assistant.instructions = json["instructions"].asString();
320+
}
321+
322+
// Parse tools array
323+
if (json.isMember("tools") && json["tools"].isArray()) {
324+
auto tools_array = json["tools"];
325+
for (const auto& tool : tools_array) {
326+
if (!tool.isMember("type") || !tool["type"].isString()) {
327+
CTL_WRN("Tool missing type field or invalid type");
328+
continue;
329+
}
330+
331+
std::string tool_type = tool["type"].asString();
332+
if (tool_type == "file_search") {
333+
auto result = AssistantFileSearchTool::FromJson(tool);
334+
if (result.has_value()) {
335+
assistant.tools.push_back(
336+
std::make_unique<AssistantFileSearchTool>(
337+
std::move(result.value())));
338+
} else {
339+
CTL_WRN("Failed to parse file_search tool: " + result.error());
340+
}
341+
} else if (tool_type == "code_interpreter") {
342+
auto result = AssistantCodeInterpreterTool::FromJson();
343+
if (result.has_value()) {
344+
assistant.tools.push_back(
345+
std::make_unique<AssistantCodeInterpreterTool>(
346+
std::move(result.value())));
347+
} else {
348+
CTL_WRN("Failed to parse code_interpreter tool: " +
349+
result.error());
350+
}
351+
} else if (tool_type == "function") {
352+
auto result = AssistantFunctionTool::FromJson(tool);
353+
if (result.has_value()) {
354+
assistant.tools.push_back(std::make_unique<AssistantFunctionTool>(
355+
std::move(result.value())));
356+
} else {
357+
CTL_WRN("Failed to parse function tool: " + result.error());
358+
}
359+
} else {
360+
CTL_WRN("Unknown tool type: " + tool_type);
361+
}
362+
}
363+
}
364+
365+
if (json.isMember("tool_resources") &&
366+
json["tool_resources"].isObject()) {
367+
const auto& tool_resources_json = json["tool_resources"];
368+
369+
// Parse code interpreter resources
370+
if (tool_resources_json.isMember("code_interpreter")) {
371+
auto result = OpenAi::CodeInterpreter::FromJson(
372+
tool_resources_json["code_interpreter"]);
373+
if (result.has_value()) {
374+
assistant.tool_resources =
375+
std::make_unique<OpenAi::CodeInterpreter>(
376+
std::move(result.value()));
377+
} else {
378+
CTL_WRN("Failed to parse code_interpreter resources: " +
379+
result.error());
380+
}
381+
}
382+
383+
// Parse file search resources
384+
if (tool_resources_json.isMember("file_search")) {
385+
auto result =
386+
OpenAi::FileSearch::FromJson(tool_resources_json["file_search"]);
387+
if (result.has_value()) {
388+
assistant.tool_resources =
389+
std::make_unique<OpenAi::FileSearch>(std::move(result.value()));
390+
} else {
391+
CTL_WRN("Failed to parse file_search resources: " + result.error());
392+
}
393+
}
394+
}
395+
396+
// Parse metadata
397+
if (json.isMember("metadata") && json["metadata"].isObject()) {
398+
auto res = Cortex::ConvertJsonValueToMap(json["metadata"]);
399+
if (res.has_value()) {
400+
assistant.metadata = res.value();
401+
} else {
402+
CTL_WRN("Failed to convert metadata to map: " + res.error());
403+
}
404+
}
405+
406+
if (json.isMember("temperature") && json["temperature"].isDouble()) {
407+
assistant.temperature = json["temperature"].asFloat();
408+
}
409+
410+
if (json.isMember("top_p") && json["top_p"].isDouble()) {
411+
assistant.top_p = json["top_p"].asFloat();
412+
}
413+
414+
return assistant;
415+
} catch (const std::exception& e) {
416+
return cpp::fail("FromJson failed: " + std::string(e.what()));
417+
}
418+
}
156419
};
157420
} // namespace OpenAi
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
3+
#include "common/assistant_tool.h"
4+
5+
namespace OpenAi {
6+
struct AssistantCodeInterpreterTool : public AssistantTool {
7+
AssistantCodeInterpreterTool() : AssistantTool("code_interpreter") {}
8+
9+
AssistantCodeInterpreterTool(const AssistantCodeInterpreterTool&) = delete;
10+
11+
AssistantCodeInterpreterTool& operator=(const AssistantCodeInterpreterTool&) =
12+
delete;
13+
14+
AssistantCodeInterpreterTool(AssistantCodeInterpreterTool&&) = default;
15+
16+
AssistantCodeInterpreterTool& operator=(AssistantCodeInterpreterTool&&) =
17+
default;
18+
19+
~AssistantCodeInterpreterTool() = default;
20+
21+
static cpp::result<AssistantCodeInterpreterTool, std::string> FromJson() {
22+
AssistantCodeInterpreterTool tool;
23+
return std::move(tool);
24+
}
25+
26+
cpp::result<Json::Value, std::string> ToJson() override {
27+
Json::Value json;
28+
json["type"] = type;
29+
return json;
30+
}
31+
};
32+
} // namespace OpenAi

0 commit comments

Comments
 (0)