Skip to content

Commit e4e2993

Browse files
committed
refactor(common): remove traverse() function's filter argument since it can be expressed using the visitor
1 parent 442199a commit e4e2993

File tree

2 files changed

+3
-12
lines changed

2 files changed

+3
-12
lines changed

ibis/common/graph.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ibis.common.bases import Hashable
1010
from ibis.common.collections import frozendict
11-
from ibis.common.patterns import NoMatch, Pattern, pattern
11+
from ibis.common.patterns import NoMatch, pattern
1212
from ibis.util import experimental
1313

1414
if TYPE_CHECKING:
@@ -485,9 +485,7 @@ def toposort(node: Node) -> Graph:
485485

486486

487487
def traverse(
488-
fn: Callable[[Node], tuple[bool | Iterable, Any]],
489-
node: Iterable[Node] | Node,
490-
filter: Optional[Any] = None,
488+
fn: Callable[[Node], tuple[bool | Iterable, Any]], node: Iterable[Node] | Node
491489
) -> Iterator[Any]:
492490
"""Utility for generic expression tree traversal.
493491
@@ -498,24 +496,17 @@ def traverse(
498496
the traversal, and the second is the result if its not `None`.
499497
node
500498
The Node expression or a list of expressions.
501-
filter
502-
Pattern-like object to filter out nodes from the traversal. The traversal will
503-
only visit nodes that match the given pattern and stop otherwise.
504499
"""
505500

506501
args = reversed(node) if isinstance(node, Sequence) else [node]
507502
todo: deque[Node] = deque(args)
508503
seen: set[Node] = set()
509-
filter: Pattern = pattern(filter or ...)
510504

511505
while todo:
512506
node = todo.pop()
513507

514508
if node in seen:
515509
continue
516-
if filter.match(node, {}) is NoMatch:
517-
continue
518-
519510
seen.add(node)
520511

521512
control, result = fn(node)

ibis/expr/analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,4 +470,4 @@ def finder(node):
470470
find_first_base_table(node) if isinstance(node, ops.Unnest) else None,
471471
)
472472

473-
return g.traverse(finder, nodes, filter=ops.Node)
473+
return g.traverse(finder, nodes)

0 commit comments

Comments
 (0)