Skip to content

Commit d51ec4e

Browse files
committed
fix(ir): implicitly convert None literals with dt.Null type to the requested type during value coercion
1 parent 66fd69c commit d51ec4e

File tree

5 files changed

+59
-3
lines changed

5 files changed

+59
-3
lines changed

ibis/expr/operations/core.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,19 @@ def __coerce__(
6868
) -> Self:
6969
# note that S=Shape is unused here since the pattern will check the
7070
# shape of the value expression after executing Value.__coerce__()
71-
from ibis.expr.operations import Literal
71+
from ibis.expr.operations.generic import NULL, Literal
7272
from ibis.expr.types import Expr
7373

7474
if isinstance(value, Expr):
7575
value = value.op()
76+
7677
if isinstance(value, Value):
77-
return value
78+
if value == NULL:
79+
# treat the NULL literal the same as None to implicitly cast to
80+
# the requested datatype if any
81+
value = None
82+
else:
83+
return value
7884

7985
if T is dt.Integer:
8086
dtype = dt.infer(int(value))

ibis/expr/operations/generic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def name(self):
195195
return repr(self.value)
196196

197197

198+
NULL = Literal(None, dt.null)
199+
200+
198201
@public
199202
class ScalarParameter(Scalar, Named):
200203
_counter = itertools.count()
@@ -313,3 +316,6 @@ def shape(self):
313316
def dtype(self):
314317
exprs = [*self.results, self.default]
315318
return rlz.highest_precedence_dtype(exprs)
319+
320+
321+
public(NULL=NULL)

ibis/expr/operations/tests/test_generic.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,24 @@ def test_error_message_when_constructing_literal(call, error, snapshot):
122122
with pytest.raises(ValidationError) as exc:
123123
call()
124124
snapshot.assert_match(str(exc.value), f"{error}.txt")
125+
126+
127+
def test_implicit_coercion_of_null_literal():
128+
# GH #7775
129+
NULL = ops.Literal(None, dt.null)
130+
131+
value = ops.Value.__coerce__(None, dt.Int8)
132+
expected = ops.Literal(None, dt.int8)
133+
assert value == expected
134+
135+
value = ops.Value.__coerce__(NULL, dt.Float64)
136+
expected = ops.Literal(None, dt.float64)
137+
assert value == expected
138+
139+
140+
def test_NULL():
141+
assert isinstance(ops.NULL, ops.Literal)
142+
assert ops.NULL.value is None
143+
assert ops.NULL.dtype is dt.null
144+
assert ops.NULL == ops.Literal(None, dt.null)
145+
assert ops.NULL is not ops.Literal(None, dt.int8)

ibis/expr/tests/test_api.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import operator
34
from datetime import datetime
45

56
import pandas as pd
@@ -9,6 +10,7 @@
910

1011
import ibis
1112
import ibis.expr.datatypes as dt
13+
import ibis.expr.operations as ops
1214
import ibis.expr.schema as sch
1315
from ibis import _
1416
from ibis.common.exceptions import IbisInputError, IntegrityError
@@ -124,3 +126,24 @@ def test_duplicate_columns_in_memtable_not_allowed():
124126

125127
with pytest.raises(IbisInputError, match="Duplicate column names"):
126128
ibis.memtable(df)
129+
130+
131+
@pytest.mark.parametrize(
132+
"op",
133+
[
134+
operator.and_,
135+
operator.or_,
136+
operator.xor,
137+
],
138+
)
139+
def test_implicit_coercion_of_null_literal(op):
140+
# GH #7775
141+
expr1 = op(ibis.literal(True), ibis.null())
142+
expr2 = op(ibis.literal(True), None)
143+
144+
expected = expr1.op().__class__(
145+
ops.Literal(True, dtype=dt.boolean), ops.Literal(None, dtype=dt.boolean)
146+
)
147+
148+
assert expr1.op() == expected
149+
assert expr2.op() == expected

ibis/expr/types/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1996,7 +1996,7 @@ class NullColumn(Column, NullValue):
19961996
@public
19971997
def null():
19981998
"""Create a NULL/NA scalar."""
1999-
return literal(None)
1999+
return ops.NULL.to_expr()
20002000

20012001

20022002
@public

0 commit comments

Comments
 (0)