Skip to content

Commit 86f6a51

Browse files
tswaststeffnay
andauthored
perf: cache first page of jobs.getQueryResults rows (#374)
Co-authored-by: Steffany Brown <[email protected]>
1 parent cd9febd commit 86f6a51

File tree

6 files changed

+115
-60
lines changed

6 files changed

+115
-60
lines changed

google/cloud/bigquery/client.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1534,7 +1534,7 @@ def _get_query_results(
15341534
A new ``_QueryResults`` instance.
15351535
"""
15361536

1537-
extra_params = {"maxResults": 0}
1537+
extra_params = {}
15381538

15391539
if project is None:
15401540
project = self.project
@@ -3187,6 +3187,7 @@ def _list_rows_from_query_results(
31873187
page_size=None,
31883188
retry=DEFAULT_RETRY,
31893189
timeout=None,
3190+
first_page_response=None,
31903191
):
31913192
"""List the rows of a completed query.
31923193
See
@@ -3247,6 +3248,7 @@ def _list_rows_from_query_results(
32473248
table=destination,
32483249
extra_params=params,
32493250
total_rows=total_rows,
3251+
first_page_response=first_page_response,
32503252
)
32513253
return row_iterator
32523254

google/cloud/bigquery/job/query.py

+54-31
Original file line numberDiff line numberDiff line change
@@ -990,48 +990,22 @@ def done(self, retry=DEFAULT_RETRY, timeout=None, reload=True):
990990
Returns:
991991
bool: True if the job is complete, False otherwise.
992992
"""
993-
is_done = (
994-
# Only consider a QueryJob complete when we know we have the final
995-
# query results available.
996-
self._query_results is not None
997-
and self._query_results.complete
998-
and self.state == _DONE_STATE
999-
)
1000993
# Do not refresh if the state is already done, as the job will not
1001994
# change once complete.
995+
is_done = self.state == _DONE_STATE
1002996
if not reload or is_done:
1003997
return is_done
1004998

1005-
# Since the API to getQueryResults can hang up to the timeout value
1006-
# (default of 10 seconds), set the timeout parameter to ensure that
1007-
# the timeout from the futures API is respected. See:
1008-
# https://github.com/GoogleCloudPlatform/google-cloud-python/issues/4135
1009-
timeout_ms = None
1010-
if self._done_timeout is not None:
1011-
# Subtract a buffer for context switching, network latency, etc.
1012-
api_timeout = self._done_timeout - _TIMEOUT_BUFFER_SECS
1013-
api_timeout = max(min(api_timeout, 10), 0)
1014-
self._done_timeout -= api_timeout
1015-
self._done_timeout = max(0, self._done_timeout)
1016-
timeout_ms = int(api_timeout * 1000)
999+
self._reload_query_results(retry=retry, timeout=timeout)
10171000

10181001
# If an explicit timeout is not given, fall back to the transport timeout
10191002
# stored in _blocking_poll() in the process of polling for job completion.
10201003
transport_timeout = timeout if timeout is not None else self._transport_timeout
10211004

1022-
self._query_results = self._client._get_query_results(
1023-
self.job_id,
1024-
retry,
1025-
project=self.project,
1026-
timeout_ms=timeout_ms,
1027-
location=self.location,
1028-
timeout=transport_timeout,
1029-
)
1030-
10311005
# Only reload the job once we know the query is complete.
10321006
# This will ensure that fields such as the destination table are
10331007
# correctly populated.
1034-
if self._query_results.complete and self.state != _DONE_STATE:
1008+
if self._query_results.complete:
10351009
self.reload(retry=retry, timeout=transport_timeout)
10361010

10371011
return self.state == _DONE_STATE
@@ -1098,6 +1072,45 @@ def _begin(self, client=None, retry=DEFAULT_RETRY, timeout=None):
10981072
exc.query_job = self
10991073
raise
11001074

1075+
def _reload_query_results(self, retry=DEFAULT_RETRY, timeout=None):
1076+
"""Refresh the cached query results.
1077+
1078+
Args:
1079+
retry (Optional[google.api_core.retry.Retry]):
1080+
How to retry the call that retrieves query results.
1081+
timeout (Optional[float]):
1082+
The number of seconds to wait for the underlying HTTP transport
1083+
before using ``retry``.
1084+
"""
1085+
if self._query_results and self._query_results.complete:
1086+
return
1087+
1088+
# Since the API to getQueryResults can hang up to the timeout value
1089+
# (default of 10 seconds), set the timeout parameter to ensure that
1090+
# the timeout from the futures API is respected. See:
1091+
# https://github.com/GoogleCloudPlatform/google-cloud-python/issues/4135
1092+
timeout_ms = None
1093+
if self._done_timeout is not None:
1094+
# Subtract a buffer for context switching, network latency, etc.
1095+
api_timeout = self._done_timeout - _TIMEOUT_BUFFER_SECS
1096+
api_timeout = max(min(api_timeout, 10), 0)
1097+
self._done_timeout -= api_timeout
1098+
self._done_timeout = max(0, self._done_timeout)
1099+
timeout_ms = int(api_timeout * 1000)
1100+
1101+
# If an explicit timeout is not given, fall back to the transport timeout
1102+
# stored in _blocking_poll() in the process of polling for job completion.
1103+
transport_timeout = timeout if timeout is not None else self._transport_timeout
1104+
1105+
self._query_results = self._client._get_query_results(
1106+
self.job_id,
1107+
retry,
1108+
project=self.project,
1109+
timeout_ms=timeout_ms,
1110+
location=self.location,
1111+
timeout=transport_timeout,
1112+
)
1113+
11011114
def result(
11021115
self,
11031116
page_size=None,
@@ -1144,6 +1157,11 @@ def result(
11441157
"""
11451158
try:
11461159
super(QueryJob, self).result(retry=retry, timeout=timeout)
1160+
1161+
# Since the job could already be "done" (e.g. got a finished job
1162+
# via client.get_job), the superclass call to done() might not
1163+
# set the self._query_results cache.
1164+
self._reload_query_results(retry=retry, timeout=timeout)
11471165
except exceptions.GoogleAPICallError as exc:
11481166
exc.message += self._format_for_exception(self.query, self.job_id)
11491167
exc.query_job = self
@@ -1158,10 +1176,14 @@ def result(
11581176
if self._query_results.total_rows is None:
11591177
return _EmptyRowIterator()
11601178

1179+
first_page_response = None
1180+
if max_results is None and page_size is None and start_index is None:
1181+
first_page_response = self._query_results._properties
1182+
11611183
rows = self._client._list_rows_from_query_results(
1162-
self._query_results.job_id,
1184+
self.job_id,
11631185
self.location,
1164-
self._query_results.project,
1186+
self.project,
11651187
self._query_results.schema,
11661188
total_rows=self._query_results.total_rows,
11671189
destination=self.destination,
@@ -1170,6 +1192,7 @@ def result(
11701192
start_index=start_index,
11711193
retry=retry,
11721194
timeout=timeout,
1195+
first_page_response=first_page_response,
11731196
)
11741197
rows._preserve_order = _contains_order_by(self.query)
11751198
return rows

google/cloud/bigquery/table.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1308,7 +1308,9 @@ class RowIterator(HTTPIterator):
13081308
A subset of columns to select from this table.
13091309
total_rows (Optional[int]):
13101310
Total number of rows in the table.
1311-
1311+
first_page_response (Optional[dict]):
1312+
API response for the first page of results. These are returned when
1313+
the first page is requested.
13121314
"""
13131315

13141316
def __init__(
@@ -1324,6 +1326,7 @@ def __init__(
13241326
table=None,
13251327
selected_fields=None,
13261328
total_rows=None,
1329+
first_page_response=None,
13271330
):
13281331
super(RowIterator, self).__init__(
13291332
client,
@@ -1346,6 +1349,7 @@ def __init__(
13461349
self._selected_fields = selected_fields
13471350
self._table = table
13481351
self._total_rows = total_rows
1352+
self._first_page_response = first_page_response
13491353

13501354
def _get_next_page_response(self):
13511355
"""Requests the next page from the path provided.
@@ -1354,6 +1358,11 @@ def _get_next_page_response(self):
13541358
Dict[str, object]:
13551359
The parsed JSON response of the next page's contents.
13561360
"""
1361+
if self._first_page_response:
1362+
response = self._first_page_response
1363+
self._first_page_response = None
1364+
return response
1365+
13571366
params = self._get_query_params()
13581367
if self._page_size is not None:
13591368
if self.page_number and "startIndex" in params:

tests/unit/job/test_query.py

+42-13
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,9 @@ def test_result(self):
787787
"location": "EU",
788788
},
789789
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
790-
"totalRows": "2",
790+
"totalRows": "3",
791+
"rows": [{"f": [{"v": "abc"}]}],
792+
"pageToken": "next-page",
791793
}
792794
job_resource = self._make_resource(started=True, location="EU")
793795
job_resource_done = self._make_resource(started=True, ended=True, location="EU")
@@ -799,9 +801,9 @@ def test_result(self):
799801
query_page_resource = {
800802
# Explicitly set totalRows to be different from the initial
801803
# response to test update during iteration.
802-
"totalRows": "1",
804+
"totalRows": "2",
803805
"pageToken": None,
804-
"rows": [{"f": [{"v": "abc"}]}],
806+
"rows": [{"f": [{"v": "def"}]}],
805807
}
806808
conn = _make_connection(
807809
query_resource, query_resource_done, job_resource_done, query_page_resource
@@ -812,19 +814,20 @@ def test_result(self):
812814
result = job.result()
813815

814816
self.assertIsInstance(result, RowIterator)
815-
self.assertEqual(result.total_rows, 2)
817+
self.assertEqual(result.total_rows, 3)
816818
rows = list(result)
817-
self.assertEqual(len(rows), 1)
819+
self.assertEqual(len(rows), 2)
818820
self.assertEqual(rows[0].col1, "abc")
821+
self.assertEqual(rows[1].col1, "def")
819822
# Test that the total_rows property has changed during iteration, based
820823
# on the response from tabledata.list.
821-
self.assertEqual(result.total_rows, 1)
824+
self.assertEqual(result.total_rows, 2)
822825

823826
query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}"
824827
query_results_call = mock.call(
825828
method="GET",
826829
path=query_results_path,
827-
query_params={"maxResults": 0, "location": "EU"},
830+
query_params={"location": "EU"},
828831
timeout=None,
829832
)
830833
reload_call = mock.call(
@@ -839,6 +842,7 @@ def test_result(self):
839842
query_params={
840843
"fields": _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS,
841844
"location": "EU",
845+
"pageToken": "next-page",
842846
},
843847
timeout=None,
844848
)
@@ -851,7 +855,9 @@ def test_result_with_done_job_calls_get_query_results(self):
851855
"jobComplete": True,
852856
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
853857
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
854-
"totalRows": "1",
858+
"totalRows": "2",
859+
"rows": [{"f": [{"v": "abc"}]}],
860+
"pageToken": "next-page",
855861
}
856862
job_resource = self._make_resource(started=True, ended=True, location="EU")
857863
job_resource["configuration"]["query"]["destinationTable"] = {
@@ -860,9 +866,9 @@ def test_result_with_done_job_calls_get_query_results(self):
860866
"tableId": "dest_table",
861867
}
862868
results_page_resource = {
863-
"totalRows": "1",
869+
"totalRows": "2",
864870
"pageToken": None,
865-
"rows": [{"f": [{"v": "abc"}]}],
871+
"rows": [{"f": [{"v": "def"}]}],
866872
}
867873
conn = _make_connection(query_resource_done, results_page_resource)
868874
client = _make_client(self.PROJECT, connection=conn)
@@ -871,14 +877,15 @@ def test_result_with_done_job_calls_get_query_results(self):
871877
result = job.result()
872878

873879
rows = list(result)
874-
self.assertEqual(len(rows), 1)
880+
self.assertEqual(len(rows), 2)
875881
self.assertEqual(rows[0].col1, "abc")
882+
self.assertEqual(rows[1].col1, "def")
876883

877884
query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}"
878885
query_results_call = mock.call(
879886
method="GET",
880887
path=query_results_path,
881-
query_params={"maxResults": 0, "location": "EU"},
888+
query_params={"location": "EU"},
882889
timeout=None,
883890
)
884891
query_results_page_call = mock.call(
@@ -887,6 +894,7 @@ def test_result_with_done_job_calls_get_query_results(self):
887894
query_params={
888895
"fields": _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS,
889896
"location": "EU",
897+
"pageToken": "next-page",
890898
},
891899
timeout=None,
892900
)
@@ -900,6 +908,12 @@ def test_result_with_max_results(self):
900908
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
901909
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
902910
"totalRows": "5",
911+
# These rows are discarded because max_results is set.
912+
"rows": [
913+
{"f": [{"v": "xyz"}]},
914+
{"f": [{"v": "uvw"}]},
915+
{"f": [{"v": "rst"}]},
916+
],
903917
}
904918
query_page_resource = {
905919
"totalRows": "5",
@@ -925,6 +939,7 @@ def test_result_with_max_results(self):
925939
rows = list(result)
926940

927941
self.assertEqual(len(rows), 3)
942+
self.assertEqual(rows[0].col1, "abc")
928943
self.assertEqual(len(connection.api_request.call_args_list), 2)
929944
query_page_request = connection.api_request.call_args_list[1]
930945
self.assertEqual(
@@ -979,7 +994,7 @@ def test_result_w_retry(self):
979994
query_results_call = mock.call(
980995
method="GET",
981996
path=f"/projects/{self.PROJECT}/queries/{self.JOB_ID}",
982-
query_params={"maxResults": 0, "location": "asia-northeast1"},
997+
query_params={"location": "asia-northeast1"},
983998
timeout=None,
984999
)
9851000
reload_call = mock.call(
@@ -1079,6 +1094,12 @@ def test_result_w_page_size(self):
10791094
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
10801095
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
10811096
"totalRows": "4",
1097+
# These rows are discarded because page_size is set.
1098+
"rows": [
1099+
{"f": [{"v": "xyz"}]},
1100+
{"f": [{"v": "uvw"}]},
1101+
{"f": [{"v": "rst"}]},
1102+
],
10821103
}
10831104
job_resource = self._make_resource(started=True, ended=True, location="US")
10841105
q_config = job_resource["configuration"]["query"]
@@ -1109,6 +1130,7 @@ def test_result_w_page_size(self):
11091130
# Assert
11101131
actual_rows = list(result)
11111132
self.assertEqual(len(actual_rows), 4)
1133+
self.assertEqual(actual_rows[0].col1, "row1")
11121134

11131135
query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}"
11141136
query_page_1_call = mock.call(
@@ -1142,6 +1164,12 @@ def test_result_with_start_index(self):
11421164
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
11431165
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
11441166
"totalRows": "5",
1167+
# These rows are discarded because start_index is set.
1168+
"rows": [
1169+
{"f": [{"v": "xyz"}]},
1170+
{"f": [{"v": "uvw"}]},
1171+
{"f": [{"v": "rst"}]},
1172+
],
11451173
}
11461174
tabledata_resource = {
11471175
"totalRows": "5",
@@ -1168,6 +1196,7 @@ def test_result_with_start_index(self):
11681196
rows = list(result)
11691197

11701198
self.assertEqual(len(rows), 4)
1199+
self.assertEqual(rows[0].col1, "abc")
11711200
self.assertEqual(len(connection.api_request.call_args_list), 2)
11721201
tabledata_list_request = connection.api_request.call_args_list[1]
11731202
self.assertEqual(

0 commit comments

Comments
 (0)