Skip to content

Commit faf99df

Browse files
gforsythcpcloud
authored andcommitted
fix(memtable): ensure column names match provided data
Previously, the `pandas` memtable constructor was performing two tasks inconsistently, one was to subselect columns out of the provided dataframe, the other was to rename those columns. This led to some weird behavior where a mismatch in provided names could lead to a dataframe consisting of NaNs. Now, `columns` can only be provided to rename existing columns and there is no subselection behavior. If the length of the `columns` iterable does not match the number of columns in the provided data, we error.
1 parent 241c8be commit faf99df

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

ibis/backends/tests/test_generic.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,39 @@ def test_memtable_construct(backend, con, monkeypatch):
947947
)
948948

949949

950+
@pytest.mark.parametrize(
951+
"df, columns, expected",
952+
[
953+
(pd.DataFrame([("a", 1.0)], columns=["d", "f"]), ["a", "b"], ["a", "b"]),
954+
(pd.DataFrame([("a", 1.0)]), ["A", "B"], ["A", "B"]),
955+
(pd.DataFrame([("a", 1.0)], columns=["c", "d"]), None, ["c", "d"]),
956+
([("a", "1.0")], None, ["col0", "col1"]),
957+
([("a", "1.0")], ["d", "e"], ["d", "e"]),
958+
],
959+
)
960+
def test_memtable_column_naming(backend, con, monkeypatch, df, columns, expected):
961+
monkeypatch.setattr(ibis.options, "default_backend", con)
962+
963+
t = ibis.memtable(df, columns=columns)
964+
assert all(t.to_pandas().columns == expected)
965+
966+
967+
@pytest.mark.parametrize(
968+
"df, columns",
969+
[
970+
(pd.DataFrame([("a", 1.0)], columns=["d", "f"]), ["a"]),
971+
(pd.DataFrame([("a", 1.0)]), ["A", "B", "C"]),
972+
([("a", "1.0")], ["col0", "col1", "col2"]),
973+
([("a", "1.0")], ["d"]),
974+
],
975+
)
976+
def test_memtable_column_naming_mismatch(backend, con, monkeypatch, df, columns):
977+
monkeypatch.setattr(ibis.options, "default_backend", con)
978+
979+
with pytest.raises(ValueError):
980+
ibis.memtable(df, columns=columns)
981+
982+
950983
@pytest.mark.notimpl(
951984
["dask", "datafusion", "pandas", "polars"],
952985
raises=NotImplementedError,

ibis/expr/api.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,8 @@ def memtable(
371371
Do not depend on the underlying storage type (e.g., pyarrow.Table), it's subject
372372
to change across non-major releases.
373373
columns
374-
Optional [](`typing.Iterable`) of [](`str`) column names.
374+
Optional [](`typing.Iterable`) of [](`str`) column names. If provided,
375+
must match the number of columns in `data`.
375376
schema
376377
Optional [`Schema`](./schemas.qmd#ibis.expr.schema.Schema).
377378
The functions use `data` to infer a schema if not passed.
@@ -468,7 +469,11 @@ def _memtable_from_dataframe(
468469

469470
from ibis.expr.operations.relations import PandasDataFrameProxy
470471

471-
df = pd.DataFrame(data, columns=columns)
472+
if not isinstance(data, pd.DataFrame):
473+
df = pd.DataFrame(data, columns=columns)
474+
else:
475+
df = data
476+
472477
if df.columns.inferred_type != "string":
473478
cols = df.columns
474479
newcols = getattr(
@@ -478,6 +483,15 @@ def _memtable_from_dataframe(
478483
)
479484
df = df.rename(columns=dict(zip(cols, newcols)))
480485

486+
if columns is not None:
487+
if (provided_col := len(columns)) != (exist_col := len(df.columns)):
488+
raise ValueError(
489+
"Provided `columns` must have an entry for each column in `data`.\n"
490+
f"`columns` has {provided_col} elements but `data` has {exist_col} columns."
491+
)
492+
493+
df = df.rename(columns=dict(zip(df.columns, columns)))
494+
481495
# verify that the DataFrame has no duplicate column names because ibis
482496
# doesn't allow that
483497
cols = df.columns

0 commit comments

Comments
 (0)