Skip to content

Commit e7eb918

Browse files
authored
feat: support bq connection path format (#1550)
* feat: support bq connection path format For example, now a user specified bq connection like "projects/project_id/locations/northamerica-northeast1/connections/conn-name" would also be supported. * include path format in tests with connection mismatch * pass cloud_function_service_account="default" in more tests
1 parent 3104fab commit e7eb918

File tree

7 files changed

+182
-96
lines changed

7 files changed

+182
-96
lines changed

bigframes/clients.py

+40-14
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
import logging
20+
import textwrap
2021
import time
2122
from typing import cast, Optional
2223

@@ -28,21 +29,46 @@
2829
logger = logging.getLogger(__name__)
2930

3031

31-
def resolve_full_bq_connection_name(
32-
connection_name: str, default_project: str, default_location: str
32+
def get_canonical_bq_connection_id(
33+
connection_id: str, default_project: str, default_location: str
3334
) -> str:
34-
"""Retrieve the full connection name of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
35-
Use default project, location or connection_id when any of them are missing."""
36-
if connection_name.count(".") == 2:
37-
return connection_name
38-
39-
if connection_name.count(".") == 1:
40-
return f"{default_project}.{connection_name}"
41-
42-
if connection_name.count(".") == 0:
43-
return f"{default_project}.{default_location}.{connection_name}"
44-
45-
raise ValueError(f"Invalid connection name format: {connection_name}.")
35+
"""
36+
Retrieve the full connection id of the form
37+
<PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
38+
Use default project, location or connection_id when any of them are missing.
39+
"""
40+
41+
if "/" in connection_id:
42+
fields = connection_id.split("/")
43+
if (
44+
len(fields) == 6
45+
and fields[0] == "projects"
46+
and fields[2] == "locations"
47+
and fields[4] == "connections"
48+
):
49+
return ".".join((fields[1], fields[3], fields[5]))
50+
else:
51+
if connection_id.count(".") == 2:
52+
return connection_id
53+
54+
if connection_id.count(".") == 1:
55+
return f"{default_project}.{connection_id}"
56+
57+
if connection_id.count(".") == 0:
58+
return f"{default_project}.{default_location}.{connection_id}"
59+
60+
raise ValueError(
61+
textwrap.dedent(
62+
f"""
63+
Invalid connection id format: {connection_id}.
64+
Only the following formats are supported:
65+
<project-id>.<location>.<connection-id>,
66+
<location>.<connection-id>,
67+
<connection-id>,
68+
projects/<project-id>/locations/<location>/connections/<connection-id>
69+
"""
70+
).strip()
71+
)
4672

4773

4874
class BqConnectionManager:

bigframes/functions/_function_session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _resolve_bigquery_connection_id(
167167
if not bigquery_connection:
168168
bigquery_connection = session._bq_connection # type: ignore
169169

170-
bigquery_connection = clients.resolve_full_bq_connection_name(
170+
bigquery_connection = clients.get_canonical_bq_connection_id(
171171
bigquery_connection,
172172
default_project=dataset_ref.project,
173173
default_location=bq_location,

bigframes/operations/blob.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def _resolve_connection(self, connection: Optional[str] = None) -> str:
297297
ValueError: If the connection cannot be resolved to a valid string.
298298
"""
299299
connection = connection or self._block.session._bq_connection
300-
return clients.resolve_full_bq_connection_name(
300+
return clients.get_canonical_bq_connection_id(
301301
connection,
302302
default_project=self._block.session._project,
303303
default_location=self._block.session._location,

bigframes/session/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1775,8 +1775,8 @@ def _create_bq_connection(
17751775
"""Create the connection with the session settings and try to attach iam role to the connection SA.
17761776
If any of project, location or connection isn't specified, use the session defaults. Returns fully-qualified connection name."""
17771777
connection = self._bq_connection if not connection else connection
1778-
connection = bigframes.clients.resolve_full_bq_connection_name(
1779-
connection_name=connection,
1778+
connection = bigframes.clients.get_canonical_bq_connection_id(
1779+
connection_id=connection,
17801780
default_project=self._project,
17811781
default_location=self._location,
17821782
)

tests/system/large/functions/test_remote_function.py

+30
Original file line numberDiff line numberDiff line change
@@ -2819,3 +2819,33 @@ def featurize(x: int) -> list[float]:
28192819
cleanup_function_assets(
28202820
featurize, session.bqclient, session.cloudfunctionsclient
28212821
)
2822+
2823+
2824+
@pytest.mark.flaky(retries=2, delay=120)
2825+
def test_remote_function_connection_path_format(
2826+
session, scalars_dfs, dataset_id, bq_cf_connection
2827+
):
2828+
try:
2829+
2830+
@session.remote_function(
2831+
dataset=dataset_id,
2832+
bigquery_connection=f"projects/{session.bqclient.project}/locations/{session._location}/connections/{bq_cf_connection}",
2833+
reuse=False,
2834+
cloud_function_service_account="default",
2835+
)
2836+
def foo(x: int) -> int:
2837+
return x + 1
2838+
2839+
scalars_df, scalars_pandas_df = scalars_dfs
2840+
2841+
bf_int64_col = scalars_df["int64_too"]
2842+
bf_result = bf_int64_col.apply(foo).to_pandas()
2843+
2844+
pd_int64_col = scalars_pandas_df["int64_too"]
2845+
pd_result = pd_int64_col.apply(foo)
2846+
2847+
# ignore any dtype disparity
2848+
pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
2849+
finally:
2850+
# clean up the gcp assets created for the remote function
2851+
cleanup_function_assets(foo, session.bqclient, session.cloudfunctionsclient)

tests/system/small/functions/test_remote_function.py

+78-66
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import test_utils.prefixer
2626

2727
import bigframes
28+
import bigframes.clients
2829
import bigframes.dtypes
2930
import bigframes.exceptions
3031
from bigframes.functions import _utils as bff_utils
@@ -93,6 +94,11 @@ def session_with_bq_connection(bq_cf_connection) -> bigframes.Session:
9394
return session
9495

9596

97+
def get_bq_connection_id_path_format(connection_id_dot_format):
98+
fields = connection_id_dot_format.split(".")
99+
return f"projects/{fields[0]}/locations/{fields[1]}/connections/{fields[2]}"
100+
101+
96102
@pytest.mark.flaky(retries=2, delay=120)
97103
def test_remote_function_direct_no_session_param(
98104
bigquery_client,
@@ -155,11 +161,8 @@ def square(x):
155161

156162

157163
@pytest.mark.flaky(retries=2, delay=120)
158-
def test_remote_function_direct_no_session_param_location_specified(
159-
bigquery_client,
160-
bigqueryconnection_client,
161-
cloudfunctions_client,
162-
resourcemanager_client,
164+
def test_remote_function_connection_w_location(
165+
session,
163166
scalars_dfs,
164167
dataset_id_permanent,
165168
bq_cf_connection_location,
@@ -170,10 +173,7 @@ def square(x):
170173
square = bff.remote_function(
171174
input_types=int,
172175
output_type=int,
173-
bigquery_client=bigquery_client,
174-
bigquery_connection_client=bigqueryconnection_client,
175-
cloud_functions_client=cloudfunctions_client,
176-
resource_manager_client=resourcemanager_client,
176+
session=session,
177177
dataset=dataset_id_permanent,
178178
bigquery_connection=bq_cf_connection_location,
179179
# See e2e tests for tests that actually deploy the Cloud Function.
@@ -210,11 +210,8 @@ def square(x):
210210

211211

212212
@pytest.mark.flaky(retries=2, delay=120)
213-
def test_remote_function_direct_no_session_param_location_mismatched(
214-
bigquery_client,
215-
bigqueryconnection_client,
216-
cloudfunctions_client,
217-
resourcemanager_client,
213+
def test_remote_function_connection_w_location_mismatched(
214+
session,
218215
dataset_id_permanent,
219216
bq_cf_connection_location_mismatched,
220217
):
@@ -223,32 +220,41 @@ def square(x):
223220
# connection doesn't match the location of the dataset.
224221
return x * x # pragma: NO COVER
225222

226-
with pytest.raises(
227-
ValueError,
228-
match=re.escape("The location does not match BigQuery connection location:"),
229-
):
230-
bff.remote_function(
231-
input_types=int,
232-
output_type=int,
233-
bigquery_client=bigquery_client,
234-
bigquery_connection_client=bigqueryconnection_client,
235-
cloud_functions_client=cloudfunctions_client,
236-
resource_manager_client=resourcemanager_client,
237-
dataset=dataset_id_permanent,
238-
bigquery_connection=bq_cf_connection_location_mismatched,
239-
# See e2e tests for tests that actually deploy the Cloud Function.
240-
reuse=True,
241-
name=get_function_name(square),
242-
cloud_function_service_account="default",
243-
)(square)
223+
bq_cf_connection_location_mismatched_path_fmt = get_bq_connection_id_path_format(
224+
bigframes.clients.get_canonical_bq_connection_id(
225+
bq_cf_connection_location_mismatched,
226+
session.bqclient.project,
227+
session._location,
228+
)
229+
)
230+
connection_ids = [
231+
bq_cf_connection_location_mismatched,
232+
bq_cf_connection_location_mismatched_path_fmt,
233+
]
234+
235+
for connection_id in connection_ids:
236+
with pytest.raises(
237+
ValueError,
238+
match=re.escape(
239+
"The location does not match BigQuery connection location:"
240+
),
241+
):
242+
bff.remote_function(
243+
input_types=int,
244+
output_type=int,
245+
session=session,
246+
dataset=dataset_id_permanent,
247+
bigquery_connection=connection_id,
248+
# See e2e tests for tests that actually deploy the Cloud Function.
249+
reuse=True,
250+
name=get_function_name(square),
251+
cloud_function_service_account="default",
252+
)(square)
244253

245254

246255
@pytest.mark.flaky(retries=2, delay=120)
247-
def test_remote_function_direct_no_session_param_location_project_specified(
248-
bigquery_client,
249-
bigqueryconnection_client,
250-
cloudfunctions_client,
251-
resourcemanager_client,
256+
def test_remote_function_connection_w_location_project(
257+
session,
252258
scalars_dfs,
253259
dataset_id_permanent,
254260
bq_cf_connection_location_project,
@@ -259,10 +265,7 @@ def square(x):
259265
square = bff.remote_function(
260266
input_types=int,
261267
output_type=int,
262-
bigquery_client=bigquery_client,
263-
bigquery_connection_client=bigqueryconnection_client,
264-
cloud_functions_client=cloudfunctions_client,
265-
resource_manager_client=resourcemanager_client,
268+
session=session,
266269
dataset=dataset_id_permanent,
267270
bigquery_connection=bq_cf_connection_location_project,
268271
# See e2e tests for tests that actually deploy the Cloud Function.
@@ -299,11 +302,8 @@ def square(x):
299302

300303

301304
@pytest.mark.flaky(retries=2, delay=120)
302-
def test_remote_function_direct_no_session_param_project_mismatched(
303-
bigquery_client,
304-
bigqueryconnection_client,
305-
cloudfunctions_client,
306-
resourcemanager_client,
305+
def test_remote_function_connection_w_project_mismatched(
306+
session,
307307
dataset_id_permanent,
308308
bq_cf_connection_location_project_mismatched,
309309
):
@@ -312,26 +312,38 @@ def square(x):
312312
# connection doesn't match the project of the dataset.
313313
return x * x # pragma: NO COVER
314314

315-
with pytest.raises(
316-
ValueError,
317-
match=re.escape(
318-
"The project_id does not match BigQuery connection gcp_project_id:"
319-
),
320-
):
321-
bff.remote_function(
322-
input_types=int,
323-
output_type=int,
324-
bigquery_client=bigquery_client,
325-
bigquery_connection_client=bigqueryconnection_client,
326-
cloud_functions_client=cloudfunctions_client,
327-
resource_manager_client=resourcemanager_client,
328-
dataset=dataset_id_permanent,
329-
bigquery_connection=bq_cf_connection_location_project_mismatched,
330-
# See e2e tests for tests that actually deploy the Cloud Function.
331-
reuse=True,
332-
name=get_function_name(square),
333-
cloud_function_service_account="default",
334-
)(square)
315+
bq_cf_connection_location_project_mismatched_path_fmt = (
316+
get_bq_connection_id_path_format(
317+
bigframes.clients.get_canonical_bq_connection_id(
318+
bq_cf_connection_location_project_mismatched,
319+
session.bqclient.project,
320+
session._location,
321+
)
322+
)
323+
)
324+
connection_ids = [
325+
bq_cf_connection_location_project_mismatched,
326+
bq_cf_connection_location_project_mismatched_path_fmt,
327+
]
328+
329+
for connection_id in connection_ids:
330+
with pytest.raises(
331+
ValueError,
332+
match=re.escape(
333+
"The project_id does not match BigQuery connection gcp_project_id:"
334+
),
335+
):
336+
bff.remote_function(
337+
input_types=int,
338+
output_type=int,
339+
session=session,
340+
dataset=dataset_id_permanent,
341+
bigquery_connection=connection_id,
342+
# See e2e tests for tests that actually deploy the Cloud Function.
343+
reuse=True,
344+
name=get_function_name(square),
345+
cloud_function_service_account="default",
346+
)(square)
335347

336348

337349
@pytest.mark.flaky(retries=2, delay=120)

0 commit comments

Comments
 (0)