Skip to content

Commit 34704c3

Browse files
committed
Make total_rows available on RowIterator before iteration
After running a query, the total number of rows is available from the call to the getQueryResults API. This commit plumbs the total rows through to the faux Table created in QueryJob.results and then on through to the RowIterator created by list_rows. Also, call get_table in list_rows... TODO: split that out to a separate PR
1 parent 1c2cee7 commit 34704c3

File tree

6 files changed

+102
-56
lines changed

6 files changed

+102
-56
lines changed

bigquery/docs/snippets.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2246,8 +2246,7 @@ def test_client_query_total_rows(client, capsys):
22462246
location="US",
22472247
) # API request - starts the query
22482248

2249-
results = query_job.result() # Waits for query to complete.
2250-
next(iter(results)) # Fetch the first page of results, which contains total_rows.
2249+
results = query_job.result() # Wait for query to complete.
22512250
print("Got {} rows.".format(results.total_rows))
22522251
# [END bigquery_query_total_rows]
22532252

bigquery/google/cloud/bigquery/job.py

+1
Original file line numberDiff line numberDiff line change
@@ -2808,6 +2808,7 @@ def result(self, timeout=None, retry=DEFAULT_RETRY):
28082808
schema = self._query_results.schema
28092809
dest_table_ref = self.destination
28102810
dest_table = Table(dest_table_ref, schema=schema)
2811+
dest_table._properties["numRows"] = self._query_results.total_rows
28112812
return self._client.list_rows(dest_table, retry=retry)
28122813

28132814
def to_dataframe(self, bqstorage_client=None, dtypes=None, progress_bar_type=None):

bigquery/google/cloud/bigquery/table.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1300,7 +1300,11 @@ def __init__(
13001300
)
13011301
self._schema = schema
13021302
self._field_to_index = _helpers._field_to_index_mapping(schema)
1303+
13031304
self._total_rows = None
1305+
if table is not None and hasattr(table, "num_rows"):
1306+
self._total_rows = table.num_rows
1307+
13041308
self._page_size = page_size
13051309
self._table = table
13061310
self._selected_fields = selected_fields
@@ -1422,9 +1426,7 @@ def _get_progress_bar(self, progress_bar_type):
14221426
desc=description, total=self.total_rows, unit=unit
14231427
)
14241428
elif progress_bar_type == "tqdm_gui":
1425-
return tqdm.tqdm_gui(
1426-
desc=description, total=self.total_rows, unit=unit
1427-
)
1429+
return tqdm.tqdm_gui(desc=description, total=self.total_rows, unit=unit)
14281430
except (KeyError, TypeError):
14291431
# Protect ourselves from any tqdm errors. In case of
14301432
# unexpected tqdm behavior, just fall back to showing

bigquery/tests/unit/test_client.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -4115,18 +4115,21 @@ def test_list_rows_empty_table(self):
41154115
client._connection = _make_connection(response, response)
41164116

41174117
# Table that has no schema because it's an empty table.
4118-
rows = tuple(
4119-
client.list_rows(
4120-
# Test with using a string for the table ID.
4121-
"{}.{}.{}".format(
4122-
self.TABLE_REF.project,
4123-
self.TABLE_REF.dataset_id,
4124-
self.TABLE_REF.table_id,
4125-
),
4126-
selected_fields=[],
4127-
)
4118+
rows = client.list_rows(
4119+
# Test with using a string for the table ID.
4120+
"{}.{}.{}".format(
4121+
self.TABLE_REF.project,
4122+
self.TABLE_REF.dataset_id,
4123+
self.TABLE_REF.table_id,
4124+
),
4125+
selected_fields=[],
41284126
)
4129-
self.assertEqual(rows, ())
4127+
4128+
# When a table reference / string and selected_fields is provided,
4129+
# total_rows can't be populated until iteration starts.
4130+
self.assertIsNone(rows.total_rows)
4131+
self.assertEqual(tuple(rows), ())
4132+
self.assertEqual(rows.total_rows, 0)
41304133

41314134
def test_list_rows_query_params(self):
41324135
from google.cloud.bigquery.table import Table, SchemaField
@@ -4329,7 +4332,7 @@ def test_list_rows_with_missing_schema(self):
43294332

43304333
conn.api_request.assert_called_once_with(method="GET", path=table_path)
43314334
conn.api_request.reset_mock()
4332-
self.assertIsNone(row_iter.total_rows, msg=repr(table))
4335+
self.assertEqual(row_iter.total_rows, 2, msg=repr(table))
43334336

43344337
rows = list(row_iter)
43354338
conn.api_request.assert_called_once_with(

bigquery/tests/unit/test_job.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -4008,21 +4008,37 @@ def test_estimated_bytes_processed(self):
40084008
self.assertEqual(job.estimated_bytes_processed, est_bytes)
40094009

40104010
def test_result(self):
4011+
from google.cloud.bigquery.table import RowIterator
4012+
40114013
query_resource = {
40124014
"jobComplete": True,
40134015
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
40144016
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
4017+
"totalRows": "2",
40154018
}
4016-
connection = _make_connection(query_resource, query_resource)
4019+
tabledata_resource = {
4020+
"totalRows": "1",
4021+
"pageToken": None,
4022+
"rows": [{"f": [{"v": "abc"}]}],
4023+
}
4024+
connection = _make_connection(query_resource, tabledata_resource)
40174025
client = _make_client(self.PROJECT, connection=connection)
40184026
resource = self._make_resource(ended=True)
40194027
job = self._get_target_class().from_api_repr(resource, client)
40204028

40214029
result = job.result()
40224030

4023-
self.assertEqual(list(result), [])
4031+
self.assertIsInstance(result, RowIterator)
4032+
self.assertEqual(result.total_rows, 2)
4033+
4034+
rows = list(result)
4035+
self.assertEqual(len(rows), 1)
4036+
self.assertEqual(rows[0].col1, "abc")
4037+
self.assertEqual(result.total_rows, 1)
40244038

40254039
def test_result_w_empty_schema(self):
4040+
from google.cloud.bigquery.table import _EmptyRowIterator
4041+
40264042
# Destination table may have no schema for some DDL and DML queries.
40274043
query_resource = {
40284044
"jobComplete": True,
@@ -4036,6 +4052,7 @@ def test_result_w_empty_schema(self):
40364052

40374053
result = job.result()
40384054

4055+
self.assertIsInstance(result, _EmptyRowIterator)
40394056
self.assertEqual(list(result), [])
40404057

40414058
def test_result_invokes_begins(self):

0 commit comments

Comments
 (0)