Skip to content

Commit 0aba430

Browse files
committed
feature(pytest): Update testing
1 parent b0a10aa commit 0aba430

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

ibis/backends/pyspark/tests/test_import_export.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pandas as pd
77
import pytest
8+
from pandas.testing import assert_frame_equal
89

910
from ibis.backends.pyspark.datatypes import PySparkSchema
1011

@@ -83,8 +84,11 @@ def test_to_parquet_read_parquet(con, tmp_path):
8384

8485
t_in = con.read_parquet(tmp_path / "out_np")
8586

86-
assert t_out.to_pandas().shape == t_in.to_pandas().shape
87-
assert sorted(t_out.columns) == sorted(t_in.columns)
87+
cols = list(t_out.columns)
88+
expected = t_out.to_pandas()[cols].sort_values(cols).reset_index(drop=True)
89+
result = t_in.to_pandas()[cols].sort_values(cols).reset_index(drop=True)
90+
91+
assert_frame_equal(expected, result)
8892

8993
# Partitions
9094
t_out = con.table("awards_players")
@@ -99,5 +103,8 @@ def test_to_parquet_read_parquet(con, tmp_path):
99103

100104
t_in = con.read_parquet(tmp_path / "out_p")
101105

102-
assert t_out.to_pandas().shape == t_in.to_pandas().shape
103-
assert sorted(t_out.columns) == sorted(t_in.columns)
106+
cols = list(t_out.columns)
107+
expected = t_out.to_pandas()[cols].sort_values(cols).reset_index(drop=True)
108+
result = t_in.to_pandas()[cols].sort_values(cols).reset_index(drop=True)
109+
110+
assert_frame_equal(expected, result)

ibis/backends/tests/test_export.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -257,15 +257,30 @@ def test_table_to_parquet_writer_kwargs(version, tmp_path, backend, awards_playe
257257
outparquet = tmp_path / "out.parquet"
258258
awards_players.to_parquet(outparquet, version=version)
259259

260-
df = pd.read_parquet(outparquet)
260+
if backend.name() == "pyspark":
261+
# Pyspark will write more than one parquet file under outparquet as directory
262+
parquet_files = sorted(outparquet.glob("*.parquet"))
263+
df = (
264+
pd.concat(map(pd.read_parquet, parquet_files))
265+
.sort_values(list(awards_players.columns))
266+
.reset_index(drop=True)
267+
)
268+
result = (
269+
awards_players.to_pandas()
270+
.sort_values(list(awards_players.columns))
271+
.reset_index(drop=True)
272+
)
273+
backend.assert_frame_equal(result, df)
274+
else:
275+
df = pd.read_parquet(outparquet)
261276

262-
backend.assert_frame_equal(
263-
awards_players.to_pandas().fillna(pd.NA), df.fillna(pd.NA)
264-
)
277+
backend.assert_frame_equal(
278+
awards_players.to_pandas().fillna(pd.NA), df.fillna(pd.NA)
279+
)
265280

266-
md = pa.parquet.read_metadata(outparquet)
281+
md = pa.parquet.read_metadata(outparquet)
267282

268-
assert md.format_version == version
283+
assert md.format_version == version
269284

270285

271286
@pytest.mark.notimpl(
@@ -333,7 +348,10 @@ def test_memtable_to_file(tmp_path, con, ftype, monkeypatch):
333348

334349
getattr(con, f"to_{ftype}")(memtable, outfile)
335350

336-
assert outfile.is_file()
351+
if con.name == "pyspark" and ftype == "parquet":
352+
assert outfile.is_dir()
353+
else:
354+
assert outfile.is_file()
337355

338356

339357
def test_table_to_csv(tmp_path, backend, awards_players):

0 commit comments

Comments
 (0)