Skip to content

Commit c288b19

Browse files
authored
Removing ndarray from DB (#188)
1 parent 558d881 commit c288b19

File tree

30 files changed

+316
-412
lines changed

30 files changed

+316
-412
lines changed

ahnlich/Cargo.lock

-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ahnlich/ai/src/engine/ai/providers/ort/mod.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use crate::engine::ai::providers::processors::preprocessor::{
2727
ORTImagePreprocessor, ORTPreprocessor, ORTTextPreprocessor,
2828
};
2929
use ahnlich_types::keyval::StoreKey;
30-
use ndarray::{Array, Array1, Axis, Ix2, Ix4};
30+
use ndarray::{Array, Axis, Ix2, Ix4};
3131
use std::convert::TryFrom;
3232
use std::default::Default;
3333
use std::path::{Path, PathBuf};
@@ -482,7 +482,7 @@ impl ProviderTrait for ORTProvider {
482482
let new_store_keys: Vec<StoreKey> = embeddings
483483
.axis_iter(Axis(0))
484484
.into_par_iter()
485-
.map(|embedding| StoreKey(<Array1<f32>>::from(embedding.to_owned())))
485+
.map(|embedding| StoreKey(embedding.to_vec()))
486486
.collect();
487487
store_keys.extend(new_store_keys);
488488
}
@@ -502,7 +502,7 @@ impl ProviderTrait for ORTProvider {
502502
let new_store_keys: Vec<StoreKey> = embeddings
503503
.axis_iter(Axis(0))
504504
.into_par_iter()
505-
.map(|embedding| StoreKey(<Array1<f32>>::from(embedding.to_owned())))
505+
.map(|embedding| StoreKey(embedding.to_vec()))
506506
.collect();
507507
store_keys.extend(new_store_keys);
508508
}

ahnlich/client/src/db.rs

+13-14
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,6 @@ mod tests {
439439
use super::*;
440440
use ahnlich_db::cli::ServerConfig;
441441
use ahnlich_db::server::handler::Server;
442-
use ndarray::array;
443442
use once_cell::sync::Lazy;
444443
use pretty_assertions::assert_eq;
445444
use std::collections::HashMap;
@@ -585,7 +584,7 @@ mod tests {
585584
assert!(db_client.create_store(create_store_params).await.is_ok());
586585
let del_key_params = db_params::DelKeyParams::builder()
587586
.store("Main".to_string())
588-
.keys(vec![StoreKey(array![1.0, 1.1, 1.2, 1.3])])
587+
.keys(vec![StoreKey(vec![1.0, 1.1, 1.2, 1.3])])
589588
.build();
590589
assert_eq!(
591590
db_client.del_key(del_key_params).await.unwrap(),
@@ -594,8 +593,8 @@ mod tests {
594593
let set_key_params = db_params::SetParams::builder()
595594
.store("Main".to_string())
596595
.inputs(vec![
597-
(StoreKey(array![1.0, 1.1, 1.2, 1.3]), HashMap::new()),
598-
(StoreKey(array![1.1, 1.2, 1.3, 1.4]), HashMap::new()),
596+
(StoreKey(vec![1.0, 1.1, 1.2, 1.3]), HashMap::new()),
597+
(StoreKey(vec![1.1, 1.2, 1.3, 1.4]), HashMap::new()),
599598
])
600599
.build();
601600
assert!(db_client.set(set_key_params).await.is_ok());
@@ -604,20 +603,20 @@ mod tests {
604603
ServerResponse::StoreList(HashSet::from_iter([StoreInfo {
605604
name: StoreName("Main".to_string()),
606605
len: 2,
607-
size_in_bytes: 2160,
606+
size_in_bytes: 2016,
608607
},]))
609608
);
610609
// error as different dimensions
611610

612611
let del_key_params = db_params::DelKeyParams::builder()
613612
.store("Main".to_string())
614-
.keys(vec![StoreKey(array![1.0, 1.2])])
613+
.keys(vec![StoreKey(vec![1.0, 1.2])])
615614
.build();
616615
assert!(db_client.del_key(del_key_params).await.is_err());
617616

618617
let del_key_params = db_params::DelKeyParams::builder()
619618
.store("Main".to_string())
620-
.keys(vec![StoreKey(array![1.0, 1.1, 1.2, 1.3])])
619+
.keys(vec![StoreKey(vec![1.0, 1.1, 1.2, 1.3])])
621620
.build();
622621

623622
assert_eq!(
@@ -629,7 +628,7 @@ mod tests {
629628
ServerResponse::StoreList(HashSet::from_iter([StoreInfo {
630629
name: StoreName("Main".to_string()),
631630
len: 1,
632-
size_in_bytes: 1976,
631+
size_in_bytes: 1904,
633632
},]))
634633
);
635634
}
@@ -661,21 +660,21 @@ mod tests {
661660
.store("Main".to_string())
662661
.inputs(vec![
663662
(
664-
StoreKey(array![1.2, 1.3, 1.4]),
663+
StoreKey(vec![1.2, 1.3, 1.4]),
665664
HashMap::from_iter([(
666665
MetadataKey::new("medal".into()),
667666
MetadataValue::RawString("silver".into()),
668667
)]),
669668
),
670669
(
671-
StoreKey(array![2.0, 2.1, 2.2]),
670+
StoreKey(vec![2.0, 2.1, 2.2]),
672671
HashMap::from_iter([(
673672
MetadataKey::new("medal".into()),
674673
MetadataValue::RawString("gold".into()),
675674
)]),
676675
),
677676
(
678-
StoreKey(array![5.0, 5.1, 5.2]),
677+
StoreKey(vec![5.0, 5.1, 5.2]),
679678
HashMap::from_iter([(
680679
MetadataKey::new("medal".into()),
681680
MetadataValue::RawString("bronze".into()),
@@ -687,15 +686,15 @@ mod tests {
687686
// error due to dimension mismatch
688687
let get_sim_n_params = db_params::GetSimNParams::builder()
689688
.store("Main".to_string())
690-
.search_input(StoreKey(array![1.1, 2.0]))
689+
.search_input(StoreKey(vec![1.1, 2.0]))
691690
.closest_n(2)
692691
.algorithm(Algorithm::EuclideanDistance)
693692
.build();
694693
assert!(db_client.get_sim_n(get_sim_n_params).await.is_err());
695694

696695
let get_sim_n_params = db_params::GetSimNParams::builder()
697696
.store("Main".to_string())
698-
.search_input(StoreKey(array![5.0, 2.1, 2.2]))
697+
.search_input(StoreKey(vec![5.0, 2.1, 2.2]))
699698
.closest_n(2)
700699
.algorithm(Algorithm::CosineSimilarity)
701700
.condition(Some(PredicateCondition::Value(Predicate::Equals {
@@ -707,7 +706,7 @@ mod tests {
707706
assert_eq!(
708707
db_client.get_sim_n(get_sim_n_params).await.unwrap(),
709708
ServerResponse::GetSimN(vec![(
710-
StoreKey(array![2.0, 2.1, 2.2]),
709+
StoreKey(vec![2.0, 2.1, 2.2]),
711710
HashMap::from_iter([(
712711
MetadataKey::new("medal".into()),
713712
MetadataValue::RawString("gold".into()),

ahnlich/db/benches/database.rs

+9-15
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ use ahnlich_types::keyval::StoreValue;
55
use ahnlich_types::similarity::Algorithm;
66
use ahnlich_types::similarity::NonLinearAlgorithm;
77
use criterion::{criterion_group, criterion_main, Criterion};
8-
use ndarray::Array;
9-
use ndarray::Array1;
108
use rayon::iter::ParallelIterator;
119
use rayon::slice::ParallelSlice;
1210
use std::collections::HashMap;
@@ -22,7 +20,7 @@ fn generate_storekey_store_value(size: usize, dimension: usize) -> Vec<(StoreKey
2220
// Use Rayon to process the buffer in parallel
2321
buffer
2422
.par_chunks_exact(dimension)
25-
.map(|chunk| (StoreKey(Array::from(chunk.to_owned())), HashMap::new()))
23+
.map(|chunk| (StoreKey(chunk.to_owned()), HashMap::new()))
2624
.collect()
2725
}
2826

@@ -42,8 +40,8 @@ fn bench_retrieval(c: &mut Criterion) {
4240
let dimension = 1024;
4341
let bulk_insert: Vec<_> = (0..size)
4442
.map(|_| {
45-
let random_array: Array1<f32> =
46-
Array::from((0..dimension).map(|_| rand::random()).collect::<Vec<f32>>());
43+
let random_array: Vec<f32> =
44+
(0..dimension).map(|_| rand::random()).collect::<Vec<f32>>();
4745
(StoreKey(random_array), HashMap::new())
4846
})
4947
.collect();
@@ -59,9 +57,7 @@ fn bench_retrieval(c: &mut Criterion) {
5957
no_condition_handler
6058
.set_in_store(&StoreName(store_name.to_string()), bulk_insert.clone())
6159
.unwrap();
62-
let random_input = StoreKey(Array::from(
63-
(0..dimension).map(|_| rand::random()).collect::<Vec<f32>>(),
64-
));
60+
let random_input = StoreKey((0..dimension).map(|_| rand::random()).collect::<Vec<f32>>());
6561
group_no_condition.sampling_mode(criterion::SamplingMode::Flat);
6662
group_no_condition.bench_function(format!("size_{size}"), |b| {
6763
b.iter(|| {
@@ -85,8 +81,8 @@ fn bench_retrieval(c: &mut Criterion) {
8581
let dimension = 1024;
8682
let bulk_insert: Vec<_> = (0..size)
8783
.map(|_| {
88-
let random_array: Array1<f32> =
89-
Array::from((0..dimension).map(|_| rand::random()).collect::<Vec<f32>>());
84+
let random_array: Vec<f32> =
85+
(0..dimension).map(|_| rand::random()).collect::<Vec<f32>>();
9086
(StoreKey(random_array), HashMap::new())
9187
})
9288
.collect();
@@ -102,9 +98,7 @@ fn bench_retrieval(c: &mut Criterion) {
10298
non_linear_handler
10399
.set_in_store(&StoreName(store_name.to_string()), bulk_insert.clone())
104100
.unwrap();
105-
let random_input = StoreKey(Array::from(
106-
(0..dimension).map(|_| rand::random()).collect::<Vec<f32>>(),
107-
));
101+
let random_input = StoreKey((0..dimension).map(|_| rand::random()).collect::<Vec<f32>>());
108102
group_non_linear_kdtree.sampling_mode(criterion::SamplingMode::Flat);
109103
group_non_linear_kdtree.bench_function(format!("size_{size}"), |b| {
110104
b.iter(|| {
@@ -143,11 +137,11 @@ fn bench_insertion(c: &mut Criterion) {
143137
.unwrap();
144138
let dimension = dimension.clone();
145139
let random_array = vec![(
146-
StoreKey(Array::from(
140+
StoreKey(
147141
(0..dimension)
148142
.map(|_| fastrand::f32())
149143
.collect::<Vec<f32>>(),
150-
)),
144+
),
151145
HashMap::new(),
152146
)];
153147
group.bench_function(format!("size_{size}"), |b| {

ahnlich/db/src/algorithm/heap.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ mod tests {
140140
fn test_min_heap_ordering_works() {
141141
let mut heap = MinHeap::new(NonZeroUsize::new(3).unwrap());
142142
let mut count = 0.0;
143-
let first_vector = StoreKey(ndarray::Array1::<f32>::zeros(2).map(|x| x + 2.0));
143+
let first_vector = StoreKey(vec![2.0, 2.0]);
144144

145145
// If we pop these scores now, they should come back in the reverse order.
146146
while count < 5.0 {
@@ -162,7 +162,7 @@ mod tests {
162162
fn test_max_heap_ordering_works() {
163163
let mut heap = MaxHeap::new(NonZeroUsize::new(3).unwrap());
164164
let mut count = 0.0;
165-
let first_vector = StoreKey(ndarray::Array1::<f32>::zeros(2).map(|x| x + 2.0));
165+
let first_vector = StoreKey(vec![2.0, 2.0]);
166166

167167
// If we pop these scores now, they should come back the right order(max first).
168168
while count < 5.0 {

ahnlich/db/src/algorithm/non_linear.rs

+7-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
use super::super::errors::ServerError;
22
use super::FindSimilarN;
33
use ahnlich_similarity::kdtree::KDTree;
4-
use ahnlich_similarity::utils::Array1F32Ordered;
4+
use ahnlich_similarity::utils::VecF32Ordered;
55
use ahnlich_types::keyval::StoreKey;
66
use ahnlich_types::similarity::NonLinearAlgorithm;
77
use flurry::HashMap as ConcurrentHashMap;
8-
use ndarray::Array1;
98
use serde::Deserialize;
109
use serde::Serialize;
1110
use std::collections::HashSet;
@@ -28,7 +27,7 @@ impl NonLinearAlgorithmWithIndex {
2827
}
2928

3029
#[tracing::instrument(skip_all)]
31-
fn insert(&self, new: &[Array1<f32>]) {
30+
fn insert(&self, new: &[Vec<f32>]) {
3231
match self {
3332
NonLinearAlgorithmWithIndex::KDTree(kdtree) => {
3433
kdtree
@@ -39,7 +38,7 @@ impl NonLinearAlgorithmWithIndex {
3938
}
4039

4140
#[tracing::instrument(skip_all)]
42-
fn delete(&self, new: &[Array1<f32>]) {
41+
fn delete(&self, new: &[Vec<f32>]) {
4342
match self {
4443
NonLinearAlgorithmWithIndex::KDTree(kdtree) => {
4544
kdtree
@@ -71,7 +70,7 @@ impl FindSimilarN for NonLinearAlgorithmWithIndex {
7170
} else {
7271
Some(
7372
search_list
74-
.map(|key| Array1F32Ordered(key.0.clone()))
73+
.map(|key| VecF32Ordered(key.0.clone()))
7574
.collect(),
7675
)
7776
};
@@ -121,7 +120,7 @@ impl NonLinearAlgorithmIndices {
121120
pub fn insert_indices(
122121
&self,
123122
indices: HashSet<NonLinearAlgorithm>,
124-
values: &[Array1<f32>],
123+
values: &[Vec<f32>],
125124
dimension: NonZeroUsize,
126125
) {
127126
let pinned = self.algorithm_to_index.pin();
@@ -156,7 +155,7 @@ impl NonLinearAlgorithmIndices {
156155

157156
/// insert new entries into the non linear algorithm indices
158157
#[tracing::instrument(skip_all)]
159-
pub(crate) fn insert(&self, new: Vec<Array1<f32>>) {
158+
pub(crate) fn insert(&self, new: Vec<Vec<f32>>) {
160159
let pinned = self.algorithm_to_index.pin();
161160
for (_, algo) in pinned.iter() {
162161
algo.insert(&new);
@@ -165,7 +164,7 @@ impl NonLinearAlgorithmIndices {
165164

166165
/// delete old entries from the non linear algorithm indices
167166
#[tracing::instrument(skip_all)]
168-
pub(crate) fn delete(&self, old: &[Array1<f32>]) {
167+
pub(crate) fn delete(&self, old: &[Vec<f32>]) {
169168
let pinned = self.algorithm_to_index.pin();
170169
for (_, algo) in pinned.iter() {
171170
algo.delete(old);

0 commit comments

Comments
 (0)