@@ -187,7 +187,9 @@ class MatchNeighbor:
187
187
id (str):
188
188
Required. The id of the neighbor.
189
189
distance (float):
190
- Required. The distance to the query embedding.
190
+ Optional. The distance between the neighbor and the dense embedding query.
191
+ sparse_distance (float):
192
+ Optional. The distance between the neighbor and the query sparse_embedding.
191
193
feature_vector (List[float]):
192
194
Optional. The feature vector of the matching datapoint.
193
195
crowding_tag (Optional[str]):
@@ -210,7 +212,8 @@ class MatchNeighbor:
210
212
"""
211
213
212
214
id : str
213
- distance : float
215
+ distance : Optional [float ] = None
216
+ sparse_distance : Optional [float ] = None
214
217
feature_vector : Optional [List [float ]] = None
215
218
crowding_tag : Optional [str ] = None
216
219
restricts : Optional [List [Namespace ]] = None
@@ -316,6 +319,9 @@ def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighb
316
319
name = restrict .name , value_double = restrict .value_double
317
320
)
318
321
self .numeric_restricts .append (numeric_namespace )
322
+ if embedding .sparse_embedding :
323
+ self .sparse_embedding_values = embedding .sparse_embedding .float_val
324
+ self .sparse_embedding_dimensions = embedding .sparse_embedding .dimension
319
325
return self
320
326
321
327
@@ -1548,7 +1554,11 @@ def find_neighbors(
1548
1554
return [
1549
1555
[
1550
1556
MatchNeighbor (
1551
- id = neighbor .datapoint .datapoint_id , distance = neighbor .distance
1557
+ id = neighbor .datapoint .datapoint_id ,
1558
+ distance = neighbor .distance ,
1559
+ sparse_distance = neighbor .sparse_distance
1560
+ if neighbor .sparse_distance
1561
+ else None ,
1552
1562
).from_index_datapoint (index_datapoint = neighbor .datapoint )
1553
1563
for neighbor in embedding_neighbors .neighbors
1554
1564
]
@@ -1662,7 +1672,7 @@ def _batch_get_embeddings(
1662
1672
def match (
1663
1673
self ,
1664
1674
deployed_index_id : str ,
1665
- queries : List [List [float ]] = None ,
1675
+ queries : Union [ List [List [float ]], List [ HybridQuery ]] = None ,
1666
1676
num_neighbors : int = 1 ,
1667
1677
filter : Optional [List [Namespace ]] = None ,
1668
1678
per_crowding_attribute_num_neighbors : Optional [int ] = None ,
@@ -1677,8 +1687,14 @@ def match(
1677
1687
Args:
1678
1688
deployed_index_id (str):
1679
1689
Required. The ID of the DeployedIndex to match the queries against.
1680
- queries (List[List[float]]):
1681
- Optional. A list of queries. Each query is a list of floats, representing a single embedding.
1690
+ queries (Union[List[List[float]], List[HybridQuery]]):
1691
+ Optional. A list of queries.
1692
+
1693
+ For regular dense-only queries, each query is a list of floats,
1694
+ representing a single embedding.
1695
+
1696
+ For hybrid queries, each query is a hybrid query of type
1697
+ aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery.
1682
1698
num_neighbors (int):
1683
1699
Required. The number of nearest neighbors to be retrieved from database for
1684
1700
each query.
@@ -1759,16 +1775,28 @@ def match(
1759
1775
1760
1776
requests = []
1761
1777
if queries :
1778
+ query_is_hybrid = isinstance (queries [0 ], HybridQuery )
1762
1779
for query in queries :
1763
1780
request = match_service_pb2 .MatchRequest (
1764
1781
deployed_index_id = deployed_index_id ,
1765
- float_val = query ,
1782
+ float_val = query . dense_embedding if query_is_hybrid else query ,
1766
1783
num_neighbors = num_neighbors ,
1767
1784
restricts = restricts ,
1768
1785
per_crowding_attribute_num_neighbors = per_crowding_attribute_num_neighbors ,
1769
1786
approx_num_neighbors = approx_num_neighbors ,
1770
1787
fraction_leaf_nodes_to_search_override = fraction_leaf_nodes_to_search_override ,
1771
1788
numeric_restricts = numeric_restricts ,
1789
+ sparse_embedding = match_service_pb2 .SparseEmbedding (
1790
+ float_val = query .sparse_embedding_values ,
1791
+ dimension = query .sparse_embedding_dimensions ,
1792
+ )
1793
+ if query_is_hybrid
1794
+ else None ,
1795
+ rrf = match_service_pb2 .MatchRequest .RRF (
1796
+ alpha = query .rrf_ranking_alpha ,
1797
+ )
1798
+ if query_is_hybrid and query .rrf_ranking_alpha
1799
+ else None ,
1772
1800
)
1773
1801
requests .append (request )
1774
1802
else :
@@ -1789,7 +1817,11 @@ def match(
1789
1817
match_neighbors_id_map = {}
1790
1818
for neighbor in resp .neighbor :
1791
1819
match_neighbors_id_map [neighbor .id ] = MatchNeighbor (
1792
- id = neighbor .id , distance = neighbor .distance
1820
+ id = neighbor .id ,
1821
+ distance = neighbor .distance ,
1822
+ sparse_distance = neighbor .sparse_distance
1823
+ if neighbor .sparse_distance
1824
+ else None ,
1793
1825
)
1794
1826
for embedding in resp .embeddings :
1795
1827
if embedding .id in match_neighbors_id_map :
0 commit comments