Skip to content

Commit 3aa643f

Browse files
authored
feat: allow single input type in remote_function (#641)
* feat: allow single input type in `remote_function` * say sequence instead of list in the remote_function docstring * fix more doc
1 parent a5c94ec commit 3aa643f

File tree

7 files changed

+50
-16
lines changed

7 files changed

+50
-16
lines changed

bigframes/functions/remote_function.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import sys
2525
import tempfile
2626
import textwrap
27-
from typing import List, NamedTuple, Optional, Sequence, TYPE_CHECKING
27+
from typing import List, NamedTuple, Optional, Sequence, TYPE_CHECKING, Union
2828

2929
import ibis
3030
import requests
@@ -623,7 +623,7 @@ def get_routine_reference(
623623
# which has moved as @js to the ibis package
624624
# https://github.com/ibis-project/ibis/blob/master/ibis/backends/bigquery/udf/__init__.py
625625
def remote_function(
626-
input_types: Sequence[type],
626+
input_types: Union[type, Sequence[type]],
627627
output_type: type,
628628
session: Optional[Session] = None,
629629
bigquery_client: Optional[bigquery.Client] = None,
@@ -686,9 +686,10 @@ def remote_function(
686686
`$ gcloud projects add-iam-policy-binding PROJECT_ID --member="serviceAccount:CONNECTION_SERVICE_ACCOUNT_ID" --role="roles/run.invoker"`.
687687
688688
Args:
689-
input_types list(type):
690-
List of input data types in the user defined function.
691-
output_type type:
689+
input_types (type or sequence(type)):
690+
Input data type, or sequence of input data types in the user
691+
defined function.
692+
output_type (type):
692693
Data type of the output in the user defined function.
693694
session (bigframes.Session, Optional):
694695
BigQuery DataFrames session to use for getting default project,
@@ -778,6 +779,9 @@ def remote_function(
778779
By default BigQuery DataFrames uses a 10 minute timeout. `None`
779780
can be passed to let the cloud functions default timeout take effect.
780781
"""
782+
if isinstance(input_types, type):
783+
input_types = [input_types]
784+
781785
import bigframes.pandas as bpd
782786

783787
session = session or bpd.get_global_session()

bigframes/pandas/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ def read_parquet(
633633

634634

635635
def remote_function(
636-
input_types: List[type],
636+
input_types: Union[type, Sequence[type]],
637637
output_type: type,
638638
dataset: Optional[str] = None,
639639
bigquery_connection: Optional[str] = None,

bigframes/session/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,7 +1538,7 @@ def _ibis_to_temp_table(
15381538

15391539
def remote_function(
15401540
self,
1541-
input_types: List[type],
1541+
input_types: Union[type, Sequence[type]],
15421542
output_type: type,
15431543
dataset: Optional[str] = None,
15441544
bigquery_connection: Optional[str] = None,
@@ -1592,8 +1592,9 @@ def remote_function(
15921592
`$ gcloud projects add-iam-policy-binding PROJECT_ID --member="serviceAccount:CONNECTION_SERVICE_ACCOUNT_ID" --role="roles/run.invoker"`.
15931593
15941594
Args:
1595-
input_types (list(type)):
1596-
List of input data types in the user defined function.
1595+
input_types (type or sequence(type)):
1596+
Input data type, or sequence of input data types in the user
1597+
defined function.
15971598
output_type (type):
15981599
Data type of the output in the user defined function.
15991600
dataset (str, Optional):

samples/snippets/remote_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def run_remote_function_and_read_gbq_function(project_id: str):
4747
# of the penguins, which is a real number, into a category, which is a
4848
# string.
4949
@bpd.remote_function(
50-
[float],
50+
float,
5151
str,
5252
reuse=False,
5353
)
@@ -91,7 +91,7 @@ def get_bucket(num):
9191
# as a remote function. The custom function in this example has external
9292
# package dependency, which can be specified via `packages` parameter.
9393
@bpd.remote_function(
94-
[str],
94+
str,
9595
str,
9696
reuse=False,
9797
packages=["cryptography"],

tests/system/large/test_remote_function.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,35 @@ def add_one(x):
310310
)
311311

312312

313+
@pytest.mark.parametrize(
314+
("input_types"),
315+
[
316+
pytest.param([int], id="list-of-int"),
317+
pytest.param(int, id="int"),
318+
],
319+
)
320+
@pytest.mark.flaky(retries=2, delay=120)
321+
def test_remote_function_input_types(session, scalars_dfs, input_types):
322+
try:
323+
324+
def add_one(x):
325+
return x + 1
326+
327+
remote_add_one = session.remote_function(input_types, int)(add_one)
328+
329+
scalars_df, scalars_pandas_df = scalars_dfs
330+
331+
bf_result = scalars_df.int64_too.map(remote_add_one).to_pandas()
332+
pd_result = scalars_pandas_df.int64_too.map(add_one)
333+
334+
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
335+
finally:
336+
# clean up the gcp assets created for the remote function
337+
cleanup_remote_function_assets(
338+
session.bqclient, session.cloudfunctionsclient, remote_add_one
339+
)
340+
341+
313342
@pytest.mark.flaky(retries=2, delay=120)
314343
def test_remote_function_explicit_dataset_not_created(
315344
session,

third_party/bigframes_vendored/pandas/core/frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3892,7 +3892,7 @@ def map(self, func, na_action: Optional[str] = None) -> DataFrame:
38923892
to potentially reuse a previously deployed ``remote_function`` from
38933893
the same user defined function.
38943894
3895-
>>> @bpd.remote_function([int], float, reuse=False)
3895+
>>> @bpd.remote_function(int, float, reuse=False)
38963896
... def minutes_to_hours(x):
38973897
... return x/60
38983898

third_party/bigframes_vendored/pandas/core/series.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,7 +1181,7 @@ def apply(
11811181
to potentially reuse a previously deployed `remote_function` from
11821182
the same user defined function.
11831183
1184-
>>> @bpd.remote_function([int], float, reuse=False)
1184+
>>> @bpd.remote_function(int, float, reuse=False)
11851185
... def minutes_to_hours(x):
11861186
... return x/60
11871187
@@ -1208,7 +1208,7 @@ def apply(
12081208
`packages` param.
12091209
12101210
>>> @bpd.remote_function(
1211-
... [str],
1211+
... str,
12121212
... str,
12131213
... reuse=False,
12141214
... packages=["cryptography"],
@@ -3341,7 +3341,7 @@ def mask(self, cond, other):
33413341
condition is evaluated based on a complicated business logic which cannot
33423342
be expressed in form of a Series.
33433343
3344-
>>> @bpd.remote_function([str], bool, reuse=False)
3344+
>>> @bpd.remote_function(str, bool, reuse=False)
33453345
... def should_mask(name):
33463346
... hash = 0
33473347
... for char_ in name:
@@ -3860,7 +3860,7 @@ def map(
38603860
38613861
It also accepts a remote function:
38623862
3863-
>>> @bpd.remote_function([str], str)
3863+
>>> @bpd.remote_function(str, str)
38643864
... def my_mapper(val):
38653865
... vowels = ["a", "e", "i", "o", "u"]
38663866
... if val:

0 commit comments

Comments
 (0)