Skip to content

Commit 7bd22af

Browse files
gforsythcpcloud
authored andcommitted
feat(pyspark): add read_csv, read_parquet, and register
1 parent d6235ee commit 7bd22af

File tree

3 files changed

+233
-10
lines changed

3 files changed

+233
-10
lines changed

ibis/backends/duckdb/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,8 @@ def register(
187187
parquet/csv files, an iterable of parquet or CSV files, a pandas
188188
dataframe, a pyarrow table or dataset, or a postgres URI.
189189
table_name
190-
An optional name to use for the created table. This defaults to the
191-
filename if a path (with hyphens replaced with underscores), or
192-
sequentially generated name otherwise.
190+
An optional name to use for the created table. This defaults to a
191+
sequentially generated name.
193192
**kwargs
194193
Additional keyword arguments passed to DuckDB loading functions for
195194
CSV or parquet. See https://duckdb.org/docs/data/csv and

ibis/backends/pyspark/__init__.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import itertools
34
from pathlib import Path
45
from typing import TYPE_CHECKING, Any
56

@@ -13,6 +14,7 @@
1314
import ibis.expr.operations as ops
1415
import ibis.expr.schema as sch
1516
import ibis.expr.types as ir
17+
from ibis import util
1618
from ibis.backends.base.df.scope import Scope
1719
from ibis.backends.base.df.timecontext import canonicalize_context, localize_context
1820
from ibis.backends.base.sql import BaseSQLBackend
@@ -37,6 +39,16 @@
3739
'escape': '"',
3840
}
3941

42+
pa_n = itertools.count(0)
43+
csv_n = itertools.count(0)
44+
45+
46+
def normalize_filenames(source_list):
47+
# Promote to list
48+
source_list = util.promote_list(source_list)
49+
50+
return list(map(util.normalize_filename, source_list))
51+
4052

4153
class _PySparkCursor:
4254
"""Spark cursor.
@@ -574,3 +586,121 @@ def _clean_up_cached_table(self, op):
574586
assert t.is_cached
575587
t.unpersist()
576588
assert not t.is_cached
589+
590+
def read_parquet(
591+
self,
592+
source: str | Path,
593+
table_name: str | None = None,
594+
**kwargs: Any,
595+
) -> ir.Table:
596+
"""Register a parquet file as a table in the current database.
597+
598+
Parameters
599+
----------
600+
source
601+
The data source. May be a path to a file or directory of parquet files.
602+
table_name
603+
An optional name to use for the created table. This defaults to
604+
a sequentially generated name.
605+
kwargs
606+
Additional keyword arguments passed to PySpark.
607+
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.parquet.html
608+
609+
Returns
610+
-------
611+
ir.Table
612+
The just-registered table
613+
"""
614+
source = util.normalize_filename(source)
615+
spark_df = self._session.read.parquet(source, **kwargs)
616+
table_name = table_name or f"ibis_read_parquet_{next(pa_n)}"
617+
618+
spark_df.createOrReplaceTempView(table_name)
619+
return self.table(table_name)
620+
621+
def read_csv(
622+
self,
623+
source_list: str | list[str] | tuple[str],
624+
table_name: str | None = None,
625+
**kwargs: Any,
626+
) -> ir.Table:
627+
"""Register a CSV file as a table in the current database.
628+
629+
Parameters
630+
----------
631+
source_list
632+
The data source(s). May be a path to a file or directory of CSV files, or an
633+
iterable of CSV files.
634+
table_name
635+
An optional name to use for the created table. This defaults to
636+
a sequentially generated name.
637+
kwargs
638+
Additional keyword arguments passed to PySpark loading function.
639+
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.csv.html
640+
641+
Returns
642+
-------
643+
ir.Table
644+
The just-registered table
645+
"""
646+
source_list = normalize_filenames(source_list)
647+
spark_df = self._session.read.csv(source_list, **kwargs)
648+
table_name = table_name or f"ibis_read_csv_{next(csv_n)}"
649+
650+
spark_df.createOrReplaceTempView(table_name)
651+
return self.table(table_name)
652+
653+
def register(
654+
self,
655+
source: str | Path | Any,
656+
table_name: str | None = None,
657+
**kwargs: Any,
658+
) -> ir.Table:
659+
"""Register a data source as a table in the current database.
660+
661+
Parameters
662+
----------
663+
source
664+
The data source(s). May be a path to a file or directory of
665+
parquet/csv files, or an iterable of CSV files.
666+
table_name
667+
An optional name to use for the created table. This defaults to
668+
a sequentially generated name.
669+
**kwargs
670+
Additional keyword arguments passed to PySpark loading functions for
671+
CSV or parquet.
672+
673+
Returns
674+
-------
675+
ir.Table
676+
The just-registered table
677+
"""
678+
679+
if isinstance(source, (str, Path)):
680+
first = str(source)
681+
elif isinstance(source, (list, tuple)):
682+
first = source[0]
683+
else:
684+
self._register_failure()
685+
686+
if first.startswith(("parquet://", "parq://")) or first.endswith(
687+
("parq", "parquet")
688+
):
689+
return self.read_parquet(source, table_name=table_name, **kwargs)
690+
elif first.startswith(
691+
("csv://", "csv.gz://", "txt://", "txt.gz://")
692+
) or first.endswith(("csv", "csv.gz", "tsv", "tsv.gz", "txt", "txt.gz")):
693+
return self.read_csv(source, table_name=table_name, **kwargs)
694+
else:
695+
self._register_failure() # noqa: RET503
696+
697+
def _register_failure(self):
698+
import inspect
699+
700+
msg = ", ".join(
701+
name for name, _ in inspect.getmembers(self) if name.startswith("read_")
702+
)
703+
raise ValueError(
704+
f"Cannot infer appropriate read function for input, "
705+
f"please call one of {msg} directly"
706+
)

ibis/backends/tests/test_register.py

Lines changed: 101 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,15 @@ def gzip_csv(data_directory, tmp_path):
4141
("fname", "in_table_name", "out_table_name"),
4242
[
4343
param("diamonds.csv", None, "ibis_read_csv_", id="default"),
44-
param("csv://diamonds.csv", "Diamonds2", "Diamonds2", id="csv_name"),
44+
param(
45+
"csv://diamonds.csv",
46+
"Diamonds2",
47+
"Diamonds2",
48+
id="csv_name",
49+
marks=pytest.mark.notyet(
50+
["pyspark"], reason="pyspark lowercases view names"
51+
),
52+
),
4553
param(
4654
"file://diamonds.csv",
4755
"fancy_stones",
@@ -53,11 +61,14 @@ def gzip_csv(data_directory, tmp_path):
5361
"fancy stones",
5462
"fancy stones",
5563
id="file_atypical_name",
64+
marks=pytest.mark.notyet(
65+
["pyspark"], reason="no spaces allowed in view names"
66+
),
5667
),
5768
param(
5869
["file://diamonds.csv", "diamonds.csv"],
59-
"fancy stones",
60-
"fancy stones",
70+
"fancy_stones2",
71+
"fancy_stones2",
6172
id="multi_csv",
6273
marks=pytest.mark.notyet(
6374
["polars", "datafusion"],
@@ -76,7 +87,6 @@ def gzip_csv(data_directory, tmp_path):
7687
"mysql",
7788
"pandas",
7889
"postgres",
79-
"pyspark",
8090
"snowflake",
8191
"sqlite",
8292
"trino",
@@ -102,7 +112,6 @@ def test_register_csv(con, data_directory, fname, in_table_name, out_table_name)
102112
"mysql",
103113
"pandas",
104114
"postgres",
105-
"pyspark",
106115
"snowflake",
107116
"sqlite",
108117
"trino",
@@ -125,7 +134,6 @@ def test_register_csv_gz(con, data_directory, gzip_csv):
125134
"mysql",
126135
"pandas",
127136
"postgres",
128-
"pyspark",
129137
"snowflake",
130138
"sqlite",
131139
"trino",
@@ -179,7 +187,6 @@ def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]:
179187
"mysql",
180188
"pandas",
181189
"postgres",
182-
"pyspark",
183190
"snowflake",
184191
"sqlite",
185192
"trino",
@@ -381,3 +388,90 @@ def test_register_garbage(con, monkeypatch):
381388

382389
with pytest.raises(FileNotFoundError):
383390
con.read_parquet("garbage_notafile")
391+
392+
393+
@pytest.mark.parametrize(
394+
("fname", "in_table_name", "out_table_name"),
395+
[
396+
(
397+
"functional_alltypes.parquet",
398+
None,
399+
"ibis_read_parquet",
400+
),
401+
("functional_alltypes.parquet", "funk_all", "funk_all"),
402+
],
403+
)
404+
@pytest.mark.notyet(
405+
[
406+
"bigquery",
407+
"clickhouse",
408+
"dask",
409+
"impala",
410+
"mssql",
411+
"mysql",
412+
"pandas",
413+
"postgres",
414+
"snowflake",
415+
"sqlite",
416+
"trino",
417+
]
418+
)
419+
def test_read_parquet(
420+
con, tmp_path, data_directory, fname, in_table_name, out_table_name
421+
):
422+
pq = pytest.importorskip("pyarrow.parquet")
423+
424+
fname = Path(fname)
425+
table = read_table(data_directory / fname.name)
426+
427+
pq.write_table(table, tmp_path / fname.name)
428+
429+
with pushd(data_directory):
430+
if con.name == "pyspark":
431+
# pyspark doesn't respect CWD
432+
fname = str(Path(fname).absolute())
433+
table = con.read_parquet(fname, table_name=in_table_name)
434+
435+
assert any(t.startswith(out_table_name) for t in con.list_tables())
436+
437+
if con.name != "datafusion":
438+
table.count().execute()
439+
440+
441+
@pytest.mark.parametrize(
442+
("fname", "in_table_name", "out_table_name"),
443+
[
444+
param("diamonds.csv", None, "ibis_read_csv_", id="default"),
445+
param(
446+
"diamonds.csv",
447+
"fancy_stones",
448+
"fancy_stones",
449+
id="file_name",
450+
),
451+
],
452+
)
453+
@pytest.mark.notyet(
454+
[
455+
"bigquery",
456+
"clickhouse",
457+
"dask",
458+
"impala",
459+
"mssql",
460+
"mysql",
461+
"pandas",
462+
"postgres",
463+
"snowflake",
464+
"sqlite",
465+
"trino",
466+
]
467+
)
468+
def test_read_csv(con, data_directory, fname, in_table_name, out_table_name):
469+
with pushd(data_directory):
470+
if con.name == "pyspark":
471+
# pyspark doesn't respect CWD
472+
fname = str(Path(fname).absolute())
473+
table = con.read_csv(fname, table_name=in_table_name)
474+
475+
assert any(t.startswith(out_table_name) for t in con.list_tables())
476+
if con.name != "datafusion":
477+
table.count().execute()

0 commit comments

Comments
 (0)