Skip to content

Commit bbc93c7

Browse files
committed
feat(common): node.replace() now supports mappings for quick lookup-like substitutions
1 parent 1d314f7 commit bbc93c7

File tree

2 files changed

+43
-20
lines changed

2 files changed

+43
-20
lines changed

ibis/common/graph.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from abc import abstractmethod
55
from collections import deque
6-
from collections.abc import Iterable, Iterator, KeysView, Sequence
6+
from collections.abc import Iterable, Iterator, KeysView, Mapping, Sequence
77
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar
88

99
from ibis.common.bases import Hashable
@@ -233,24 +233,33 @@ def replace(
233233
-------
234234
The root node of the graph with the replaced nodes.
235235
"""
236-
pat = pattern(pat)
237-
ctx = context or {}
238236

239-
def fn(node, _, **kwargs):
240-
# need to first reconstruct the node from the possible rewritten
241-
# children, so we can match on the new node containing the rewritten
242-
# child arguments, this way we can propagate the rewritten nodes
243-
# upward in the hierarchy
244-
# TODO(kszucs): add a __recreate__() method to the Node interface
245-
# with a default implementation that uses the __class__ constructor
246-
# which is supposed to provide an implementation for quick object
247-
# reconstruction (the __recreate__ implementation in grounds.py
248-
# should be sped up as well by totally avoiding the validation)
249-
recreated = node.__class__(**kwargs)
250-
if (result := pat.match(recreated, ctx)) is NoMatch:
251-
return recreated
252-
else:
253-
return result
237+
if isinstance(pat, Mapping):
238+
239+
def fn(node, _, **kwargs):
240+
try:
241+
return pat[node]
242+
except KeyError:
243+
return node.__class__(**kwargs)
244+
else:
245+
pat = pattern(pat)
246+
ctx = context or {}
247+
248+
def fn(node, _, **kwargs):
249+
# need to first reconstruct the node from the possible rewritten
250+
# children, so we can match on the new node containing the rewritten
251+
# child arguments, this way we can propagate the rewritten nodes
252+
# upward in the hierarchy
253+
# TODO(kszucs): add a __recreate__() method to the Node interface
254+
# with a default implementation that uses the __class__ constructor
255+
# which is supposed to provide an implementation for quick object
256+
# reconstruction (the __recreate__ implementation in grounds.py
257+
# should be sped up as well by totally avoiding the validation)
258+
recreated = node.__class__(**kwargs)
259+
if (result := pat.match(recreated, ctx)) is NoMatch:
260+
return recreated
261+
else:
262+
return result
254263

255264
results = self.map(fn, filter=filter)
256265
return results.get(self, self)

ibis/common/tests/test_graph.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ def __init__(self, name, children):
2828

2929
@property
3030
def __args__(self):
31-
return (self.children,)
31+
return (self.name, self.children)
3232

3333
@property
3434
def __argnames__(self):
35-
return ("children",)
35+
return ("name", "children")
3636

3737
def __repr__(self):
3838
return f"{self.__class__.__name__}({self.name})"
@@ -145,6 +145,20 @@ def test_replace_with_filtering_out_root():
145145
assert result == A
146146

147147

148+
def test_replace_with_mapping():
149+
new_E = MyNode(name="e", children=[])
150+
new_D = MyNode(name="d", children=[])
151+
new_B = MyNode(name="B", children=[new_D, new_E])
152+
new_A = MyNode(name="A", children=[new_B, C])
153+
154+
subs = {
155+
E: new_E,
156+
D: new_D,
157+
}
158+
result = A.replace(subs)
159+
assert result == new_A
160+
161+
148162
def test_example():
149163
class Example(Annotable, Node):
150164
def __hash__(self):

0 commit comments

Comments
 (0)