From cf0d260195287e57451175556deb71811c44fbc7 Mon Sep 17 00:00:00 2001 From: Tim Swena Date: Tue, 19 Nov 2024 11:15:00 -0600 Subject: [PATCH 1/2] feat: `bigframes.bigquery.vector_search` supports `use_brute_force` and `fraction_lists_to_search` parameters --- bigframes/bigquery/_operations/search.py | 53 ++++++++-------- bigframes/core/sql.py | 61 ++++++++++--------- tests/unit/core/test_sql.py | 77 +++++++++++------------- 3 files changed, 91 insertions(+), 100 deletions(-) diff --git a/bigframes/bigquery/_operations/search.py b/bigframes/bigquery/_operations/search.py index 496e259944..45d54e352b 100644 --- a/bigframes/bigquery/_operations/search.py +++ b/bigframes/bigquery/_operations/search.py @@ -18,7 +18,6 @@ import typing from typing import Collection, Literal, Mapping, Optional, Union -import bigframes_vendored.constants as constants import google.cloud.bigquery as bigquery import bigframes.core.sql @@ -96,10 +95,10 @@ def vector_search( query: Union[dataframe.DataFrame, series.Series], *, query_column_to_search: Optional[str] = None, - top_k: Optional[int] = 10, - distance_type: Literal["euclidean", "cosine"] = "euclidean", + top_k: Optional[int] = None, + distance_type: Optional[Literal["euclidean", "cosine", "dot_product"]] = None, fraction_lists_to_search: Optional[float] = None, - use_brute_force: bool = False, + use_brute_force: Optional[bool] = None, ) -> dataframe.DataFrame: """ Conduct vector search which searches embeddings to find semantically similar entities. @@ -141,7 +140,8 @@ def vector_search( ... base_table="bigframes-dev.bigframes_tests_sys.base_table", ... column_to_search="my_embedding", ... query=search_query, - ... top_k=2) + ... top_k=2, + ... use_brute_force=True) embedding id my_embedding distance dog [1. 2.] 1 [1. 2.] 0.0 cat [3. 5.2] 5 [5. 5.4] 2.009975 @@ -185,17 +185,18 @@ def vector_search( find nearest neighbors. The column must have a type of ``ARRAY``. All elements in the array must be non-NULL and all values in the column must have the same array dimensions as the values in the ``column_to_search`` column. Can only be set when query is a DataFrame. - top_k (int, default 10): + top_k (int): Sepecifies the number of nearest neighbors to return. Default to 10. distance_type (str, defalt "euclidean"): Specifies the type of metric to use to compute the distance between two vectors. - Possible values are "euclidean" and "cosine". Default to "euclidean". + Possible values are "euclidean", "cosine" and "dot_product". + Default to "euclidean". fraction_lists_to_search (float, range in [0.0, 1.0]): Specifies the percentage of lists to search. Specifying a higher percentage leads to higher recall and slower performance, and the converse is true when specifying a lower percentage. It is only used when a vector index is also used. You can only specify ``fraction_lists_to_search`` when ``use_brute_force`` is set to False. - use_brute_force (bool, default False): + use_brute_force (bool): Determines whether to use brute force search by skipping the vector index if one is available. Default to False. @@ -204,10 +205,6 @@ def vector_search( """ import bigframes.series - if not fraction_lists_to_search and use_brute_force is True: - raise ValueError( - "You can't specify fraction_lists_to_search when use_brute_force is set to True." - ) if ( isinstance(query, bigframes.series.Series) and query_column_to_search is not None @@ -215,26 +212,28 @@ def vector_search( raise ValueError( "You can't specify query_column_to_search when query is a Series." ) - # TODO(ashleyxu): Support options in vector search. b/344019989 - if fraction_lists_to_search is not None or use_brute_force is True: - raise NotImplementedError( - f"fraction_lists_to_search and use_brute_force is not supported. {constants.FEEDBACK_LINK}" - ) - options = { - "base_table": base_table, - "column_to_search": column_to_search, - "query_column_to_search": query_column_to_search, - "distance_type": distance_type, - "top_k": top_k, - "fraction_lists_to_search": fraction_lists_to_search, - "use_brute_force": use_brute_force, - } + + # Only populate options if not set to the default value. + # This avoids accidentally setting options that are mutually exclusive. + options = None + if fraction_lists_to_search is not None: + options = {} if options is None else options + options["fraction_lists_to_search"] = fraction_lists_to_search + if use_brute_force is not None: + options = {} if options is None else options + options["use_brute_force"] = use_brute_force (query,) = utils.convert_to_dataframe(query) sql_string, index_col_ids, index_labels = query._to_sql_query(include_index=True) sql = bigframes.core.sql.create_vector_search_sql( - sql_string=sql_string, options=options # type: ignore + sql_string=sql_string, + base_table=base_table, + column_to_search=column_to_search, + query_column_to_search=query_column_to_search, + top_k=top_k, + distance_type=distance_type, + options=options, ) if index_col_ids is not None: df = query._session.read_gbq(sql, index_col=index_col_ids) diff --git a/bigframes/core/sql.py b/bigframes/core/sql.py index d5dfc64ddd..79fe680958 100644 --- a/bigframes/core/sql.py +++ b/bigframes/core/sql.py @@ -18,8 +18,9 @@ """ import datetime +import json import math -from typing import cast, Collection, Iterable, Mapping, TYPE_CHECKING, Union +from typing import cast, Collection, Iterable, Mapping, Optional, TYPE_CHECKING, Union import bigframes.core.compile.googlesql as googlesql @@ -157,43 +158,43 @@ def create_vector_index_ddl( def create_vector_search_sql( sql_string: str, - options: Mapping[str, Union[str | int | bool | float]] = {}, + *, + base_table: str, + column_to_search: str, + query_column_to_search: Optional[str] = None, + top_k: Optional[int] = None, + distance_type: Optional[str] = None, + options: Optional[Mapping[str, Union[str | int | bool | float]]] = None, ) -> str: """Encode the VECTOR SEARCH statement for BigQuery Vector Search.""" - base_table = options["base_table"] - column_to_search = options["column_to_search"] - distance_type = options["distance_type"] - top_k = options["top_k"] - query_column_to_search = options.get("query_column_to_search", None) + vector_search_args = [ + f"TABLE {googlesql.identifier(cast(str, base_table))}", + f"{simple_literal(column_to_search)}", + f"({sql_string})", + ] if query_column_to_search is not None: - query_str = f""" - SELECT - query.*, - base.*, - distance, - FROM VECTOR_SEARCH( - TABLE {googlesql.identifier(cast(str, base_table))}, - {simple_literal(column_to_search)}, - ({sql_string}), - {simple_literal(query_column_to_search)}, - distance_type => {simple_literal(distance_type)}, - top_k => {simple_literal(top_k)} - ) - """ - else: - query_str = f""" + vector_search_args.append( + f"query_column_to_search => {simple_literal(query_column_to_search)}" + ) + + if top_k is not None: + vector_search_args.append(f"top_k=> {simple_literal(top_k)}") + + if distance_type is not None: + vector_search_args.append(f"distance_type => {simple_literal(distance_type)}") + + if options is not None: + vector_search_args.append( + f"options => {simple_literal(json.dumps(options, indent=None))}" + ) + + return f""" SELECT query.*, base.*, distance, FROM VECTOR_SEARCH( - TABLE {googlesql.identifier(cast(str, base_table))}, - {simple_literal(column_to_search)}, - ({sql_string}), - distance_type => {simple_literal(distance_type)}, - top_k => {simple_literal(top_k)} - ) +{',\n'.join(vector_search_args)}) """ - return query_str diff --git a/tests/unit/core/test_sql.py b/tests/unit/core/test_sql.py index 29f1e48a70..72b26cf347 100644 --- a/tests/unit/core/test_sql.py +++ b/tests/unit/core/test_sql.py @@ -17,62 +17,53 @@ def test_create_vector_search_sql_simple(): - sql_string = "SELECT embedding FROM my_embeddings_table WHERE id = 1" - options = { - "base_table": "my_base_table", - "column_to_search": "my_embedding_column", - "distance_type": "COSINE", - "top_k": 10, - "use_brute_force": False, - } - - expected_query = f""" + result_query = sql.create_vector_search_sql( + sql_string="SELECT embedding FROM my_embeddings_table WHERE id = 1", + base_table="my_base_table", + column_to_search="my_embedding_column", + ) + assert ( + result_query + == """ SELECT query.*, base.*, distance, FROM VECTOR_SEARCH( - TABLE `my_base_table`, - 'my_embedding_column', - ({sql_string}), - distance_type => 'COSINE', - top_k => 10 - ) +TABLE `my_base_table`, +'my_embedding_column', +(SELECT embedding FROM my_embeddings_table WHERE id = 1)) """ - - result_query = sql.create_vector_search_sql( - sql_string, options # type:ignore ) - assert result_query == expected_query -def test_create_vector_search_sql_query_column_to_search(): - sql_string = "SELECT embedding FROM my_embeddings_table WHERE id = 1" - options = { - "base_table": "my_base_table", - "column_to_search": "my_embedding_column", - "distance_type": "COSINE", - "top_k": 10, - "query_column_to_search": "new_embedding_column", - "use_brute_force": False, - } - - expected_query = f""" +def test_create_vector_search_sql_all_named_parameters(): + result_query = sql.create_vector_search_sql( + sql_string="SELECT embedding FROM my_embeddings_table WHERE id = 1", + base_table="my_base_table", + column_to_search="my_embedding_column", + query_column_to_search="another_embedding_column", + top_k=10, + distance_type="cosine", + options={ + "fraction_lists_to_search": 0.1, + "use_brute_force": False, + }, + ) + assert ( + result_query + == """ SELECT query.*, base.*, distance, FROM VECTOR_SEARCH( - TABLE `my_base_table`, - 'my_embedding_column', - ({sql_string}), - 'new_embedding_column', - distance_type => 'COSINE', - top_k => 10 - ) +TABLE `my_base_table`, +'my_embedding_column', +(SELECT embedding FROM my_embeddings_table WHERE id = 1), +query_column_to_search => 'another_embedding_column', +top_k=> 10, +distance_type => 'cosine', +options => '{\\"fraction_lists_to_search\\": 0.1, \\"use_brute_force\\": false}') """ - - result_query = sql.create_vector_search_sql( - sql_string, options # type:ignore ) - assert result_query == expected_query From f684a7d53c68a7de26c243f44026e0f6eb018359 Mon Sep 17 00:00:00 2001 From: Tim Swena Date: Tue, 19 Nov 2024 11:22:52 -0600 Subject: [PATCH 2/2] fix f-string on lower python versions --- bigframes/core/sql.py | 4 ++-- tests/unit/core/test_sql.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/bigframes/core/sql.py b/bigframes/core/sql.py index 79fe680958..ae0a0de2aa 100644 --- a/bigframes/core/sql.py +++ b/bigframes/core/sql.py @@ -190,11 +190,11 @@ def create_vector_search_sql( f"options => {simple_literal(json.dumps(options, indent=None))}" ) + args_str = ",\n".join(vector_search_args) return f""" SELECT query.*, base.*, distance, - FROM VECTOR_SEARCH( -{',\n'.join(vector_search_args)}) + FROM VECTOR_SEARCH({args_str}) """ diff --git a/tests/unit/core/test_sql.py b/tests/unit/core/test_sql.py index 72b26cf347..a2ee2f359e 100644 --- a/tests/unit/core/test_sql.py +++ b/tests/unit/core/test_sql.py @@ -29,8 +29,7 @@ def test_create_vector_search_sql_simple(): query.*, base.*, distance, - FROM VECTOR_SEARCH( -TABLE `my_base_table`, + FROM VECTOR_SEARCH(TABLE `my_base_table`, 'my_embedding_column', (SELECT embedding FROM my_embeddings_table WHERE id = 1)) """ @@ -57,8 +56,7 @@ def test_create_vector_search_sql_all_named_parameters(): query.*, base.*, distance, - FROM VECTOR_SEARCH( -TABLE `my_base_table`, + FROM VECTOR_SEARCH(TABLE `my_base_table`, 'my_embedding_column', (SELECT embedding FROM my_embeddings_table WHERE id = 1), query_column_to_search => 'another_embedding_column',