Skip to content

Commit 16945d5

Browse files
committed
Determine the schema in load_table_from_dataframe based on dtypes.
This PR updates `load_table_from_dataframe` to automatically determine the BigQuery schema based on the DataFrame's dtypes. If any field's type cannot be determined, fallback to the logic in the pandas `to_parquet` method.
1 parent 1a103d3 commit 16945d5

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed

bigquery/google/cloud/bigquery/_pandas_helpers.py

+40
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,21 @@
4949

5050
_PROGRESS_INTERVAL = 0.2 # Maximum time between download status checks, in seconds.
5151

52+
_PANDAS_DTYPE_TO_BQ = {
53+
"bool": "BOOLEAN",
54+
"datetime64[ns, UTC]": "TIMESTAMP",
55+
"datetime64[ns]": "DATETIME",
56+
"float32": "FLOAT",
57+
"float64": "FLOAT",
58+
"int8": "INTEGER",
59+
"int16": "INTEGER",
60+
"int32": "INTEGER",
61+
"int64": "INTEGER",
62+
"uint8": "INTEGER",
63+
"uint16": "INTEGER",
64+
"uint32": "INTEGER",
65+
}
66+
5267

5368
class _DownloadState(object):
5469
"""Flag to indicate that a thread should exit early."""
@@ -172,6 +187,31 @@ def bq_to_arrow_array(series, bq_field):
172187
return pyarrow.array(series, type=arrow_type)
173188

174189

190+
def dataframe_to_bq_schema(dataframe):
191+
"""Convert a pandas DataFrame schema to a BigQuery schema.
192+
193+
TODO(GH#8140): Add bq_schema argument to allow overriding autodetected
194+
schema for a subset of columns.
195+
196+
Args:
197+
dataframe (pandas.DataFrame):
198+
DataFrame to convert to convert to Parquet file.
199+
200+
Returns:
201+
Optional[Sequence[google.cloud.bigquery.schema.SchemaField]]:
202+
The automatically determined schema. Returns None if the type of
203+
any column cannot be determined.
204+
"""
205+
bq_schema = []
206+
for column, dtype in zip(dataframe.columns, dataframe.dtypes):
207+
bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name)
208+
if not bq_type:
209+
return None
210+
bq_field = schema.SchemaField(column, bq_type)
211+
bq_schema.append(bq_field)
212+
return tuple(bq_schema)
213+
214+
175215
def dataframe_to_arrow(dataframe, bq_schema):
176216
"""Convert pandas dataframe to Arrow table, using BigQuery schema.
177217

bigquery/google/cloud/bigquery/client.py

+9
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
except ImportError: # Python 2.7
2222
import collections as collections_abc
2323

24+
import copy
2425
import functools
2526
import gzip
2627
import io
@@ -1520,11 +1521,19 @@ def load_table_from_dataframe(
15201521

15211522
if job_config is None:
15221523
job_config = job.LoadJobConfig()
1524+
else:
1525+
# Make a copy so that the job config isn't modified in-place.
1526+
job_config_properties = copy.deepcopy(job_config._properties)
1527+
job_config = job.LoadJobConfig()
1528+
job_config._properties = job_config_properties
15231529
job_config.source_format = job.SourceFormat.PARQUET
15241530

15251531
if location is None:
15261532
location = self.location
15271533

1534+
if not job_config.schema:
1535+
job_config.schema = _pandas_helpers.dataframe_to_bq_schema(dataframe)
1536+
15281537
tmpfd, tmppath = tempfile.mkstemp(suffix="_job_{}.parquet".format(job_id[:8]))
15291538
os.close(tmpfd)
15301539

bigquery/tests/system.py

+94
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,100 @@ def test_load_table_from_local_avro_file_then_dump_table(self):
634634
sorted(row_tuples, key=by_wavelength), sorted(ROWS, key=by_wavelength)
635635
)
636636

637+
@unittest.skipIf(pandas is None, "Requires `pandas`")
638+
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
639+
def test_load_table_from_dataframe_w_automatic_schema(self):
640+
"""Test that a DataFrame with dtypes that map well to BigQuery types
641+
can be uploaded without specifying a schema.
642+
643+
https://github.com/googleapis/google-cloud-python/issues/9044
644+
"""
645+
bool_col = pandas.Series([True, False, True], dtype="bool")
646+
ts_col = pandas.Series(
647+
[
648+
datetime.datetime(2010, 1, 2, 3, 44, 50),
649+
datetime.datetime(2011, 2, 3, 14, 50, 59),
650+
datetime.datetime(2012, 3, 14, 15, 16),
651+
],
652+
dtype="datetime64[ns]",
653+
).dt.tz_localize(pytz.utc)
654+
dt_col = pandas.Series(
655+
[
656+
datetime.datetime(2010, 1, 2, 3, 44, 50),
657+
datetime.datetime(2011, 2, 3, 14, 50, 59),
658+
datetime.datetime(2012, 3, 14, 15, 16),
659+
],
660+
dtype="datetime64[ns]",
661+
)
662+
float32_col = pandas.Series([1.0, 2.0, 3.0], dtype="float32")
663+
float64_col = pandas.Series([4.0, 5.0, 6.0], dtype="float64")
664+
int8_col = pandas.Series([-12, -11, -10], dtype="int8")
665+
int16_col = pandas.Series([-9, -8, -7], dtype="int16")
666+
int32_col = pandas.Series([-6, -5, -4], dtype="int32")
667+
int64_col = pandas.Series([-3, -2, -1], dtype="int64")
668+
uint8_col = pandas.Series([0, 1, 2], dtype="uint8")
669+
uint16_col = pandas.Series([3, 4, 5], dtype="uint16")
670+
uint32_col = pandas.Series([6, 7, 8], dtype="uint32")
671+
dataframe = pandas.DataFrame(
672+
{
673+
"bool_col": bool_col,
674+
"ts_col": ts_col,
675+
"dt_col": dt_col,
676+
"float32_col": float32_col,
677+
"float64_col": float64_col,
678+
"int8_col": int8_col,
679+
"int16_col": int16_col,
680+
"int32_col": int32_col,
681+
"int64_col": int64_col,
682+
"uint8_col": uint8_col,
683+
"uint16_col": uint16_col,
684+
"uint32_col": uint32_col,
685+
},
686+
columns=[
687+
"bool_col",
688+
"ts_col",
689+
"dt_col",
690+
"float32_col",
691+
"float64_col",
692+
"int8_col",
693+
"int16_col",
694+
"int32_col",
695+
"int64_col",
696+
"uint8_col",
697+
"uint16_col",
698+
"uint32_col",
699+
],
700+
)
701+
702+
dataset_id = _make_dataset_id("bq_load_test")
703+
self.temp_dataset(dataset_id)
704+
table_id = "{}.{}.load_table_from_dataframe_w_automatic_schema".format(
705+
Config.CLIENT.project, dataset_id
706+
)
707+
708+
load_job = Config.CLIENT.load_table_from_dataframe(dataframe, table_id)
709+
load_job.result()
710+
711+
table = Config.CLIENT.get_table(table_id)
712+
self.assertEqual(
713+
tuple(table.schema),
714+
(
715+
bigquery.SchemaField("bool_col", "BOOLEAN"),
716+
bigquery.SchemaField("ts_col", "TIMESTAMP"),
717+
bigquery.SchemaField("dt_col", "DATETIME"),
718+
bigquery.SchemaField("float32_col", "FLOAT"),
719+
bigquery.SchemaField("float64_col", "FLOAT"),
720+
bigquery.SchemaField("int8_col", "INTEGER"),
721+
bigquery.SchemaField("int16_col", "INTEGER"),
722+
bigquery.SchemaField("int32_col", "INTEGER"),
723+
bigquery.SchemaField("int64_col", "INTEGER"),
724+
bigquery.SchemaField("uint8_col", "INTEGER"),
725+
bigquery.SchemaField("uint16_col", "INTEGER"),
726+
bigquery.SchemaField("uint32_col", "INTEGER"),
727+
),
728+
)
729+
self.assertEqual(table.num_rows, 3)
730+
637731
@unittest.skipIf(pandas is None, "Requires `pandas`")
638732
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
639733
def test_load_table_from_dataframe_w_nulls(self):

0 commit comments

Comments
 (0)