Skip to content

Refactor ONNX runtime execution logic and enhance tests #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion duckdb
187 changes: 97 additions & 90 deletions src/onnx_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,123 +9,130 @@
#include <duckdb/parser/parsed_data/create_scalar_function_info.hpp>

// OpenSSL linked through vcpkg
#include <onnxruntime_cxx_api.h>

Check failure on line 12 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64, x64-windows-static-md-release, x64-windows-static-md-release)

Cannot open include file: 'onnxruntime_cxx_api.h': No such file or directory

Check failure on line 12 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64, x64-windows-static-md-release, x64-windows-static-md-release)

Cannot open include file: 'onnxruntime_cxx_api.h': No such file or directory

Check failure on line 12 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64, x64-windows-static-md-release, x64-windows-static-md-release)

Cannot open include file: 'onnxruntime_cxx_api.h': No such file or directory

Check failure on line 12 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64, x64-windows-static-md-release, x64-windows-static-md-release)

Cannot open include file: 'onnxruntime_cxx_api.h': No such file or directory
#include <openssl/opensslv.h>

Check failure on line 13 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten)

'openssl/opensslv.h' file not found

Check failure on line 13 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten)

'openssl/opensslv.h' file not found

Check failure on line 13 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten)

'openssl/opensslv.h' file not found

Check failure on line 13 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten, x64-linux)

'openssl/opensslv.h' file not found

Check failure on line 13 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten, x64-linux)

'openssl/opensslv.h' file not found

Check failure on line 13 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten, x64-linux)

'openssl/opensslv.h' file not found

Check failure on line 13 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten)

'openssl/opensslv.h' file not found

Check failure on line 13 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten)

'openssl/opensslv.h' file not found

Check failure on line 13 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten, x64-linux)

'openssl/opensslv.h' file not found

Check failure on line 13 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten, x64-linux)

'openssl/opensslv.h' file not found

Check failure on line 13 in src/onnx_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten, x64-linux)

'openssl/opensslv.h' file not found

namespace duckdb {

struct TensorOutput {
struct Tensor {
std::vector<int64_t> shape;
std::vector<float> values;
};

TensorOutput ConvertToTensorOutput(Ort::Value &tensor) {

try {
TensorOutput result;
auto tensor_info = tensor.GetTensorTypeAndShapeInfo();
result.shape = tensor_info.GetShape();

auto *data = tensor.GetTensorMutableData<float>();
size_t total_size = tensor_info.GetElementCount();
result.values.assign(data, data + total_size);

return result;

} catch (const Ort::Exception &e) {
throw std::runtime_error("Convert tensor ONNX Runtime error:" +
std::string(e.what()));
void run_onnx_model_and_extract_results(const string &path,
vector<Tensor> input_tensors,
Value &struct_val_list) {
Ort::Env ort_env;
Ort::Session session{ort_env, (path.c_str()), Ort::SessionOptions{nullptr}};
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);

/// now only support 1 input tensor
Tensor duckdb_input_tensor = input_tensors[0];

Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
memory_info, duckdb_input_tensor.values.data(),
duckdb_input_tensor.values.size(), duckdb_input_tensor.shape.data(),
duckdb_input_tensor.shape.size());

Ort::TypeInfo type_info = session.GetOutputTypeInfo(0);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
std::vector<int64_t> output_shape = tensor_info.GetShape();
size_t output_size = 1;
for (int64_t dim : output_shape) {
if (dim < 0) {
dim = 1;
}
output_size *= dim;
}
}

inline void OnnxScalarFun(DataChunk &args, ExpressionState &state,
Vector &result) {
auto &str_vector = args.data[0];
auto &struct_vector = args.data[1];

vector<unique_ptr<Vector>> &source_child =
StructVector::GetEntries(struct_vector);
std::vector<float> output_data(output_size);
Ort::Value output_tensor = Ort::Value::CreateTensor<float>(
memory_info, output_data.data(), output_size, output_shape.data(),
output_shape.size());

const auto path = str_vector.GetValue(0).ToString();

Vector &shape_vector = *source_child.front();
Vector &value_vector = *source_child.back();
const char *input_names[] = {"X"};
const char *output_names[] = {"Y"};

idx_t shape_list_size = ListVector::GetListSize(shape_vector);
idx_t value_list_size = ListVector::GetListSize(value_vector);
session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1,
output_names, &output_tensor, 1);

std::vector<int64_t> input_shape;
input_shape.reserve(shape_list_size);
for (int idx = 0; idx < shape_list_size; idx++) {
input_shape.push_back(
ListVector::GetEntry(shape_vector).GetValue(idx).GetValue<int64_t>());
std::vector<Value> shape_vl_list;
shape_vl_list.reserve(output_shape.size());
for (int64_t dim : output_shape) {
shape_vl_list.emplace_back(dim);
}

std::vector<float> input_data;
input_data.reserve(value_list_size);
for (int idx = 0; idx < value_list_size; idx++) {
input_data.push_back(
ListVector::GetEntry(value_vector).GetValue(idx).GetValue<float>());
auto shape_val_list = Value::LIST(std::move(shape_vl_list));
std::vector<Value> value_vl_list;
value_vl_list.reserve(output_data.size());
for (float val : output_data) {
value_vl_list.emplace_back(val);
}
auto value_val_list = Value::LIST(std::move(value_vl_list));

try {
Ort::Env ort_env;
Ort::Session session{ort_env, (path.c_str()), Ort::SessionOptions{nullptr}};
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);

Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
memory_info, input_data.data(), input_data.size(), input_shape.data(),
input_shape.size());

Ort::TypeInfo type_info = session.GetOutputTypeInfo(0);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();

std::vector<int64_t> output_shape = tensor_info.GetShape();

size_t output_size = 1;
for (int64_t dim : output_shape) {
if (dim < 0) {
dim = input_shape[0];
}
output_size *= dim;
}

std::vector<float> output_data(output_size);
Ort::Value output_tensor = Ort::Value::CreateTensor<float>(
memory_info, output_data.data(), output_size, output_shape.data(),
output_shape.size());
child_list_t<Value> struct_vl_list;
struct_vl_list.push_back(make_pair("shape", shape_val_list));
struct_vl_list.push_back(make_pair("value", value_val_list));
struct_val_list = Value::STRUCT(std::move(struct_vl_list));
}

const char *input_names[] = {"X"};
const char *output_names[] = {"Y"};
inline void OnnxScalarFun(DataChunk &args, ExpressionState &state,
Vector &result) {
auto count = args.size();

session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1,
output_names, &output_tensor, 1);
auto &str_vector = args.data[0];
const auto path = str_vector.GetValue(0).ToString();

std::vector<Value> shape_vl_list;
shape_vl_list.reserve(output_shape.size());
for (int64_t dim : output_shape) {
shape_vl_list.emplace_back(dim);
auto &struct_vector = args.data[1]; // { shape: int[], value: float[]}

auto &struct_child = StructVector::GetEntries(struct_vector);

///
Vector &shape_list_vec_ref = *struct_child[0];
auto *shape_list_data =
reinterpret_cast<list_entry_t *>(FlatVector::GetData(shape_list_vec_ref));
Vector &shape_list_child_vec = ListVector::GetEntry(shape_list_vec_ref);
auto *shape_child_data = reinterpret_cast<int32_t *>(
duckdb::FlatVector::GetData(shape_list_child_vec));

///
Vector &value_list_vec_ref = *struct_child[1];
auto *value_list_data =
reinterpret_cast<list_entry_t *>(FlatVector::GetData(value_list_vec_ref));
Vector &value_list_child_vec = ListVector::GetEntry(value_list_vec_ref);

/// run
for (idx_t row = 0; row < count; row++) {

vector<int64_t> shape_std_vec = std::vector<int64_t>();
shape_std_vec.reserve(count);
list_entry_t list = shape_list_data[row];
for (idx_t child_idx = list.offset; child_idx < list.offset + list.length;
child_idx++) {
shape_std_vec.push_back(shape_child_data[child_idx]);
}

auto shape_val_list = Value::LIST(std::move(shape_vl_list));
std::vector<Value> value_vl_list;
value_vl_list.reserve(output_data.size());
for (float val : output_data) {
value_vl_list.emplace_back(val);
vector<float> value_std_vec = std::vector<float>();
value_std_vec.reserve(count);
list = value_list_data[row];
for (idx_t child_idx = list.offset; child_idx < list.offset + list.length;
child_idx++) {
value_std_vec.push_back(
value_list_child_vec.GetValue(child_idx).GetValue<float>());
}
auto value_val_list = Value::LIST(std::move(value_vl_list));

child_list_t<Value> struct_vl_list;
struct_vl_list.push_back(make_pair("shape", shape_val_list));
struct_vl_list.push_back(make_pair("value", value_val_list));
auto max_struct_val_list = Value::STRUCT(std::move(struct_vl_list));
vector<Tensor> onnx_inputs;
onnx_inputs.emplace_back(Tensor{shape_std_vec, value_std_vec});

result.SetValue(0, max_struct_val_list);
Value struct_val_list;
try {
run_onnx_model_and_extract_results(path, onnx_inputs, struct_val_list);
} catch (...) {
throw std::runtime_error("Onnxruntime error");
}
result.SetValue(row, struct_val_list);
result.SetVectorType(VectorType::CONSTANT_VECTOR);

} catch (...) {
throw std::runtime_error("Onnxruntime error");
}
}

Expand Down
15 changes: 15 additions & 0 deletions test/sql/onnx.test
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,19 @@ require onnx
query I
SELECT onnx('test/sql/mul_1.onnx',{'shape':[3,2],'value': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]}) AS s;
----
{'shape': [3, 2], 'value': [1.0, 4.0, 9.0, 16.0, 25.0, 36.0]}

statement ok
CREATE TABLE t1 (c1 INT[], c2 FLOAT[]);

statement ok
INSERT INTO t1 values ([3,2],[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
INSERT INTO t1 values ([3,2],[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
INSERT INTO t1 values ([3,2],[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);

query I
SELECT onnx('test/sql/mul_1.onnx',{'shape':c1,'value':c2}) FROM t1;
----
{'shape': [3, 2], 'value': [1.0, 4.0, 9.0, 16.0, 25.0, 36.0]}
{'shape': [3, 2], 'value': [1.0, 4.0, 9.0, 16.0, 25.0, 36.0]}
{'shape': [3, 2], 'value': [1.0, 4.0, 9.0, 16.0, 25.0, 36.0]}
Loading