Skip to content

Commit 3c14091

Browse files
committed
refactor(common): unify the node.find() and node.match() methods to transparently support types and patterns
1 parent bbc93c7 commit 3c14091

File tree

3 files changed

+54
-40
lines changed

3 files changed

+54
-40
lines changed

ibis/common/graph.py

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -162,38 +162,22 @@ def map(self, fn: Callable, filter: Optional[Any] = None) -> dict[Node, Any]:
162162
results[node] = fn(node, results, **kwargs)
163163
return results
164164

165-
def find(self, type: type | tuple[type], filter: Optional[Any] = None) -> set[Node]:
166-
"""Find all nodes of a given type in the graph.
167-
168-
Parameters
169-
----------
170-
type
171-
Type or tuple of types to find.
172-
filter
173-
Pattern-like object to filter out nodes from the traversal. The traversal
174-
will only visit nodes that match the given pattern and stop otherwise.
175-
176-
Returns
177-
-------
178-
The set of nodes matching the given type.
179-
"""
180-
nodes = Graph.from_bfs(self, filter=filter).nodes()
181-
return {node for node in nodes if isinstance(node, type)}
182-
183-
@experimental
184-
def match(
185-
self, pat: Any, filter: Optional[Any] = None, context: Optional[dict] = None
186-
) -> set[Node]:
187-
"""Find all nodes matching a given pattern in the graph.
188-
189-
A more advanced version of find, this method allows to match nodes based on
190-
the more flexible pattern matching system implemented in the pattern module.
165+
def find(
166+
self,
167+
pat: type | tuple[type],
168+
filter: Optional[Any] = None,
169+
context: Optional[dict] = None,
170+
) -> list[Node]:
171+
"""Find all nodes matching a given pattern or type in the graph.
172+
173+
Allow to match nodes based on the flexible pattern matching system implemented
174+
in the pattern module, but also provide a fast path for matching based on the
175+
type of the node.
191176
192177
Parameters
193178
----------
194179
pat
195-
Pattern to match. `ibis.common.pattern()` function is used to coerce the
196-
input value into a pattern. See the pattern module for more details.
180+
Python type or `Pattern` to match.
197181
filter
198182
Pattern-like object to filter out nodes from the traversal. The traversal
199183
will only visit nodes that match the given pattern and stop otherwise.
@@ -202,12 +186,16 @@ def match(
202186
203187
Returns
204188
-------
205-
The set of nodes matching the given pattern.
189+
The list of nodes matching the given pattern. The order of the nodes is
190+
determined by a breadth-first search.
206191
"""
207-
pat = pattern(pat)
208-
ctx = context or {}
209192
nodes = Graph.from_bfs(self, filter=filter).nodes()
210-
return {node for node in nodes if pat.match(node, ctx) is not NoMatch}
193+
if isinstance(pat, (tuple, type)):
194+
return [node for node in nodes if isinstance(node, pat)]
195+
else:
196+
pat = pattern(pat)
197+
ctx = context or {}
198+
return [node for node in nodes if pat.match(node, ctx) is not NoMatch]
211199

212200
@experimental
213201
def replace(

ibis/common/tests/test_graph.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,38 @@ def test_recursive_lookup():
309309
)
310310

311311

312-
def test_node_match():
313-
result = A.match(If(_.name == "C"))
314-
assert result == {C}
312+
def test_node_find_using_type():
313+
class FooNode(MyNode):
314+
pass
315+
316+
class BarNode(MyNode):
317+
pass
318+
319+
C = BarNode(name="C", children=[])
320+
D = FooNode(name="D", children=[])
321+
E = BarNode(name="E", children=[])
322+
B = FooNode(name="B", children=[D, E])
323+
A = MyNode(name="A", children=[B, C])
324+
325+
result = A.find(MyNode)
326+
assert result == [A, B, C, D, E]
327+
328+
result = A.find(FooNode)
329+
assert result == [B, D]
330+
331+
result = A.find(BarNode)
332+
assert result == [C, E]
333+
334+
result = A.find((FooNode, BarNode))
335+
assert result == [B, C, D, E]
336+
337+
338+
def test_node_find_using_pattern():
339+
result = A.find(If(_.name == "C"))
340+
assert result == [C]
315341

316-
result = A.match(Object(MyNode, name=Eq("D")))
317-
assert result == {D}
342+
result = A.find(Object(MyNode, name=Eq("D")))
343+
assert result == [D]
318344

319-
result = A.match(If(_.children))
320-
assert result == {A, B}
345+
result = A.find(If(_.children))
346+
assert result == [A, B]

ibis/expr/analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def pushdown_selection_filters(parent, predicates):
151151

152152
simplified = []
153153
for pred in predicates:
154-
if pred.match(conflicting_projection, filter=p.Value):
154+
if pred.find(conflicting_projection, filter=p.Value):
155155
return default
156156
try:
157157
simplified.append(pred.replace(pushdown_pattern))

0 commit comments

Comments
 (0)