@@ -51,10 +51,13 @@ class MatchNeighbor:
51
51
Required. The id of the neighbor.
52
52
distance (float):
53
53
Required. The distance to the query embedding.
54
+ feature_vector (List(float)):
55
+ Optional. The feature vector of the matching datapoint.
54
56
"""
55
57
56
58
id : str
57
59
distance : float
60
+ feature_vector : Optional [List [float ]] = None
58
61
59
62
60
63
@dataclass
@@ -1185,14 +1188,15 @@ def find_neighbors(
1185
1188
self ,
1186
1189
* ,
1187
1190
deployed_index_id : str ,
1188
- queries : List [List [float ]],
1191
+ queries : Optional [ List [List [float ]]] = None ,
1189
1192
num_neighbors : int = 10 ,
1190
1193
filter : Optional [List [Namespace ]] = None ,
1191
1194
per_crowding_attribute_neighbor_count : Optional [int ] = None ,
1192
1195
approx_num_neighbors : Optional [int ] = None ,
1193
1196
fraction_leaf_nodes_to_search_override : Optional [float ] = None ,
1194
1197
return_full_datapoint : bool = False ,
1195
1198
numeric_filter : Optional [List [NumericNamespace ]] = None ,
1199
+ embedding_ids : Optional [List [str ]] = None ,
1196
1200
) -> List [List [MatchNeighbor ]]:
1197
1201
"""Retrieves nearest neighbors for the given embedding queries on the
1198
1202
specified deployed index which is deployed to either public or private
@@ -1243,11 +1247,18 @@ def find_neighbors(
1243
1247
Note that returning full datapoint will significantly increase the
1244
1248
latency and cost of the query.
1245
1249
1246
- numeric_filter (Optional[ list[NumericNamespace] ]):
1250
+ numeric_filter (list[NumericNamespace]):
1247
1251
Optional. A list of NumericNamespaces for filtering the matching
1248
1252
results. For example:
1249
1253
[NumericNamespace(name="cost", value_int=5, op="GREATER")]
1250
1254
will match datapoints that its cost is greater than 5.
1255
+
1256
+ embedding_ids (str):
1257
+ Optional. If `queries` is set, will use `queries` to do nearest
1258
+ neighbor search. If `queries` isn't set, will first use
1259
+ `embedding_ids` to lookup embedding values from dataset, if embedding
1260
+ with `embedding_ids` exists in the dataset, do nearest neighbor search.
1261
+
1251
1262
Returns:
1252
1263
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
1253
1264
"""
@@ -1262,7 +1273,6 @@ def find_neighbors(
1262
1273
per_crowding_attribute_num_neighbors = per_crowding_attribute_neighbor_count ,
1263
1274
approx_num_neighbors = approx_num_neighbors ,
1264
1275
fraction_leaf_nodes_to_search_override = fraction_leaf_nodes_to_search_override ,
1265
- return_full_datapoint = return_full_datapoint ,
1266
1276
)
1267
1277
1268
1278
# Create the FindNeighbors request
@@ -1271,50 +1281,60 @@ def find_neighbors(
1271
1281
find_neighbors_request .deployed_index_id = deployed_index_id
1272
1282
find_neighbors_request .return_full_datapoint = return_full_datapoint
1273
1283
1274
- for query in queries :
1275
- find_neighbors_query = (
1276
- gca_match_service_v1beta1 .FindNeighborsRequest .Query ()
1277
- )
1278
- find_neighbors_query .neighbor_count = num_neighbors
1279
- find_neighbors_query .per_crowding_attribute_neighbor_count = (
1280
- per_crowding_attribute_neighbor_count
1281
- )
1282
- find_neighbors_query .approximate_neighbor_count = approx_num_neighbors
1283
- find_neighbors_query .fraction_leaf_nodes_to_search_override = (
1284
- fraction_leaf_nodes_to_search_override
1284
+ # Token restricts
1285
+ restricts = []
1286
+ if filter :
1287
+ for namespace in filter :
1288
+ restrict = gca_index_v1beta1 .IndexDatapoint .Restriction ()
1289
+ restrict .namespace = namespace .name
1290
+ restrict .allow_list .extend (namespace .allow_tokens )
1291
+ restrict .deny_list .extend (namespace .deny_tokens )
1292
+ restricts .append (restrict )
1293
+ # Numeric restricts
1294
+ numeric_restricts = []
1295
+ if numeric_filter :
1296
+ for numeric_namespace in numeric_filter :
1297
+ numeric_restrict = gca_index_v1beta1 .IndexDatapoint .NumericRestriction ()
1298
+ numeric_restrict .namespace = numeric_namespace .name
1299
+ numeric_restrict .op = numeric_namespace .op
1300
+ numeric_restrict .value_int = numeric_namespace .value_int
1301
+ numeric_restrict .value_float = numeric_namespace .value_float
1302
+ numeric_restrict .value_double = numeric_namespace .value_double
1303
+ numeric_restricts .append (numeric_restrict )
1304
+ # Queries
1305
+ query_by_id = False if queries else True
1306
+ queries = queries if queries else embedding_ids
1307
+ if queries :
1308
+ for query in queries :
1309
+ find_neighbors_query = gca_match_service_v1beta1 .FindNeighborsRequest .Query (
1310
+ neighbor_count = num_neighbors ,
1311
+ per_crowding_attribute_neighbor_count = per_crowding_attribute_neighbor_count ,
1312
+ approximate_neighbor_count = approx_num_neighbors ,
1313
+ fraction_leaf_nodes_to_search_override = fraction_leaf_nodes_to_search_override ,
1314
+ )
1315
+ datapoint = gca_index_v1beta1 .IndexDatapoint (
1316
+ datapoint_id = query if query_by_id else None ,
1317
+ feature_vector = None if query_by_id else query ,
1318
+ )
1319
+ datapoint .restricts .extend (restricts )
1320
+ datapoint .numeric_restricts .extend (numeric_restricts )
1321
+ find_neighbors_query .datapoint = datapoint
1322
+ find_neighbors_request .queries .append (find_neighbors_query )
1323
+ else :
1324
+ raise ValueError (
1325
+ "To find neighbors using matching engine,"
1326
+ "please specify `queries` or `embedding_ids`"
1285
1327
)
1286
- datapoint = gca_index_v1beta1 .IndexDatapoint (feature_vector = query )
1287
- # Token restricts
1288
- if filter :
1289
- for namespace in filter :
1290
- restrict = gca_index_v1beta1 .IndexDatapoint .Restriction ()
1291
- restrict .namespace = namespace .name
1292
- restrict .allow_list .extend (namespace .allow_tokens )
1293
- restrict .deny_list .extend (namespace .deny_tokens )
1294
- datapoint .restricts .append (restrict )
1295
- # Numeric restricts
1296
- if numeric_filter :
1297
- for numeric_namespace in numeric_filter :
1298
- numeric_restrict = (
1299
- gca_index_v1beta1 .IndexDatapoint .NumericRestriction ()
1300
- )
1301
- numeric_restrict .namespace = numeric_namespace .name
1302
- numeric_restrict .op = numeric_namespace .op
1303
- numeric_restrict .value_int = numeric_namespace .value_int
1304
- numeric_restrict .value_float = numeric_namespace .value_float
1305
- numeric_restrict .value_double = numeric_namespace .value_double
1306
- datapoint .numeric_restricts .append (numeric_restrict )
1307
-
1308
- find_neighbors_query .datapoint = datapoint
1309
- find_neighbors_request .queries .append (find_neighbors_query )
1310
1328
1311
1329
response = self ._public_match_client .find_neighbors (find_neighbors_request )
1312
1330
1313
1331
# Wrap the results in MatchNeighbor objects and return
1314
1332
return [
1315
1333
[
1316
1334
MatchNeighbor (
1317
- id = neighbor .datapoint .datapoint_id , distance = neighbor .distance
1335
+ id = neighbor .datapoint .datapoint_id ,
1336
+ distance = neighbor .distance ,
1337
+ feature_vector = neighbor .datapoint .feature_vector ,
1318
1338
)
1319
1339
for neighbor in embedding_neighbors .neighbors
1320
1340
]
@@ -1429,13 +1449,12 @@ def _batch_get_embeddings(
1429
1449
def match (
1430
1450
self ,
1431
1451
deployed_index_id : str ,
1432
- queries : Optional [ List [List [float ] ]] = None ,
1452
+ queries : List [List [float ]] = None ,
1433
1453
num_neighbors : int = 1 ,
1434
1454
filter : Optional [List [Namespace ]] = None ,
1435
1455
per_crowding_attribute_num_neighbors : Optional [int ] = None ,
1436
1456
approx_num_neighbors : Optional [int ] = None ,
1437
1457
fraction_leaf_nodes_to_search_override : Optional [float ] = None ,
1438
- return_full_datapoint : bool = False ,
1439
1458
low_level_batch_size : int = 0 ,
1440
1459
) -> List [List [MatchNeighbor ]]:
1441
1460
"""Retrieves nearest neighbors for the given embedding queries on the
@@ -1468,11 +1487,6 @@ def match(
1468
1487
query time allows user to tune search performance. This value
1469
1488
increase result in both search accuracy and latency increase.
1470
1489
The value should be between 0.0 and 1.0.
1471
- return_full_datapoint (bool):
1472
- Optional. If set to true, the full datapoints (including all
1473
- vector values and of the nearest neighbors are returned.
1474
- Note that returning full datapoint will significantly increase the
1475
- latency and cost of the query.
1476
1490
low_level_batch_size (int):
1477
1491
Optional. Selects the optimal batch size to use for low-level
1478
1492
batching. Queries within each low level batch are executed
@@ -1518,9 +1532,13 @@ def match(
1518
1532
per_crowding_attribute_num_neighbors = per_crowding_attribute_num_neighbors ,
1519
1533
approx_num_neighbors = approx_num_neighbors ,
1520
1534
fraction_leaf_nodes_to_search_override = fraction_leaf_nodes_to_search_override ,
1521
- embedding_enabled = return_full_datapoint ,
1522
1535
)
1523
1536
requests .append (request )
1537
+ else :
1538
+ raise ValueError (
1539
+ "To find neighbors using matching engine,"
1540
+ "please specify `queries` or `embedding_ids`"
1541
+ )
1524
1542
1525
1543
batch_request_for_index .requests .extend (requests )
1526
1544
batch_request .requests .append (batch_request_for_index )
@@ -1531,8 +1549,11 @@ def match(
1531
1549
# Wrap the results in MatchNeighbor objects and return
1532
1550
return [
1533
1551
[
1534
- MatchNeighbor (id = neighbor .id , distance = neighbor .distance )
1535
- for neighbor in embedding_neighbors .neighbor
1552
+ MatchNeighbor (
1553
+ id = embedding_neighbors .neighbor [i ].id ,
1554
+ distance = embedding_neighbors .neighbor [i ].distance ,
1555
+ )
1556
+ for i in range (len (embedding_neighbors .neighbor ))
1536
1557
]
1537
1558
for embedding_neighbors in response .responses [0 ].responses
1538
1559
]
0 commit comments