Skip to content

Commit 7372ad6

Browse files
authored
feat: updates to allow users to set max_stream_count (#2039)
Adds a function `determine_requested_streams()` to compare `preserve_order` and the new argument `max_stream_count` to determine how many streams to request. ``` preserve_order (bool): Whether to preserve the order of streams. If True, this limits the number of streams to one (more than one cannot guarantee order). max_stream_count (Union[int, None]]): The maximum number of streams allowed. Must be a non-negative number or None, where None indicates the value is unset. If `max_stream_count` is set, it overrides `preserve_order`. ``` Fixes #2030 🦕
1 parent 1d8d0a0 commit 7372ad6

File tree

2 files changed

+130
-19
lines changed

2 files changed

+130
-19
lines changed

google/cloud/bigquery/_pandas_helpers.py

+99-19
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121
import logging
2222
import queue
2323
import warnings
24-
from typing import Any, Union
24+
from typing import Any, Union, Optional, Callable, Generator, List
2525

2626

2727
from google.cloud.bigquery import _pyarrow_helpers
2828
from google.cloud.bigquery import _versions_helpers
2929
from google.cloud.bigquery import schema
3030

31+
3132
try:
3233
import pandas # type: ignore
3334

@@ -75,7 +76,7 @@ def _to_wkb(v):
7576
_to_wkb = _to_wkb()
7677

7778
try:
78-
from google.cloud.bigquery_storage import ArrowSerializationOptions
79+
from google.cloud.bigquery_storage_v1.types import ArrowSerializationOptions
7980
except ImportError:
8081
_ARROW_COMPRESSION_SUPPORT = False
8182
else:
@@ -816,18 +817,54 @@ def _nowait(futures):
816817

817818

818819
def _download_table_bqstorage(
819-
project_id,
820-
table,
821-
bqstorage_client,
822-
preserve_order=False,
823-
selected_fields=None,
824-
page_to_item=None,
825-
max_queue_size=_MAX_QUEUE_SIZE_DEFAULT,
826-
):
827-
"""Use (faster, but billable) BQ Storage API to construct DataFrame."""
820+
project_id: str,
821+
table: Any,
822+
bqstorage_client: Any,
823+
preserve_order: bool = False,
824+
selected_fields: Optional[List[Any]] = None,
825+
page_to_item: Optional[Callable] = None,
826+
max_queue_size: Any = _MAX_QUEUE_SIZE_DEFAULT,
827+
max_stream_count: Optional[int] = None,
828+
) -> Generator[Any, None, None]:
829+
"""Downloads a BigQuery table using the BigQuery Storage API.
830+
831+
This method uses the faster, but potentially more expensive, BigQuery
832+
Storage API to download a table as a Pandas DataFrame. It supports
833+
parallel downloads and optional data transformations.
834+
835+
Args:
836+
project_id (str): The ID of the Google Cloud project containing
837+
the table.
838+
table (Any): The BigQuery table to download.
839+
bqstorage_client (Any): An
840+
authenticated BigQuery Storage API client.
841+
preserve_order (bool, optional): Whether to preserve the order
842+
of the rows as they are read from BigQuery. If True this limits
843+
the number of streams to one and overrides `max_stream_count`.
844+
Defaults to False.
845+
selected_fields (Optional[List[SchemaField]]):
846+
A list of BigQuery schema fields to select for download. If None,
847+
all fields are downloaded. Defaults to None.
848+
page_to_item (Optional[Callable]): An optional callable
849+
function that takes a page of data from the BigQuery Storage API
850+
max_stream_count (Optional[int]): The maximum number of
851+
concurrent streams to use for downloading data. If `preserve_order`
852+
is True, the requested streams are limited to 1 regardless of the
853+
`max_stream_count` value. If 0 or None, then the number of
854+
requested streams will be unbounded. Defaults to None.
855+
856+
Yields:
857+
pandas.DataFrame: Pandas DataFrames, one for each chunk of data
858+
downloaded from BigQuery.
859+
860+
Raises:
861+
ValueError: If attempting to read from a specific partition or snapshot.
862+
863+
Note:
864+
This method requires the `google-cloud-bigquery-storage` library
865+
to be installed.
866+
"""
828867

829-
# Passing a BQ Storage client in implies that the BigQuery Storage library
830-
# is available and can be imported.
831868
from google.cloud import bigquery_storage
832869

833870
if "$" in table.table_id:
@@ -837,18 +874,20 @@ def _download_table_bqstorage(
837874
if "@" in table.table_id:
838875
raise ValueError("Reading from a specific snapshot is not currently supported.")
839876

840-
requested_streams = 1 if preserve_order else 0
877+
requested_streams = determine_requested_streams(preserve_order, max_stream_count)
841878

842-
requested_session = bigquery_storage.types.ReadSession(
843-
table=table.to_bqstorage(), data_format=bigquery_storage.types.DataFormat.ARROW
879+
requested_session = bigquery_storage.types.stream.ReadSession(
880+
table=table.to_bqstorage(),
881+
data_format=bigquery_storage.types.stream.DataFormat.ARROW,
844882
)
845883
if selected_fields is not None:
846884
for field in selected_fields:
847885
requested_session.read_options.selected_fields.append(field.name)
848886

849887
if _ARROW_COMPRESSION_SUPPORT:
850888
requested_session.read_options.arrow_serialization_options.buffer_compression = (
851-
ArrowSerializationOptions.CompressionCodec.LZ4_FRAME
889+
# CompressionCodec(1) -> LZ4_FRAME
890+
ArrowSerializationOptions.CompressionCodec(1)
852891
)
853892

854893
session = bqstorage_client.create_read_session(
@@ -884,7 +923,7 @@ def _download_table_bqstorage(
884923
elif max_queue_size is None:
885924
max_queue_size = 0 # unbounded
886925

887-
worker_queue = queue.Queue(maxsize=max_queue_size)
926+
worker_queue: queue.Queue[int] = queue.Queue(maxsize=max_queue_size)
888927

889928
with concurrent.futures.ThreadPoolExecutor(max_workers=total_streams) as pool:
890929
try:
@@ -910,7 +949,7 @@ def _download_table_bqstorage(
910949
# we want to block on the queue's get method, instead. This
911950
# prevents the queue from filling up, because the main thread
912951
# has smaller gaps in time between calls to the queue's get
913-
# method. For a detailed explaination, see:
952+
# method. For a detailed explanation, see:
914953
# https://friendliness.dev/2019/06/18/python-nowait/
915954
done, not_done = _nowait(not_done)
916955
for future in done:
@@ -949,6 +988,7 @@ def download_arrow_bqstorage(
949988
preserve_order=False,
950989
selected_fields=None,
951990
max_queue_size=_MAX_QUEUE_SIZE_DEFAULT,
991+
max_stream_count=None,
952992
):
953993
return _download_table_bqstorage(
954994
project_id,
@@ -958,6 +998,7 @@ def download_arrow_bqstorage(
958998
selected_fields=selected_fields,
959999
page_to_item=_bqstorage_page_to_arrow,
9601000
max_queue_size=max_queue_size,
1001+
max_stream_count=max_stream_count,
9611002
)
9621003

9631004

@@ -970,6 +1011,7 @@ def download_dataframe_bqstorage(
9701011
preserve_order=False,
9711012
selected_fields=None,
9721013
max_queue_size=_MAX_QUEUE_SIZE_DEFAULT,
1014+
max_stream_count=None,
9731015
):
9741016
page_to_item = functools.partial(_bqstorage_page_to_dataframe, column_names, dtypes)
9751017
return _download_table_bqstorage(
@@ -980,6 +1022,7 @@ def download_dataframe_bqstorage(
9801022
selected_fields=selected_fields,
9811023
page_to_item=page_to_item,
9821024
max_queue_size=max_queue_size,
1025+
max_stream_count=max_stream_count,
9831026
)
9841027

9851028

@@ -1024,3 +1067,40 @@ def verify_pandas_imports():
10241067
raise ValueError(_NO_PANDAS_ERROR) from pandas_import_exception
10251068
if db_dtypes is None:
10261069
raise ValueError(_NO_DB_TYPES_ERROR) from db_dtypes_import_exception
1070+
1071+
1072+
def determine_requested_streams(
1073+
preserve_order: bool,
1074+
max_stream_count: Union[int, None],
1075+
) -> int:
1076+
"""Determines the value of requested_streams based on the values of
1077+
`preserve_order` and `max_stream_count`.
1078+
1079+
Args:
1080+
preserve_order (bool): Whether to preserve the order of streams. If True,
1081+
this limits the number of streams to one. `preserve_order` takes
1082+
precedence over `max_stream_count`.
1083+
max_stream_count (Union[int, None]]): The maximum number of streams
1084+
allowed. Must be a non-negative number or None, where None indicates
1085+
the value is unset. NOTE: if `preserve_order` is also set, it takes
1086+
precedence over `max_stream_count`, thus to ensure that `max_stream_count`
1087+
is used, ensure that `preserve_order` is None.
1088+
1089+
Returns:
1090+
(int) The appropriate value for requested_streams.
1091+
"""
1092+
1093+
if preserve_order:
1094+
# If preserve order is set, it takes precendence.
1095+
# Limit the requested streams to 1, to ensure that order
1096+
# is preserved)
1097+
return 1
1098+
1099+
elif max_stream_count is not None:
1100+
# If preserve_order is not set, only then do we consider max_stream_count
1101+
if max_stream_count <= -1:
1102+
raise ValueError("max_stream_count must be non-negative OR None")
1103+
return max_stream_count
1104+
1105+
# Default to zero requested streams (unbounded).
1106+
return 0

tests/unit/test__pandas_helpers.py

+31
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import functools
1919
import operator
2020
import queue
21+
from typing import Union
2122
from unittest import mock
2223
import warnings
2324

@@ -46,6 +47,7 @@
4647
from google.cloud.bigquery import _pyarrow_helpers
4748
from google.cloud.bigquery import _versions_helpers
4849
from google.cloud.bigquery import schema
50+
from google.cloud.bigquery._pandas_helpers import determine_requested_streams
4951

5052
pyarrow = _versions_helpers.PYARROW_VERSIONS.try_import()
5153

@@ -2053,3 +2055,32 @@ def test_verify_pandas_imports_no_db_dtypes(module_under_test, monkeypatch):
20532055
monkeypatch.setattr(module_under_test, "db_dtypes", None)
20542056
with pytest.raises(ValueError, match="Please install the 'db-dtypes' package"):
20552057
module_under_test.verify_pandas_imports()
2058+
2059+
2060+
@pytest.mark.parametrize(
2061+
"preserve_order, max_stream_count, expected_requested_streams",
2062+
[
2063+
# If preserve_order is set/True, it takes precedence:
2064+
(True, 10, 1), # use 1
2065+
(True, None, 1), # use 1
2066+
# If preserve_order is not set check max_stream_count:
2067+
(False, 10, 10), # max_stream_count (X) takes precedence
2068+
(False, None, 0), # Unbounded (0) when both are unset
2069+
],
2070+
)
2071+
def test_determine_requested_streams(
2072+
preserve_order: bool,
2073+
max_stream_count: Union[int, None],
2074+
expected_requested_streams: int,
2075+
):
2076+
"""Tests various combinations of preserve_order and max_stream_count."""
2077+
actual_requested_streams = determine_requested_streams(
2078+
preserve_order, max_stream_count
2079+
)
2080+
assert actual_requested_streams == expected_requested_streams
2081+
2082+
2083+
def test_determine_requested_streams_invalid_max_stream_count():
2084+
"""Tests that a ValueError is raised if max_stream_count is negative."""
2085+
with pytest.raises(ValueError):
2086+
determine_requested_streams(preserve_order=False, max_stream_count=-1)

0 commit comments

Comments
 (0)