Skip to content

Commit 4b28ff1

Browse files
gforsythcpcloud
authored andcommitted
fix(substitute): allow mappings with None keys
We don't need to sort this dictionary anymore.
1 parent 6e3219f commit 4b28ff1

File tree

3 files changed

+37
-6
lines changed

3 files changed

+37
-6
lines changed

ibis/backends/tests/test_generic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,3 +1608,19 @@ def test_sample_with_seed(backend):
16081608
df1 = expr.to_pandas()
16091609
df2 = expr.to_pandas()
16101610
backend.assert_frame_equal(df1, df2)
1611+
1612+
1613+
@pytest.mark.broken(
1614+
["dask"], reason="implementation somehow differs from pandas", raises=ValueError
1615+
)
1616+
def test_substitute(backend):
1617+
val = "400"
1618+
t = backend.functional_alltypes
1619+
expr = (
1620+
t.string_col.nullif("1")
1621+
.substitute({None: val})
1622+
.name("subs")
1623+
.value_counts()
1624+
.filter(lambda t: t.subs == val)
1625+
)
1626+
assert expr["subs_count"].execute()[0] == t.count().execute() // 10

ibis/expr/types/generic.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import contextlib
34
from typing import TYPE_CHECKING, Any, Iterable, Literal, Sequence
45

56
from public import public
@@ -683,12 +684,18 @@ def substitute(
683684
│ torg │ 52 │
684685
└────────┴──────────────┘
685686
"""
686-
expr = self.case()
687687
if isinstance(value, dict):
688-
for k, v in sorted(value.items()):
689-
expr = expr.when(k, v)
688+
expr = ibis.case()
689+
try:
690+
null_replacement = value.pop(None)
691+
except KeyError:
692+
pass
693+
else:
694+
expr = expr.when(self.isnull(), null_replacement)
695+
for k, v in value.items():
696+
expr = expr.when(self == k, v)
690697
else:
691-
expr = expr.when(value, replacement)
698+
expr = self.case().when(value, replacement)
692699

693700
return expr.else_(else_ if else_ is not None else self).end()
694701

ibis/tests/expr/test_value_exprs.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -792,13 +792,21 @@ def test_substitute_dict():
792792

793793
result = table.foo.substitute(subs)
794794
expected = (
795-
table.foo.case().when("a", "one").when("b", table.bar).else_(table.foo).end()
795+
ibis.case()
796+
.when(table.foo == "a", "one")
797+
.when(table.foo == "b", table.bar)
798+
.else_(table.foo)
799+
.end()
796800
)
797801
assert_equal(result, expected)
798802

799803
result = table.foo.substitute(subs, else_=ibis.NA)
800804
expected = (
801-
table.foo.case().when("a", "one").when("b", table.bar).else_(ibis.NA).end()
805+
ibis.case()
806+
.when(table.foo == "a", "one")
807+
.when(table.foo == "b", table.bar)
808+
.else_(ibis.NA)
809+
.end()
802810
)
803811
assert_equal(result, expected)
804812

0 commit comments

Comments
 (0)