Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MNIST inference unit test and enhance ONNX model handling #5

Merged
merged 1 commit into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Set extension name here
set(TARGET_NAME onnx)
add_subdirectory(unit_test)

# DuckDB's extension distribution supports vcpkg. As such, dependencies can be added in ./vcpkg.json and then
# used in cmake with find_package. Feel free to remove or replace with other dependencies.
Expand Down
2 changes: 1 addition & 1 deletion duckdb
Submodule duckdb updated 155 files
2 changes: 1 addition & 1 deletion extension-ci-tools
35 changes: 29 additions & 6 deletions src/onnx_extension.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#define DUCKDB_EXTENSION_MAIN

#include <filesystem>
#include "onnx_extension.hpp"
#include "duckdb.hpp"
#include "duckdb/common/exception.hpp"
Expand All @@ -21,6 +22,11 @@ struct Tensor {
void run_onnx_model_and_extract_results(const string &path,
vector<Tensor> input_tensors,
Value &struct_val_list) {
// check path exist
if (!std::__fs::filesystem::exists(path)) {
throw std::runtime_error(
std::string("ONNX model file not found: ") + path);
};
Ort::Env ort_env;
Ort::Session session{ort_env, (path.c_str()), Ort::SessionOptions{nullptr}};
auto memory_info =
Expand Down Expand Up @@ -50,11 +56,23 @@ void run_onnx_model_and_extract_results(const string &path,
memory_info, output_data.data(), output_size, output_shape.data(),
output_shape.size());

const char *input_names[] = {"X"};
const char *output_names[] = {"Y"};

session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1,
output_names, &output_tensor, 1);
Ort::AllocatorWithDefaultOptions allocator;
const std::string input_name =
session.GetInputNameAllocated(0, allocator).get();
const std::string output_name =
session.GetOutputNameAllocated(0, allocator).get();

const char *input_names[] = {input_name.c_str()};
const char *output_names[] = {output_name.c_str()};

try {
session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1,
output_names, &output_tensor, 1);
} catch (const Ort::Exception &e) {
std::string persistent_message = e.what();
throw std::runtime_error(std::string("ONNXRuntime error: ") +
persistent_message);
}

std::vector<Value> shape_vl_list;
shape_vl_list.reserve(output_shape.size());
Expand Down Expand Up @@ -138,9 +156,14 @@ inline void OnnxScalarFun(DataChunk &args, ExpressionState &state,
Value struct_val_list;
try {
run_onnx_model_and_extract_results(path, onnx_inputs, struct_val_list);
} catch (Ort::Exception &e) {
throw std::runtime_error(std::string("ONNXRuntime error: ") + e.what());
} catch (std::exception &e) {
throw std::runtime_error(std::string("General error: ") + e.what());
} catch (...) {
throw std::runtime_error("Onnxruntime error");
throw std::runtime_error("Unknown error occurred.");
}

result.SetValue(row, struct_val_list);
result.SetVectorType(VectorType::CONSTANT_VECTOR);
}
Expand Down
31 changes: 31 additions & 0 deletions unit_test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
cmake_minimum_required(VERSION 3.10)

project(MNIST_Inference)


set(INPUT_DIR "${CMAKE_SOURCE_DIR}/../unit_test/mnist")
set(OUTPUT_DIR "${CMAKE_BINARY_DIR}/extension/onnx/unit_test/mnist")
message(STATUS "Input Directory: ${INPUT_DIR}")
message(STATUS "Output Directory: ${OUTPUT_DIR}")
file(MAKE_DIRECTORY ${OUTPUT_DIR})

set(CMAKE_CXX_STANDARD 11)

include_directories(${OpenCV_INCLUDE_DIRS})
find_package(OpenCV REQUIRED)

add_executable(mnist_inference mnist_inference.cpp)
add_custom_command(TARGET mnist_inference POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory ${INPUT_DIR} ${OUTPUT_DIR}
COMMENT "Copying files from ${INPUT_DIR} to ${OUTPUT_DIR}"
)

set_target_properties(mnist_inference PROPERTIES
BUILD_RPATH ${DUCKDB_LIB_DIR}
INSTALL_RPATH ${DUCKDB_LIB_DIR}
)


target_link_libraries(mnist_inference PRIVATE duckdb ${OpenCV_LIBS})


Binary file added unit_test/mnist/images/0_108.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_120.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_132.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_141.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_15.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_17.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_18.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_20.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_21.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_26.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_28.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_57.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_71.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_80.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/0_93.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/1_103.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/1_114.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/1_119.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/1_13.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/1_14.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/1_140.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/1_150.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/1_158.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added unit_test/mnist/images/1_161.png
Binary file added unit_test/mnist/images/1_164.png
Binary file added unit_test/mnist/images/1_173.png
Binary file added unit_test/mnist/images/1_177.png
Binary file added unit_test/mnist/images/1_178.png
Binary file added unit_test/mnist/images/1_186.png
Binary file added unit_test/mnist/images/1_189.png
Binary file added unit_test/mnist/images/1_19.png
Binary file added unit_test/mnist/images/1_192.png
Binary file added unit_test/mnist/images/1_32.png
Binary file added unit_test/mnist/images/1_4.png
Binary file added unit_test/mnist/images/1_48.png
Binary file added unit_test/mnist/images/1_5.png
Binary file added unit_test/mnist/images/1_73.png
Binary file added unit_test/mnist/images/1_75.png
Binary file added unit_test/mnist/images/1_77.png
Binary file added unit_test/mnist/images/1_79.png
Binary file added unit_test/mnist/images/1_8.png
Binary file added unit_test/mnist/images/1_92.png
Binary file added unit_test/mnist/images/1_95.png
Binary file added unit_test/mnist/images/2_0.png
Binary file added unit_test/mnist/images/2_101.png
Binary file added unit_test/mnist/images/2_106.png
Binary file added unit_test/mnist/images/2_126.png
Binary file added unit_test/mnist/images/2_146.png
Binary file added unit_test/mnist/images/2_149.png
Binary file added unit_test/mnist/images/2_162.png
Binary file added unit_test/mnist/images/2_171.png
Binary file added unit_test/mnist/images/2_185.png
Binary file added unit_test/mnist/images/2_188.png
Binary file added unit_test/mnist/images/2_191.png
Binary file added unit_test/mnist/images/2_22.png
Binary file added unit_test/mnist/images/2_27.png
Binary file added unit_test/mnist/images/2_30.png
Binary file added unit_test/mnist/images/2_39.png
Binary file added unit_test/mnist/images/2_55.png
Binary file added unit_test/mnist/images/2_58.png
Binary file added unit_test/mnist/images/2_63.png
Binary file added unit_test/mnist/images/2_66.png
Binary file added unit_test/mnist/images/2_82.png
Binary file added unit_test/mnist/images/2_89.png
Binary file added unit_test/mnist/images/3_105.png
Binary file added unit_test/mnist/images/3_124.png
Binary file added unit_test/mnist/images/3_129.png
Binary file added unit_test/mnist/images/3_133.png
Binary file added unit_test/mnist/images/3_135.png
Binary file added unit_test/mnist/images/3_138.png
Binary file added unit_test/mnist/images/3_139.png
Binary file added unit_test/mnist/images/3_145.png
Binary file added unit_test/mnist/images/3_151.png
Binary file added unit_test/mnist/images/3_152.png
Binary file added unit_test/mnist/images/3_169.png
Binary file added unit_test/mnist/images/3_174.png
Binary file added unit_test/mnist/images/3_31.png
Binary file added unit_test/mnist/images/3_40.png
Binary file added unit_test/mnist/images/3_45.png
Binary file added unit_test/mnist/images/3_47.png
Binary file added unit_test/mnist/images/3_49.png
Binary file added unit_test/mnist/images/3_50.png
Binary file added unit_test/mnist/images/3_51.png
Binary file added unit_test/mnist/images/3_59.png
Binary file added unit_test/mnist/images/3_67.png
Binary file added unit_test/mnist/images/3_72.png
Binary file added unit_test/mnist/images/3_74.png
Binary file added unit_test/mnist/images/3_94.png
Binary file added unit_test/mnist/images/3_98.png
Binary file added unit_test/mnist/images/4_109.png
Binary file added unit_test/mnist/images/4_118.png
Binary file added unit_test/mnist/images/4_127.png
Binary file added unit_test/mnist/images/4_137.png
Binary file added unit_test/mnist/images/4_165.png
Binary file added unit_test/mnist/images/4_182.png
Binary file added unit_test/mnist/images/4_194.png
Binary file added unit_test/mnist/images/4_29.png
Binary file added unit_test/mnist/images/4_81.png
Binary file added unit_test/mnist/images/4_86.png
Binary file added unit_test/mnist/images/4_99.png
Binary file added unit_test/mnist/images/5_1.png
Binary file added unit_test/mnist/images/5_113.png
Binary file added unit_test/mnist/images/5_116.png
Binary file added unit_test/mnist/images/5_128.png
Binary file added unit_test/mnist/images/5_134.png
Binary file added unit_test/mnist/images/5_144.png
Binary file added unit_test/mnist/images/5_148.png
Binary file added unit_test/mnist/images/5_16.png
Binary file added unit_test/mnist/images/5_166.png
Binary file added unit_test/mnist/images/5_167.png
Binary file added unit_test/mnist/images/5_172.png
Binary file added unit_test/mnist/images/5_176.png
Binary file added unit_test/mnist/images/5_179.png
Binary file added unit_test/mnist/images/5_181.png
Binary file added unit_test/mnist/images/5_184.png
Binary file added unit_test/mnist/images/5_195.png
Binary file added unit_test/mnist/images/5_196.png
Binary file added unit_test/mnist/images/5_24.png
Binary file added unit_test/mnist/images/5_25.png
Binary file added unit_test/mnist/images/5_33.png
Binary file added unit_test/mnist/images/5_34.png
Binary file added unit_test/mnist/images/5_37.png
Binary file added unit_test/mnist/images/5_56.png
Binary file added unit_test/mnist/images/5_60.png
Binary file added unit_test/mnist/images/5_64.png
Binary file added unit_test/mnist/images/5_76.png
Binary file added unit_test/mnist/images/5_84.png
Binary file added unit_test/mnist/images/5_87.png
Binary file added unit_test/mnist/images/5_9.png
Binary file added unit_test/mnist/images/6_110.png
Binary file added unit_test/mnist/images/6_122.png
Binary file added unit_test/mnist/images/6_125.png
Binary file added unit_test/mnist/images/6_136.png
Binary file added unit_test/mnist/images/6_142.png
Binary file added unit_test/mnist/images/6_175.png
Binary file added unit_test/mnist/images/6_199.png
Binary file added unit_test/mnist/images/6_36.png
Binary file added unit_test/mnist/images/6_41.png
Binary file added unit_test/mnist/images/6_42.png
Binary file added unit_test/mnist/images/6_44.png
Binary file added unit_test/mnist/images/6_52.png
Binary file added unit_test/mnist/images/6_53.png
Binary file added unit_test/mnist/images/6_54.png
Binary file added unit_test/mnist/images/6_62.png
Binary file added unit_test/mnist/images/6_68.png
Binary file added unit_test/mnist/images/6_69.png
Binary file added unit_test/mnist/images/7_10.png
Binary file added unit_test/mnist/images/7_112.png
Binary file added unit_test/mnist/images/7_12.png
Binary file added unit_test/mnist/images/7_130.png
Binary file added unit_test/mnist/images/7_147.png
Binary file added unit_test/mnist/images/7_154.png
Binary file added unit_test/mnist/images/7_168.png
Binary file added unit_test/mnist/images/7_170.png
Binary file added unit_test/mnist/images/7_3.png
Binary file added unit_test/mnist/images/7_61.png
Binary file added unit_test/mnist/images/7_65.png
Binary file added unit_test/mnist/images/7_70.png
Binary file added unit_test/mnist/images/7_78.png
Binary file added unit_test/mnist/images/7_83.png
Binary file added unit_test/mnist/images/8_102.png
Binary file added unit_test/mnist/images/8_104.png
Binary file added unit_test/mnist/images/8_117.png
Binary file added unit_test/mnist/images/8_123.png
Binary file added unit_test/mnist/images/8_131.png
Binary file added unit_test/mnist/images/8_143.png
Binary file added unit_test/mnist/images/8_153.png
Binary file added unit_test/mnist/images/8_155.png
Binary file added unit_test/mnist/images/8_156.png
Binary file added unit_test/mnist/images/8_157.png
Binary file added unit_test/mnist/images/8_159.png
Binary file added unit_test/mnist/images/8_160.png
Binary file added unit_test/mnist/images/8_163.png
Binary file added unit_test/mnist/images/8_190.png
Binary file added unit_test/mnist/images/8_193.png
Binary file added unit_test/mnist/images/8_2.png
Binary file added unit_test/mnist/images/8_23.png
Binary file added unit_test/mnist/images/8_43.png
Binary file added unit_test/mnist/images/8_6.png
Binary file added unit_test/mnist/images/8_88.png
Binary file added unit_test/mnist/images/8_97.png
Binary file added unit_test/mnist/images/9_100.png
Binary file added unit_test/mnist/images/9_107.png
Binary file added unit_test/mnist/images/9_111.png
Binary file added unit_test/mnist/images/9_115.png
Binary file added unit_test/mnist/images/9_121.png
Binary file added unit_test/mnist/images/9_180.png
Binary file added unit_test/mnist/images/9_183.png
Binary file added unit_test/mnist/images/9_187.png
Binary file added unit_test/mnist/images/9_197.png
Binary file added unit_test/mnist/images/9_198.png
Binary file added unit_test/mnist/images/9_35.png
Binary file added unit_test/mnist/images/9_38.png
Binary file added unit_test/mnist/images/9_46.png
Binary file added unit_test/mnist/images/9_85.png
Binary file added unit_test/mnist/images/9_90.png
Binary file added unit_test/mnist/images/9_91.png
Binary file added unit_test/mnist/images/9_96.png
Binary file added unit_test/mnist/onnx/mnist-8.onnx
Binary file not shown.
Binary file added unit_test/mnist/onnx/mnist.onnx
Binary file not shown.
153 changes: 153 additions & 0 deletions unit_test/mnist_inference.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#include <duckdb.h>
#include <iostream>
#include <opencv2/opencv.hpp>
#include <vector>

cv::Mat preprocess_image(const std::string &image_path) {
cv::Mat image = cv::imread(image_path, cv::IMREAD_GRAYSCALE);
if (image.empty()) {
std::cerr << "Error: Unable to load image!" << std::endl;
exit(1);
}
cv::resize(image, image, cv::Size(28, 28));
image.convertTo(image, CV_32F, 1.0 / 255.0);
return image;
}

int main() {
cv::Mat image = preprocess_image("mnist/images/7_12.png");
if (!image.isContinuous()) {
image = image.clone();
}
std::vector<float> input_data(image.ptr<float>(),
image.ptr<float>() + image.total());

duckdb_database db;
duckdb_connection con;

if (duckdb_open(NULL, &db) == DuckDBError) {
// handle error
}
if (duckdb_connect(db, &con) == DuckDBError) {
// handle error
}

// run queries...
duckdb_result res;

std::string tensor_value = "[";
for (int i = 0; i < input_data.size(); ++i) {
tensor_value += std::to_string(input_data[i]);
if (i != input_data.size() - 1) {
tensor_value += ",";
}
}

std::string sql = "SELECT "
"onnx('mnist/onnx/mnist-8.onnx',"
"{'shape':[1,1,28,28],'value':" +
tensor_value + "]}) as result";

duckdb_state state = duckdb_query(con, sql.c_str(), &res);
if (state == DuckDBError) {
// handle error
std::cerr << "Error" << std::endl;
exit(1);
}

std::vector<float> output_data;
output_data.reserve(10);
while (true) {
duckdb_data_chunk result = duckdb_fetch_chunk(res);
if (!result) {
// result is exhausted
break;
}
// get the number of rows from the data chunk
idx_t row_count = duckdb_data_chunk_get_size(result);
assert(row_count == 1);
// get the first column
duckdb_vector struct_col = duckdb_data_chunk_get_vector(result, 0);
uint64_t *struct_validity = duckdb_vector_get_validity(struct_col);

duckdb_vector col1_vector = duckdb_struct_vector_get_child(struct_col, 0);
duckdb_vector col2_vector = duckdb_struct_vector_get_child(struct_col, 1);

duckdb_list_entry *list_data_1 =
(duckdb_list_entry *)duckdb_vector_get_data(col1_vector);
uint64_t *list_validity_1 = duckdb_vector_get_validity(col1_vector);
// get the child column of the list
duckdb_vector list_child_1 = duckdb_list_vector_get_child(col1_vector);
int32_t *child_data_1 = (int32_t *)duckdb_vector_get_data(list_child_1);
uint64_t *child_validity_1 = duckdb_vector_get_validity(list_child_1);

duckdb_list_entry *list_data_2 =
(duckdb_list_entry *)duckdb_vector_get_data(col2_vector);
uint64_t *list_validity_2 = duckdb_vector_get_validity(col2_vector);
// get the child column of the list
duckdb_vector list_child_2 = duckdb_list_vector_get_child(col2_vector);
float *child_data_2 = (float *)duckdb_vector_get_data(list_child_2);
uint64_t *child_validity_2 = duckdb_vector_get_validity(list_child_2);

for (idx_t row = 0; row < row_count; row++) {
if (!duckdb_validity_row_is_valid(list_validity_1, row)) {
// entire list is NULL
printf("NULL\n");
continue;
}
// read the list offsets for this row
duckdb_list_entry list = list_data_1[row];
printf("[");
for (idx_t child_idx = list.offset; child_idx < list.offset + list.length;
child_idx++) {
if (child_idx > list.offset) {
printf(", ");
}
if (!duckdb_validity_row_is_valid(child_validity_1, child_idx)) {
// col1 is NULL
printf("NULL");
} else {
printf("%lld", child_data_1[child_idx]);
}
}
printf("]\n");

duckdb_list_entry list2 = list_data_2[row];
printf("[");
for (idx_t child_idx = list2.offset;
child_idx < list2.offset + list2.length; child_idx++) {
if (child_idx > list2.offset) {
printf(", ");
}
if (!duckdb_validity_row_is_valid(child_validity_2, child_idx)) {
// col2 is NULL
printf("NULL");
} else {
printf("%f", child_data_2[child_idx]);
output_data.push_back(child_data_2[child_idx]);
}
}
printf("]\n");
}
duckdb_destroy_data_chunk(&result);
}

// clean-up
duckdb_destroy_result(&res);
duckdb_disconnect(&con);
duckdb_close(&db);

std::cout << "Output probabilities: ";
for (int i = 0; i < 10; ++i) {
std::cout << output_data[i] << " ";
}
std::cout << std::endl;

//
int predicted_class =
std::distance(output_data.begin(),
std::max_element(output_data.begin(), output_data.end()));
std::cout << "Predicted Class: " << predicted_class << std::endl;

return 0;
}
3 changes: 2 additions & 1 deletion vcpkg.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"dependencies": [
"openssl"
"openssl",
"opencv"
]
}
Loading