Skip to content

Commit dbc34bb

Browse files
gerrymanoimcpcloud
authored andcommitted
feat: allow column_of to take a column expression
1 parent 8f4bc79 commit dbc34bb

File tree

2 files changed

+75
-5
lines changed

2 files changed

+75
-5
lines changed

ibis/expr/rules.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -424,11 +424,40 @@ def table(arg, *, schema=None, **kwargs):
424424

425425
@validator
426426
def column_from(name, column, *, this):
427-
if not isinstance(column, (str, int)):
428-
raise com.IbisTypeError(
429-
f"value must be an int or str, got {type(column).__name__}"
430-
)
431-
return getattr(this, name)[column]
427+
"""A column from a named table.
428+
429+
This validator accepts columns passed as string, integer, or column
430+
expression. In the case of a column expression, this validator
431+
checks if the column in the table is equal to the column being
432+
passed.
433+
"""
434+
if not hasattr(this, name):
435+
raise com.IbisTypeError(f"Could not get table {name} from {this}")
436+
table = getattr(this, name)
437+
438+
if isinstance(column, (str, int)):
439+
return table[column]
440+
elif isinstance(column, ir.AnyColumn):
441+
if not column.has_name():
442+
raise com.IbisTypeError(f"Passed column {column} has no name")
443+
444+
maybe_column = column.get_name()
445+
try:
446+
if column.equals(table[maybe_column]):
447+
return column
448+
else:
449+
raise com.IbisTypeError(
450+
f"Passed column is not a column in {table}"
451+
)
452+
except com.IbisError:
453+
raise com.IbisTypeError(
454+
f"Cannot get column {maybe_column} from {table}"
455+
)
456+
457+
raise com.IbisTypeError(
458+
"value must be an int or str or AnyColumn, got "
459+
f"{type(column).__name__}"
460+
)
432461

433462

434463
@validator

ibis/tests/expr/test_rules.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
[('int_col', 'int64'), ('string_col', 'string'), ('double_col', 'double')]
1515
)
1616

17+
similar_table = ibis.table(
18+
[('int_col', 'int64'), ('string_col', 'string'), ('double_col', 'double')]
19+
)
20+
1721

1822
@pytest.mark.parametrize(
1923
('value', 'expected'),
@@ -290,6 +294,43 @@ def test_invalid_column_or_scalar(validator, value, expected):
290294
validator(value)
291295

292296

297+
@pytest.mark.parametrize(
298+
('check_table', 'value', 'expected'),
299+
[
300+
(table, "int_col", table.int_col),
301+
(table, table.int_col, table.int_col),
302+
],
303+
)
304+
def test_valid_column_from(check_table, value, expected):
305+
class Test:
306+
table = check_table
307+
308+
validator = rlz.column_from("table")
309+
assert validator(value, this=Test()).equals(expected)
310+
311+
312+
@pytest.mark.parametrize(
313+
('check_table', 'validator', 'value'),
314+
[
315+
(table, rlz.column_from("not_table"), "int_col"),
316+
(table, rlz.column_from("table"), "col_not_in_table"),
317+
(
318+
table,
319+
rlz.column_from("table"),
320+
similar_table.int_col,
321+
),
322+
],
323+
)
324+
def test_invalid_column_from(check_table, validator, value):
325+
class Test:
326+
table = check_table
327+
328+
test = Test()
329+
330+
with pytest.raises(IbisTypeError):
331+
validator(value, this=test)
332+
333+
293334
@pytest.mark.parametrize(
294335
'table',
295336
[

0 commit comments

Comments
 (0)