Skip to content

Commit f6f697a

Browse files
jialuooshobsi
andauthored
feat: support bigquery connection in managed function (#1554)
* feat: support bigquery connection in managed function * simplify a bit the intended change * restore pytestmark, remove a comment --------- Co-authored-by: Shobhit Singh <[email protected]>
1 parent e7eb918 commit f6f697a

File tree

4 files changed

+68
-10
lines changed

4 files changed

+68
-10
lines changed

bigframes/functions/_function_client.py

+10
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def provision_bq_managed_function(
196196
name,
197197
packages,
198198
is_row_processor,
199+
bq_connection_id,
199200
*,
200201
capture_references=False,
201202
):
@@ -273,12 +274,21 @@ def provision_bq_managed_function(
273274
udf_code = textwrap.dedent(inspect.getsource(func))
274275
udf_code = udf_code[udf_code.index("def") :]
275276

277+
with_connection_clause = (
278+
(
279+
f"WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{self._bq_connection_id}`"
280+
)
281+
if bq_connection_id
282+
else ""
283+
)
284+
276285
create_function_ddl = (
277286
textwrap.dedent(
278287
f"""
279288
CREATE OR REPLACE FUNCTION {persistent_func_id}({','.join(bq_function_args)})
280289
RETURNS {bq_function_return_type}
281290
LANGUAGE python
291+
{with_connection_clause}
282292
OPTIONS ({managed_function_options_str})
283293
AS r'''
284294
__UDF_PLACE_HOLDER__

bigframes/functions/_function_session.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -807,9 +807,13 @@ def udf(
807807

808808
bq_location, _ = _utils.get_remote_function_locations(bigquery_client.location)
809809

810-
# A connection is required for BQ managed function.
811-
bq_connection_id = self._resolve_bigquery_connection_id(
812-
session, dataset_ref, bq_location, bigquery_connection
810+
# A connection is optional for BQ managed function.
811+
bq_connection_id = (
812+
self._resolve_bigquery_connection_id(
813+
session, dataset_ref, bq_location, bigquery_connection
814+
)
815+
if bigquery_connection
816+
else None
813817
)
814818

815819
bq_connection_manager = session.bqconnectionmanager
@@ -907,6 +911,7 @@ def wrapper(func):
907911
name=name,
908912
packages=packages,
909913
is_row_processor=is_row_processor,
914+
bq_connection_id=bq_connection_id,
910915
)
911916

912917
# TODO(shobs): Find a better way to support udfs with param named

tests/system/conftest.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,13 @@ def session_tokyo(tokyo_location: str) -> Generator[bigframes.Session, None, Non
185185

186186

187187
@pytest.fixture(scope="session")
188-
def bq_connection(bigquery_client: bigquery.Client) -> str:
189-
return f"{bigquery_client.project}.{bigquery_client.location}.bigframes-rf-conn"
188+
def bq_connection_name() -> str:
189+
return "bigframes-rf-conn"
190+
191+
192+
@pytest.fixture(scope="session")
193+
def bq_connection(bigquery_client: bigquery.Client, bq_connection_name: str) -> str:
194+
return f"{bigquery_client.project}.{bigquery_client.location}.{bq_connection_name}"
190195

191196

192197
@pytest.fixture(scope="session", autouse=True)

tests/system/large/functions/test_managed_function.py

+43-5
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,7 @@ def featurize(x: int) -> list[float]:
166166
cleanup_function_assets(featurize, session.bqclient, ignore_failures=False)
167167

168168

169-
def test_managed_function_series_apply(
170-
session,
171-
scalars_dfs,
172-
):
169+
def test_managed_function_series_apply(session, scalars_dfs):
173170
try:
174171

175172
@session.udf()
@@ -504,7 +501,10 @@ def test_managed_function_dataframe_apply_axis_1_array_output(session):
504501

505502
try:
506503

507-
@session.udf(input_types=[int, float, str], output_type=list[str])
504+
@session.udf(
505+
input_types=[int, float, str],
506+
output_type=list[str],
507+
)
508508
def foo(x, y, z):
509509
return [str(x), str(y), z]
510510

@@ -587,3 +587,41 @@ def foo(x, y, z):
587587
finally:
588588
# Clean up the gcp assets created for the managed function.
589589
cleanup_function_assets(foo, session.bqclient, ignore_failures=False)
590+
591+
592+
@pytest.mark.parametrize(
593+
"connection_fixture",
594+
[
595+
"bq_connection_name",
596+
"bq_connection",
597+
],
598+
)
599+
def test_managed_function_with_connection(
600+
session, scalars_dfs, request, connection_fixture
601+
):
602+
try:
603+
bigquery_connection = request.getfixturevalue(connection_fixture)
604+
605+
@session.udf(bigquery_connection=bigquery_connection)
606+
def foo(x: int) -> int:
607+
return x + 10
608+
609+
# Function should still work normally.
610+
assert foo(-2) == 8
611+
612+
scalars_df, scalars_pandas_df = scalars_dfs
613+
614+
bf_result_col = scalars_df["int64_too"].apply(foo)
615+
bf_result = (
616+
scalars_df["int64_too"].to_frame().assign(result=bf_result_col).to_pandas()
617+
)
618+
619+
pd_result_col = scalars_pandas_df["int64_too"].apply(foo)
620+
pd_result = (
621+
scalars_pandas_df["int64_too"].to_frame().assign(result=pd_result_col)
622+
)
623+
624+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
625+
finally:
626+
# Clean up the gcp assets created for the managed function.
627+
cleanup_function_assets(foo, session.bqclient, ignore_failures=False)

0 commit comments

Comments
 (0)