Skip to content

Commit 4ed3ad4

Browse files
authored
Simplifier Client Args by specifying default args (#150)
* WIP: Using Typebuilder to create simple struct params for rust client * Refactor: Use struct params as args for ai and db clients * update client readme
1 parent a3aa36c commit 4ed3ad4

File tree

16 files changed

+1032
-731
lines changed

16 files changed

+1032
-731
lines changed

ahnlich/Cargo.lock

+21
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ahnlich/ai/src/server/task.rs

+95-72
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::engine::ai::models::Model;
2-
use ahnlich_client_rs::db::DbClient;
2+
use ahnlich_client_rs::{builders::db as db_params, db::DbClient};
33
use ahnlich_types::ai::{
44
AIQuery, AIServerQuery, AIServerResponse, AIServerResult, PreprocessAction,
55
};
@@ -80,18 +80,15 @@ impl AhnlichProtocol for AIProxyTask {
8080
predicates.insert(default_metadata_key.clone());
8181
}
8282
let model: Model = (&index_model).into();
83-
match self
84-
.db_client
85-
.create_store(
86-
store.clone(),
87-
model.embedding_size,
88-
predicates,
89-
non_linear_indices,
90-
false,
91-
parent_id.clone(),
92-
)
93-
.await
94-
{
83+
let create_store_params = db_params::CreateStoreParams::builder()
84+
.store(store.clone().to_string())
85+
.dimension(model.embedding_size.into())
86+
.create_predicates(predicates)
87+
.non_linear_indices(non_linear_indices)
88+
.error_if_exists(false)
89+
.tracing_id(parent_id.clone())
90+
.build();
91+
match self.db_client.create_store(create_store_params).await {
9592
Err(err) => Err(err.to_string()),
9693
Ok(_) => self
9794
.store_handler
@@ -130,9 +127,20 @@ impl AhnlichProtocol for AIProxyTask {
130127
key: default_metadatakey.clone(),
131128
value: del_hashset,
132129
});
133-
pipeline.del_pred(store.clone(), delete_condition);
130+
let del_pred_params = db_params::DelPredParams::builder()
131+
.store(store.clone().to_string())
132+
.condition(delete_condition)
133+
.tracing_id(parent_id.clone())
134+
.build();
135+
pipeline.del_pred(del_pred_params);
134136
}
135-
pipeline.set(store, db_inputs);
137+
let set_params = db_params::SetParams::builder()
138+
.store(store.to_string())
139+
.inputs(db_inputs)
140+
.tracing_id(parent_id.clone())
141+
.build();
142+
143+
pipeline.set(set_params);
136144
match pipeline.exec().await {
137145
Ok(res) => match res.into_inner().as_slice() {
138146
[Ok(ServerResponse::Set(upsert))]
@@ -164,11 +172,12 @@ impl AhnlichProtocol for AIProxyTask {
164172
key: default_metadatakey.clone(),
165173
value: HashSet::from_iter([metadata_value]),
166174
});
167-
match self
168-
.db_client
169-
.del_pred(store, delete_condition, parent_id.clone())
170-
.await
171-
{
175+
let del_pred_params = db_params::DelPredParams::builder()
176+
.store(store.to_string())
177+
.condition(delete_condition)
178+
.tracing_id(parent_id.clone())
179+
.build();
180+
match self.db_client.del_pred(del_pred_params).await {
172181
Ok(res) => {
173182
if let ServerResponse::Del(num) = res {
174183
Ok(AIServerResponse::Del(num))
@@ -188,22 +197,30 @@ impl AhnlichProtocol for AIProxyTask {
188197
AIQuery::DropStore {
189198
store,
190199
error_if_not_exists,
191-
} => match self
192-
.db_client
193-
.drop_store(store.clone(), error_if_not_exists, parent_id.clone())
194-
.await
195-
{
196-
Ok(_) => self
197-
.store_handler
198-
.drop_store(store, error_if_not_exists)
199-
.map(AIServerResponse::Del)
200-
.map_err(|e| e.to_string()),
201-
Err(err) => Err(format!("{err}")),
202-
},
200+
} => {
201+
let drop_store_params = db_params::DropStoreParams::builder()
202+
.store(store.to_string())
203+
.error_if_not_exists(error_if_not_exists)
204+
.tracing_id(parent_id.clone())
205+
.build();
206+
match self.db_client.drop_store(drop_store_params).await {
207+
Ok(_) => self
208+
.store_handler
209+
.drop_store(store, error_if_not_exists)
210+
.map(AIServerResponse::Del)
211+
.map_err(|e| e.to_string()),
212+
Err(err) => Err(format!("{err}")),
213+
}
214+
}
203215
AIQuery::CreatePredIndex { store, predicates } => {
216+
let create_pred_index_params = db_params::CreatePredIndexParams::builder()
217+
.store(store.to_string())
218+
.predicates(predicates)
219+
.tracing_id(parent_id.clone())
220+
.build();
204221
match self
205222
.db_client
206-
.create_pred_index(store, predicates, parent_id.clone())
223+
.create_pred_index(create_pred_index_params)
207224
.await
208225
{
209226
Ok(res) => {
@@ -221,13 +238,15 @@ impl AhnlichProtocol for AIProxyTask {
221238
store,
222239
non_linear_indices,
223240
} => {
241+
let create_non_linear_algo_params =
242+
db_params::CreateNonLinearAlgorithmIndexParams::builder()
243+
.store(store.to_string())
244+
.non_linear_indices(non_linear_indices)
245+
.tracing_id(parent_id.clone())
246+
.build();
224247
match self
225248
.db_client
226-
.create_non_linear_algorithm_index(
227-
store,
228-
non_linear_indices,
229-
parent_id.clone(),
230-
)
249+
.create_non_linear_algorithm_index(create_non_linear_algo_params)
231250
.await
232251
{
233252
Ok(res) => {
@@ -250,35 +269,41 @@ impl AhnlichProtocol for AIProxyTask {
250269
if predicates.contains(default_metadatakey) {
251270
let _ = predicates.remove(default_metadatakey);
252271
}
253-
match self
254-
.db_client
255-
.drop_pred_index(store, predicates, error_if_not_exists, parent_id.clone())
256-
.await
257272
{
258-
Ok(res) => {
259-
if let ServerResponse::Del(num) = res {
260-
Ok(AIServerResponse::Del(num))
261-
} else {
262-
Err(AIProxyError::UnexpectedDBResponse(format!("{:?}", res))
263-
.to_string())
273+
let drop_pred_index_params = db_params::DropPredIndexParams::builder()
274+
.store(store.to_string())
275+
.predicates(predicates)
276+
.error_if_not_exists(error_if_not_exists)
277+
.tracing_id(parent_id.clone())
278+
.build();
279+
match self.db_client.drop_pred_index(drop_pred_index_params).await {
280+
Ok(res) => {
281+
if let ServerResponse::Del(num) = res {
282+
Ok(AIServerResponse::Del(num))
283+
} else {
284+
Err(AIProxyError::UnexpectedDBResponse(format!("{:?}", res))
285+
.to_string())
286+
}
264287
}
288+
Err(err) => Err(format!("{err}")),
265289
}
266-
Err(err) => Err(format!("{err}")),
267290
}
268291
}
269292
AIQuery::DropNonLinearAlgorithmIndex {
270293
store,
271294
non_linear_indices,
272295
error_if_not_exists,
273296
} => {
297+
let drop_non_linear_algorithm_index_params =
298+
db_params::DropNonLinearAlgorithmIndexParams::builder()
299+
.store(store.to_string())
300+
.non_linear_indices(non_linear_indices)
301+
.error_if_not_exists(error_if_not_exists)
302+
.tracing_id(parent_id.clone())
303+
.build();
274304
match self
275305
.db_client
276-
.drop_non_linear_algorithm_index(
277-
store,
278-
non_linear_indices,
279-
error_if_not_exists,
280-
parent_id.clone(),
281-
)
306+
.drop_non_linear_algorithm_index(drop_non_linear_algorithm_index_params)
282307
.await
283308
{
284309
Ok(res) => {
@@ -293,11 +318,12 @@ impl AhnlichProtocol for AIProxyTask {
293318
}
294319
}
295320
AIQuery::GetPred { store, condition } => {
296-
match self
297-
.db_client
298-
.get_pred(store, condition, parent_id.clone())
299-
.await
300-
{
321+
let get_pred_params = db_params::GetPredParams::builder()
322+
.store(store.to_string())
323+
.condition(condition)
324+
.tracing_id(parent_id.clone())
325+
.build();
326+
match self.db_client.get_pred(get_pred_params).await {
301327
Ok(res) => {
302328
if let ServerResponse::Get(response) = res {
303329
// conversion to store input here
@@ -337,18 +363,15 @@ impl AhnlichProtocol for AIProxyTask {
337363
.await;
338364
match repr {
339365
Ok(store_key) => {
340-
match self
341-
.db_client
342-
.get_sim_n(
343-
store,
344-
store_key,
345-
closest_n,
346-
algorithm,
347-
condition,
348-
parent_id.clone(),
349-
)
350-
.await
351-
{
366+
let get_sim_n_params = db_params::GetSimNParams::builder()
367+
.store(store.to_string())
368+
.search_input(store_key)
369+
.closest_n(closest_n.into())
370+
.algorithm(algorithm)
371+
.condition(condition)
372+
.tracing_id(parent_id.clone())
373+
.build();
374+
match self.db_client.get_sim_n(get_sim_n_params).await {
352375
Ok(res) => {
353376
if let ServerResponse::GetSimN(response) = res {
354377
let (store_key_input, similarities): (Vec<_>, Vec<_>) =

ahnlich/client/Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@ async-trait.workspace = true
2323
tokio.workspace = true
2424
deadpool.workspace = true
2525
fallible_collections.workspace = true
26+
typed-builder = "0.20.0"
27+
2628
[dev-dependencies]
2729
db = { path = "../db", version = "*" }
2830
ai = { path = "../ai", version = "*" }
2931
pretty_assertions.workspace = true
3032
ndarray.workspace = true
3133
utils = { path = "../utils", version = "*" }
34+

0 commit comments

Comments
 (0)