Skip to content

Commit ce65eab

Browse files
lingyinwcopybara-github
authored andcommitted
feat: Add support for hybrid queries for private endpoint in Matching Engine Index Endpoint.
PiperOrigin-RevId: 644459987
1 parent 536f1d5 commit ce65eab

File tree

2 files changed

+108
-15
lines changed

2 files changed

+108
-15
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+40-8
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,9 @@ class MatchNeighbor:
187187
id (str):
188188
Required. The id of the neighbor.
189189
distance (float):
190-
Required. The distance to the query embedding.
190+
Optional. The distance between the neighbor and the dense embedding query.
191+
sparse_distance (float):
192+
Optional. The distance between the neighbor and the query sparse_embedding.
191193
feature_vector (List[float]):
192194
Optional. The feature vector of the matching datapoint.
193195
crowding_tag (Optional[str]):
@@ -210,7 +212,8 @@ class MatchNeighbor:
210212
"""
211213

212214
id: str
213-
distance: float
215+
distance: Optional[float] = None
216+
sparse_distance: Optional[float] = None
214217
feature_vector: Optional[List[float]] = None
215218
crowding_tag: Optional[str] = None
216219
restricts: Optional[List[Namespace]] = None
@@ -316,6 +319,9 @@ def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighb
316319
name=restrict.name, value_double=restrict.value_double
317320
)
318321
self.numeric_restricts.append(numeric_namespace)
322+
if embedding.sparse_embedding:
323+
self.sparse_embedding_values = embedding.sparse_embedding.float_val
324+
self.sparse_embedding_dimensions = embedding.sparse_embedding.dimension
319325
return self
320326

321327

@@ -1548,7 +1554,11 @@ def find_neighbors(
15481554
return [
15491555
[
15501556
MatchNeighbor(
1551-
id=neighbor.datapoint.datapoint_id, distance=neighbor.distance
1557+
id=neighbor.datapoint.datapoint_id,
1558+
distance=neighbor.distance,
1559+
sparse_distance=neighbor.sparse_distance
1560+
if neighbor.sparse_distance
1561+
else None,
15521562
).from_index_datapoint(index_datapoint=neighbor.datapoint)
15531563
for neighbor in embedding_neighbors.neighbors
15541564
]
@@ -1662,7 +1672,7 @@ def _batch_get_embeddings(
16621672
def match(
16631673
self,
16641674
deployed_index_id: str,
1665-
queries: List[List[float]] = None,
1675+
queries: Union[List[List[float]], List[HybridQuery]] = None,
16661676
num_neighbors: int = 1,
16671677
filter: Optional[List[Namespace]] = None,
16681678
per_crowding_attribute_num_neighbors: Optional[int] = None,
@@ -1677,8 +1687,14 @@ def match(
16771687
Args:
16781688
deployed_index_id (str):
16791689
Required. The ID of the DeployedIndex to match the queries against.
1680-
queries (List[List[float]]):
1681-
Optional. A list of queries. Each query is a list of floats, representing a single embedding.
1690+
queries (Union[List[List[float]], List[HybridQuery]]):
1691+
Optional. A list of queries.
1692+
1693+
For regular dense-only queries, each query is a list of floats,
1694+
representing a single embedding.
1695+
1696+
For hybrid queries, each query is a hybrid query of type
1697+
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery.
16821698
num_neighbors (int):
16831699
Required. The number of nearest neighbors to be retrieved from database for
16841700
each query.
@@ -1759,16 +1775,28 @@ def match(
17591775

17601776
requests = []
17611777
if queries:
1778+
query_is_hybrid = isinstance(queries[0], HybridQuery)
17621779
for query in queries:
17631780
request = match_service_pb2.MatchRequest(
17641781
deployed_index_id=deployed_index_id,
1765-
float_val=query,
1782+
float_val=query.dense_embedding if query_is_hybrid else query,
17661783
num_neighbors=num_neighbors,
17671784
restricts=restricts,
17681785
per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors,
17691786
approx_num_neighbors=approx_num_neighbors,
17701787
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
17711788
numeric_restricts=numeric_restricts,
1789+
sparse_embedding=match_service_pb2.SparseEmbedding(
1790+
float_val=query.sparse_embedding_values,
1791+
dimension=query.sparse_embedding_dimensions,
1792+
)
1793+
if query_is_hybrid
1794+
else None,
1795+
rrf=match_service_pb2.MatchRequest.RRF(
1796+
alpha=query.rrf_ranking_alpha,
1797+
)
1798+
if query_is_hybrid and query.rrf_ranking_alpha
1799+
else None,
17721800
)
17731801
requests.append(request)
17741802
else:
@@ -1789,7 +1817,11 @@ def match(
17891817
match_neighbors_id_map = {}
17901818
for neighbor in resp.neighbor:
17911819
match_neighbors_id_map[neighbor.id] = MatchNeighbor(
1792-
id=neighbor.id, distance=neighbor.distance
1820+
id=neighbor.id,
1821+
distance=neighbor.distance,
1822+
sparse_distance=neighbor.sparse_distance
1823+
if neighbor.sparse_distance
1824+
else None,
17931825
)
17941826
for embedding in resp.embeddings:
17951827
if embedding.id in match_neighbors_id_map:

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

+68-7
Original file line numberDiff line numberDiff line change
@@ -1137,7 +1137,7 @@ def test_index_endpoint_match_queries_backward_compatibility(
11371137
index_endpoint_match_queries_mock.assert_called_with(batch_request)
11381138

11391139
@pytest.mark.usefixtures("get_index_endpoint_mock")
1140-
def test_private_service_access_service_access_index_endpoint_match_queries(
1140+
def test_private_service_access_hybrid_search_match_queries(
11411141
self, index_endpoint_match_queries_mock
11421142
):
11431143
aiplatform.init(project=_TEST_PROJECT)
@@ -1146,7 +1146,72 @@ def test_private_service_access_service_access_index_endpoint_match_queries(
11461146
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
11471147
)
11481148

1149-
response = my_index_endpoint.match(
1149+
my_index_endpoint.match(
1150+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1151+
num_neighbors=_TEST_NUM_NEIGHBOURS,
1152+
filter=_TEST_FILTER,
1153+
queries=_TEST_HYBRID_QUERIES,
1154+
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
1155+
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
1156+
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1157+
low_level_batch_size=_TEST_LOW_LEVEL_BATCH_SIZE,
1158+
numeric_filter=_TEST_NUMERIC_FILTER,
1159+
)
1160+
1161+
batch_request = match_service_pb2.BatchMatchRequest(
1162+
requests=[
1163+
match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex(
1164+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1165+
low_level_batch_size=_TEST_LOW_LEVEL_BATCH_SIZE,
1166+
requests=[
1167+
match_service_pb2.MatchRequest(
1168+
num_neighbors=_TEST_NUM_NEIGHBOURS,
1169+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1170+
float_val=_TEST_HYBRID_QUERIES[i].dense_embedding,
1171+
restricts=[
1172+
match_service_pb2.Namespace(
1173+
name="class",
1174+
allow_tokens=["token_1"],
1175+
deny_tokens=["token_2"],
1176+
)
1177+
],
1178+
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
1179+
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
1180+
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1181+
numeric_restricts=_TEST_NUMERIC_NAMESPACE,
1182+
sparse_embedding=match_service_pb2.SparseEmbedding(
1183+
float_val=_TEST_HYBRID_QUERIES[
1184+
i
1185+
].sparse_embedding_values,
1186+
dimension=_TEST_HYBRID_QUERIES[
1187+
i
1188+
].sparse_embedding_dimensions,
1189+
),
1190+
rrf=match_service_pb2.MatchRequest.RRF(
1191+
alpha=_TEST_HYBRID_QUERIES[i].rrf_ranking_alpha,
1192+
)
1193+
if _TEST_HYBRID_QUERIES[i].rrf_ranking_alpha
1194+
else None,
1195+
)
1196+
for i in range(len(_TEST_HYBRID_QUERIES))
1197+
],
1198+
)
1199+
]
1200+
)
1201+
1202+
index_endpoint_match_queries_mock.assert_called_with(batch_request)
1203+
1204+
@pytest.mark.usefixtures("get_index_endpoint_mock")
1205+
def test_private_service_access_index_endpoint_match_queries(
1206+
self, index_endpoint_match_queries_mock
1207+
):
1208+
aiplatform.init(project=_TEST_PROJECT)
1209+
1210+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
1211+
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
1212+
)
1213+
1214+
my_index_endpoint.match(
11501215
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
11511216
num_neighbors=_TEST_NUM_NEIGHBOURS,
11521217
filter=_TEST_FILTER,
@@ -1188,8 +1253,6 @@ def test_private_service_access_service_access_index_endpoint_match_queries(
11881253

11891254
index_endpoint_match_queries_mock.assert_called_with(batch_request)
11901255

1191-
assert response == _TEST_PRIVATE_MATCH_NEIGHBOR_RESPONSE
1192-
11931256
@pytest.mark.usefixtures("get_index_endpoint_mock")
11941257
def test_index_private_service_access_endpoint_find_neighbor_queries(
11951258
self, index_endpoint_match_queries_mock
@@ -1200,7 +1263,7 @@ def test_index_private_service_access_endpoint_find_neighbor_queries(
12001263
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
12011264
)
12021265

1203-
response = my_private_index_endpoint.find_neighbors(
1266+
my_private_index_endpoint.find_neighbors(
12041267
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
12051268
queries=_TEST_QUERIES,
12061269
num_neighbors=_TEST_NUM_NEIGHBOURS,
@@ -1240,8 +1303,6 @@ def test_index_private_service_access_endpoint_find_neighbor_queries(
12401303
)
12411304
index_endpoint_match_queries_mock.assert_called_with(batch_match_request)
12421305

1243-
assert response == _TEST_PRIVATE_MATCH_NEIGHBOR_RESPONSE
1244-
12451306
@pytest.mark.usefixtures("get_index_endpoint_mock")
12461307
def test_index_private_service_connect_endpoint_match_queries(
12471308
self, index_endpoint_match_queries_mock

0 commit comments

Comments
 (0)