Skip to content

Commit fcf99ce

Browse files
authored
Determine the schema in load_table_from_dataframe based on dtypes. (#9049)
* Determine the schema in `load_table_from_dataframe` based on dtypes. This PR updates `load_table_from_dataframe` to automatically determine the BigQuery schema based on the DataFrame's dtypes. If any field's type cannot be determined, fallback to the logic in the pandas `to_parquet` method. * Fix test coverage. * Reduce duplication by using OrderedDict * Add columns option to DataFrame constructor to ensure correct column order.
1 parent c927a72 commit fcf99ce

File tree

4 files changed

+203
-2
lines changed

4 files changed

+203
-2
lines changed

bigquery/google/cloud/bigquery/_pandas_helpers.py

+40
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,21 @@
4949

5050
_PROGRESS_INTERVAL = 0.2 # Maximum time between download status checks, in seconds.
5151

52+
_PANDAS_DTYPE_TO_BQ = {
53+
"bool": "BOOLEAN",
54+
"datetime64[ns, UTC]": "TIMESTAMP",
55+
"datetime64[ns]": "DATETIME",
56+
"float32": "FLOAT",
57+
"float64": "FLOAT",
58+
"int8": "INTEGER",
59+
"int16": "INTEGER",
60+
"int32": "INTEGER",
61+
"int64": "INTEGER",
62+
"uint8": "INTEGER",
63+
"uint16": "INTEGER",
64+
"uint32": "INTEGER",
65+
}
66+
5267

5368
class _DownloadState(object):
5469
"""Flag to indicate that a thread should exit early."""
@@ -172,6 +187,31 @@ def bq_to_arrow_array(series, bq_field):
172187
return pyarrow.array(series, type=arrow_type)
173188

174189

190+
def dataframe_to_bq_schema(dataframe):
191+
"""Convert a pandas DataFrame schema to a BigQuery schema.
192+
193+
TODO(GH#8140): Add bq_schema argument to allow overriding autodetected
194+
schema for a subset of columns.
195+
196+
Args:
197+
dataframe (pandas.DataFrame):
198+
DataFrame to convert to convert to Parquet file.
199+
200+
Returns:
201+
Optional[Sequence[google.cloud.bigquery.schema.SchemaField]]:
202+
The automatically determined schema. Returns None if the type of
203+
any column cannot be determined.
204+
"""
205+
bq_schema = []
206+
for column, dtype in zip(dataframe.columns, dataframe.dtypes):
207+
bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name)
208+
if not bq_type:
209+
return None
210+
bq_field = schema.SchemaField(column, bq_type)
211+
bq_schema.append(bq_field)
212+
return tuple(bq_schema)
213+
214+
175215
def dataframe_to_arrow(dataframe, bq_schema):
176216
"""Convert pandas dataframe to Arrow table, using BigQuery schema.
177217

bigquery/google/cloud/bigquery/client.py

+15
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
except ImportError: # Python 2.7
2222
import collections as collections_abc
2323

24+
import copy
2425
import functools
2526
import gzip
2627
import io
@@ -1521,11 +1522,25 @@ def load_table_from_dataframe(
15211522

15221523
if job_config is None:
15231524
job_config = job.LoadJobConfig()
1525+
else:
1526+
# Make a copy so that the job config isn't modified in-place.
1527+
job_config_properties = copy.deepcopy(job_config._properties)
1528+
job_config = job.LoadJobConfig()
1529+
job_config._properties = job_config_properties
15241530
job_config.source_format = job.SourceFormat.PARQUET
15251531

15261532
if location is None:
15271533
location = self.location
15281534

1535+
if not job_config.schema:
1536+
autodetected_schema = _pandas_helpers.dataframe_to_bq_schema(dataframe)
1537+
1538+
# Only use an explicit schema if we were able to determine one
1539+
# matching the dataframe. If not, fallback to the pandas to_parquet
1540+
# method.
1541+
if autodetected_schema:
1542+
job_config.schema = autodetected_schema
1543+
15291544
tmpfd, tmppath = tempfile.mkstemp(suffix="_job_{}.parquet".format(job_id[:8]))
15301545
os.close(tmpfd)
15311546

bigquery/tests/system.py

+76
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import base64
16+
import collections
1617
import concurrent.futures
1718
import csv
1819
import datetime
@@ -634,6 +635,81 @@ def test_load_table_from_local_avro_file_then_dump_table(self):
634635
sorted(row_tuples, key=by_wavelength), sorted(ROWS, key=by_wavelength)
635636
)
636637

638+
@unittest.skipIf(pandas is None, "Requires `pandas`")
639+
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
640+
def test_load_table_from_dataframe_w_automatic_schema(self):
641+
"""Test that a DataFrame with dtypes that map well to BigQuery types
642+
can be uploaded without specifying a schema.
643+
644+
https://github.com/googleapis/google-cloud-python/issues/9044
645+
"""
646+
df_data = collections.OrderedDict(
647+
[
648+
("bool_col", pandas.Series([True, False, True], dtype="bool")),
649+
(
650+
"ts_col",
651+
pandas.Series(
652+
[
653+
datetime.datetime(2010, 1, 2, 3, 44, 50),
654+
datetime.datetime(2011, 2, 3, 14, 50, 59),
655+
datetime.datetime(2012, 3, 14, 15, 16),
656+
],
657+
dtype="datetime64[ns]",
658+
).dt.tz_localize(pytz.utc),
659+
),
660+
(
661+
"dt_col",
662+
pandas.Series(
663+
[
664+
datetime.datetime(2010, 1, 2, 3, 44, 50),
665+
datetime.datetime(2011, 2, 3, 14, 50, 59),
666+
datetime.datetime(2012, 3, 14, 15, 16),
667+
],
668+
dtype="datetime64[ns]",
669+
),
670+
),
671+
("float32_col", pandas.Series([1.0, 2.0, 3.0], dtype="float32")),
672+
("float64_col", pandas.Series([4.0, 5.0, 6.0], dtype="float64")),
673+
("int8_col", pandas.Series([-12, -11, -10], dtype="int8")),
674+
("int16_col", pandas.Series([-9, -8, -7], dtype="int16")),
675+
("int32_col", pandas.Series([-6, -5, -4], dtype="int32")),
676+
("int64_col", pandas.Series([-3, -2, -1], dtype="int64")),
677+
("uint8_col", pandas.Series([0, 1, 2], dtype="uint8")),
678+
("uint16_col", pandas.Series([3, 4, 5], dtype="uint16")),
679+
("uint32_col", pandas.Series([6, 7, 8], dtype="uint32")),
680+
]
681+
)
682+
dataframe = pandas.DataFrame(df_data, columns=df_data.keys())
683+
684+
dataset_id = _make_dataset_id("bq_load_test")
685+
self.temp_dataset(dataset_id)
686+
table_id = "{}.{}.load_table_from_dataframe_w_automatic_schema".format(
687+
Config.CLIENT.project, dataset_id
688+
)
689+
690+
load_job = Config.CLIENT.load_table_from_dataframe(dataframe, table_id)
691+
load_job.result()
692+
693+
table = Config.CLIENT.get_table(table_id)
694+
self.assertEqual(
695+
tuple(table.schema),
696+
(
697+
bigquery.SchemaField("bool_col", "BOOLEAN"),
698+
bigquery.SchemaField("ts_col", "TIMESTAMP"),
699+
bigquery.SchemaField("dt_col", "DATETIME"),
700+
bigquery.SchemaField("float32_col", "FLOAT"),
701+
bigquery.SchemaField("float64_col", "FLOAT"),
702+
bigquery.SchemaField("int8_col", "INTEGER"),
703+
bigquery.SchemaField("int16_col", "INTEGER"),
704+
bigquery.SchemaField("int32_col", "INTEGER"),
705+
bigquery.SchemaField("int64_col", "INTEGER"),
706+
bigquery.SchemaField("uint8_col", "INTEGER"),
707+
bigquery.SchemaField("uint16_col", "INTEGER"),
708+
bigquery.SchemaField("uint32_col", "INTEGER"),
709+
),
710+
)
711+
self.assertEqual(table.num_rows, 3)
712+
637713
@unittest.skipIf(pandas is None, "Requires `pandas`")
638714
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
639715
def test_load_table_from_dataframe_w_nulls(self):

bigquery/tests/unit/test_client.py

+72-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import copy
16+
import collections
1617
import datetime
1718
import decimal
1819
import email
@@ -5325,9 +5326,78 @@ def test_load_table_from_dataframe_w_custom_job_config(self):
53255326
)
53265327

53275328
sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
5328-
assert sent_config is job_config
53295329
assert sent_config.source_format == job.SourceFormat.PARQUET
53305330

5331+
@unittest.skipIf(pandas is None, "Requires `pandas`")
5332+
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
5333+
def test_load_table_from_dataframe_w_automatic_schema(self):
5334+
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
5335+
from google.cloud.bigquery import job
5336+
from google.cloud.bigquery.schema import SchemaField
5337+
5338+
client = self._make_client()
5339+
df_data = collections.OrderedDict(
5340+
[
5341+
("int_col", [1, 2, 3]),
5342+
("float_col", [1.0, 2.0, 3.0]),
5343+
("bool_col", [True, False, True]),
5344+
(
5345+
"dt_col",
5346+
pandas.Series(
5347+
[
5348+
datetime.datetime(2010, 1, 2, 3, 44, 50),
5349+
datetime.datetime(2011, 2, 3, 14, 50, 59),
5350+
datetime.datetime(2012, 3, 14, 15, 16),
5351+
],
5352+
dtype="datetime64[ns]",
5353+
),
5354+
),
5355+
(
5356+
"ts_col",
5357+
pandas.Series(
5358+
[
5359+
datetime.datetime(2010, 1, 2, 3, 44, 50),
5360+
datetime.datetime(2011, 2, 3, 14, 50, 59),
5361+
datetime.datetime(2012, 3, 14, 15, 16),
5362+
],
5363+
dtype="datetime64[ns]",
5364+
).dt.tz_localize(pytz.utc),
5365+
),
5366+
]
5367+
)
5368+
dataframe = pandas.DataFrame(df_data, columns=df_data.keys())
5369+
load_patch = mock.patch(
5370+
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
5371+
)
5372+
5373+
with load_patch as load_table_from_file:
5374+
client.load_table_from_dataframe(
5375+
dataframe, self.TABLE_REF, location=self.LOCATION
5376+
)
5377+
5378+
load_table_from_file.assert_called_once_with(
5379+
client,
5380+
mock.ANY,
5381+
self.TABLE_REF,
5382+
num_retries=_DEFAULT_NUM_RETRIES,
5383+
rewind=True,
5384+
job_id=mock.ANY,
5385+
job_id_prefix=None,
5386+
location=self.LOCATION,
5387+
project=None,
5388+
job_config=mock.ANY,
5389+
)
5390+
5391+
sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
5392+
assert sent_config.source_format == job.SourceFormat.PARQUET
5393+
assert tuple(sent_config.schema) == (
5394+
SchemaField("int_col", "INTEGER"),
5395+
SchemaField("float_col", "FLOAT"),
5396+
SchemaField("bool_col", "BOOLEAN"),
5397+
SchemaField("dt_col", "DATETIME"),
5398+
SchemaField("ts_col", "TIMESTAMP"),
5399+
)
5400+
53315401
@unittest.skipIf(pandas is None, "Requires `pandas`")
53325402
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
53335403
def test_load_table_from_dataframe_struct_fields_error(self):
@@ -5509,7 +5579,7 @@ def test_load_table_from_dataframe_w_nulls(self):
55095579
)
55105580

55115581
sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
5512-
assert sent_config is job_config
5582+
assert sent_config.schema == schema
55135583
assert sent_config.source_format == job.SourceFormat.PARQUET
55145584

55155585
# Low-level tests

0 commit comments

Comments
 (0)