Skip to content

Commit c1dcf67

Browse files
authored
feat(duckdb): add support for passing a subset of column types to read_csv (#9776)
Add support for the `types` argument to `read_csv` in the DuckDB backend.
1 parent a019dfd commit c1dcf67

File tree

2 files changed

+111
-23
lines changed

2 files changed

+111
-23
lines changed

ibis/backends/duckdb/__init__.py

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import ibis
2121
import ibis.backends.sql.compilers as sc
2222
import ibis.common.exceptions as exc
23+
import ibis.expr.datatypes as dt
2324
import ibis.expr.operations as ops
2425
import ibis.expr.schema as sch
2526
import ibis.expr.types as ir
@@ -637,27 +638,80 @@ def read_csv(
637638
self,
638639
source_list: str | list[str] | tuple[str],
639640
table_name: str | None = None,
641+
columns: Mapping[str, str | dt.DataType] | None = None,
642+
types: Mapping[str, str | dt.DataType] | None = None,
640643
**kwargs: Any,
641644
) -> ir.Table:
642645
"""Register a CSV file as a table in the current database.
643646
644647
Parameters
645648
----------
646649
source_list
647-
The data source(s). May be a path to a file or directory of CSV files, or an
648-
iterable of CSV files.
650+
The data source(s). May be a path to a file or directory of CSV
651+
files, or an iterable of CSV files.
649652
table_name
650-
An optional name to use for the created table. This defaults to
651-
a sequentially generated name.
653+
An optional name to use for the created table. This defaults to a
654+
sequentially generated name.
655+
columns
656+
An optional mapping of **all** column names to their types.
657+
types
658+
An optional mapping of a **subset** of column names to their types.
652659
**kwargs
653-
Additional keyword arguments passed to DuckDB loading function.
654-
See https://duckdb.org/docs/data/csv for more information.
660+
Additional keyword arguments passed to DuckDB loading function. See
661+
https://duckdb.org/docs/data/csv for more information.
655662
656663
Returns
657664
-------
658665
ir.Table
659666
The just-registered table
660667
668+
Examples
669+
--------
670+
Generate some data
671+
672+
>>> import tempfile
673+
>>> data = b'''
674+
... lat,lon,geom
675+
... 1.0,2.0,POINT (1 2)
676+
... 2.0,3.0,POINT (2 3)
677+
... '''
678+
>>> with tempfile.NamedTemporaryFile(delete=False) as f:
679+
... nbytes = f.write(data)
680+
681+
Import Ibis
682+
683+
>>> import ibis
684+
>>> from ibis import _
685+
>>> ibis.options.interactive = True
686+
>>> con = ibis.duckdb.connect()
687+
688+
Read the raw CSV file
689+
690+
>>> t = con.read_csv(f.name)
691+
>>> t
692+
┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━┓
693+
┃ lat ┃ lon ┃ geom ┃
694+
┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━┩
695+
│ float64 │ float64 │ string │
696+
├─────────┼─────────┼─────────────┤
697+
│ 1.0 │ 2.0 │ POINT (1 2) │
698+
│ 2.0 │ 3.0 │ POINT (2 3) │
699+
└─────────┴─────────┴─────────────┘
700+
701+
Load the `spatial` extension and read the CSV file again, using
702+
specific column types
703+
704+
>>> con.load_extension("spatial")
705+
>>> t = con.read_csv(f.name, types={"geom": "geometry"})
706+
>>> t
707+
┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
708+
┃ lat ┃ lon ┃ geom ┃
709+
┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
710+
│ float64 │ float64 │ geospatial:geometry │
711+
├─────────┼─────────┼──────────────────────┤
712+
│ 1.0 │ 2.0 │ <POINT (1 2)> │
713+
│ 2.0 │ 3.0 │ <POINT (2 3)> │
714+
└─────────┴─────────┴──────────────────────┘
661715
"""
662716
source_list = util.normalize_filenames(source_list)
663717

@@ -673,27 +727,35 @@ def read_csv(
673727
self._load_extensions(["httpfs"])
674728

675729
kwargs.setdefault("header", True)
676-
kwargs["auto_detect"] = kwargs.pop("auto_detect", "columns" not in kwargs)
730+
kwargs["auto_detect"] = kwargs.pop("auto_detect", columns is None)
677731
# TODO: clean this up
678732
# We want to _usually_ quote arguments but if we quote `columns` it messes
679733
# up DuckDB's struct parsing.
680-
options = [
681-
sg.to_identifier(key).eq(sge.convert(val)) for key, val in kwargs.items()
682-
]
683-
684-
if (columns := kwargs.pop("columns", None)) is not None:
685-
options.append(
686-
sg.to_identifier("columns").eq(
687-
sge.Struct(
688-
expressions=[
689-
sge.PropertyEQ(
690-
this=sge.convert(key), expression=sge.convert(value)
691-
)
692-
for key, value in columns.items()
693-
]
694-
)
734+
options = [C[key].eq(sge.convert(val)) for key, val in kwargs.items()]
735+
736+
def make_struct_argument(obj: Mapping[str, str | dt.DataType]) -> sge.Struct:
737+
expressions = []
738+
geospatial = False
739+
type_mapper = self.compiler.type_mapper
740+
741+
for name, typ in obj.items():
742+
typ = dt.dtype(typ)
743+
geospatial |= typ.is_geospatial()
744+
sgtype = type_mapper.from_ibis(typ)
745+
prop = sge.PropertyEQ(
746+
this=sge.to_identifier(name), expression=sge.convert(sgtype)
695747
)
696-
)
748+
expressions.append(prop)
749+
750+
if geospatial:
751+
self._load_extensions(["spatial"])
752+
return sge.Struct(expressions=expressions)
753+
754+
if columns is not None:
755+
options.append(C.columns.eq(make_struct_argument(columns)))
756+
757+
if types is not None:
758+
options.append(C.types.eq(make_struct_argument(types)))
697759

698760
self._create_temp_view(
699761
table_name,

ibis/backends/duckdb/tests/test_client.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,29 @@ def test_multiple_tables_with_the_same_name(tmp_path):
375375
t3 = con.table("t", database="w.main")
376376

377377
assert t3.schema() == ibis.schema({"y": "array<float64>"})
378+
379+
380+
@pytest.mark.parametrize(
381+
"input",
382+
[
383+
{"columns": {"lat": "float64", "lon": "float64", "geom": "geometry"}},
384+
{"types": {"geom": "geometry"}},
385+
],
386+
)
387+
@pytest.mark.parametrize("all_varchar", [True, False])
388+
@pytest.mark.xfail(
389+
LINUX and SANDBOXED,
390+
reason="nix on linux cannot download duckdb extensions or data due to sandboxing",
391+
raises=duckdb.IOException,
392+
)
393+
@pytest.mark.xdist_group(name="duckdb-extensions")
394+
def test_read_csv_with_types(tmp_path, input, all_varchar):
395+
con = ibis.duckdb.connect()
396+
data = b"""\
397+
lat,lon,geom
398+
1.0,2.0,POINT (1 2)
399+
2.0,3.0,POINT (2 3)"""
400+
path = tmp_path / "data.csv"
401+
path.write_bytes(data)
402+
t = con.read_csv(path, all_varchar=all_varchar, **input)
403+
assert t.schema()["geom"].is_geospatial()

0 commit comments

Comments
 (0)