Skip to content

Commit ac79604

Browse files
committed
perf(ir): don't recreate nodes in replace if their children haven't changed
1 parent f86515c commit ac79604

File tree

4 files changed

+132
-44
lines changed

4 files changed

+132
-44
lines changed

ibis/backends/sql/rewrites.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def sqlize(
370370

371371
# lower the expression graph to a SQL-like relational algebra
372372
context = {"params": params}
373-
sqlized = node.replace(
373+
result = node.replace(
374374
replace_parameter
375375
| project_to_select
376376
| filter_to_select
@@ -385,24 +385,23 @@ def sqlize(
385385

386386
# squash subsequent Select nodes into one
387387
if fuse_selects:
388-
simplified = sqlized.replace(merge_select_select)
389-
else:
390-
simplified = sqlized
388+
result = result.replace(merge_select_select)
391389

392390
if post_rewrites:
393-
simplified = simplified.replace(reduce(operator.or_, post_rewrites))
391+
result = result.replace(reduce(operator.or_, post_rewrites))
394392

395393
# extract common table expressions while wrapping them in a CTE node
396-
ctes = extract_ctes(simplified)
394+
ctes = extract_ctes(result)
397395

398-
def wrap(node, _, **kwargs):
399-
new = node.__recreate__(kwargs)
400-
return CTE(new) if node in ctes else new
396+
if ctes:
401397

402-
result = simplified.replace(wrap)
403-
ctes = [cte.parent for cte in result.find(CTE, ordered=True)]
398+
def apply_ctes(node, kwargs):
399+
new = node.__recreate__(kwargs) if kwargs else node
400+
return CTE(new) if node in ctes else new
404401

405-
return result, ctes
402+
result = result.replace(apply_ctes)
403+
return result, [cte.parent for cte in result.find(CTE, ordered=True)]
404+
return result, []
406405

407406

408407
# supplemental rewrites selectively used on a per-backend basis

ibis/backends/tests/test_numeric.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,16 +1382,14 @@ def test_histogram(con, alltypes):
13821382
hist = con.execute(alltypes.int_col.histogram(n).name("hist"))
13831383
vc = hist.value_counts().sort_index()
13841384
vc_np, _bin_edges = np.histogram(alltypes.int_col.execute(), bins=n)
1385-
assert vc.tolist() == vc_np.tolist()
1386-
assert (
1387-
con.execute(
1388-
ibis.memtable({"value": range(100)})
1389-
.select(bin=_.value.histogram(10))
1390-
.value_counts()
1391-
.bin_count.nunique()
1392-
)
1393-
== 1
1385+
expr = (
1386+
ibis.memtable({"value": range(100)})
1387+
.select(bin=_.value.histogram(10))
1388+
.value_counts()
1389+
.bin_count.nunique()
13941390
)
1391+
assert vc.tolist() == vc_np.tolist()
1392+
assert con.execute(expr) == 1
13951393

13961394

13971395
@pytest.mark.parametrize("const", ["pi", "e"])

ibis/common/graph.py

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
Finder = Callable[["Node"], bool]
2323
FinderLike = Union[Finder, Pattern, _ClassInfo]
2424

25-
Replacer = Callable[["Node", dict["Node", Any]], "Node"]
25+
Replacer = Callable[["Node", dict["Node", Any] | None], "Node"]
2626
ReplacerLike = Union[Replacer, Pattern, Mapping]
2727

2828

@@ -127,6 +127,47 @@ def _recursive_lookup(obj: Any, dct: dict) -> Any:
127127
return obj
128128

129129

130+
def _apply_replacements(obj: Any, replacements: dict) -> tuple[Any, bool]:
131+
"""Replace nodes in a possibly nested object.
132+
133+
Parameters
134+
----------
135+
obj
136+
The object to traverse.
137+
replacements
138+
A mapping of replacement values.
139+
140+
Returns
141+
-------
142+
tuple[Any, bool]
143+
A tuple of the replaced object and whether any replacements were made.
144+
"""
145+
if isinstance(obj, Node):
146+
val = replacements.get(obj)
147+
return (obj, False) if val is None else (val, True)
148+
typ = type(obj)
149+
if typ in (tuple, frozenset, list):
150+
changed = False
151+
items = []
152+
for i in obj:
153+
i, ichanged = _apply_replacements(i, replacements)
154+
changed |= ichanged
155+
items.append(i)
156+
return typ(items), changed
157+
elif isinstance(obj, dict):
158+
changed = False
159+
items = {}
160+
for k, v in obj.items():
161+
k, kchanged = _apply_replacements(k, replacements)
162+
v, vchanged = _apply_replacements(v, replacements)
163+
changed |= kchanged
164+
changed |= vchanged
165+
items[k] = v
166+
return items, changed
167+
else:
168+
return obj, False
169+
170+
130171
def _coerce_finder(obj: FinderLike, context: Optional[dict] = None) -> Finder:
131172
"""Coerce an object into a callable finder function.
132173
@@ -165,8 +206,7 @@ def _coerce_replacer(obj: ReplacerLike, context: Optional[dict] = None) -> Repla
165206
Parameters
166207
----------
167208
obj
168-
A Pattern, a Mapping or a callable which can be fed to `node.map()`
169-
to replace nodes.
209+
A Pattern, Mapping, or Callable.
170210
context
171211
Optional context to use if the replacer is a pattern.
172212
@@ -177,26 +217,26 @@ def _coerce_replacer(obj: ReplacerLike, context: Optional[dict] = None) -> Repla
177217
"""
178218
if isinstance(obj, Pattern):
179219

180-
def fn(node, _, **kwargs):
220+
def fn(node, kwargs):
181221
ctx = context or {}
182222
# need to first reconstruct the node from the possible rewritten
183223
# children, so we can match on the new node containing the rewritten
184224
# child arguments, this way we can propagate the rewritten nodes
185-
# upward in the hierarchy, using a specialized __recreate__ method
186-
# improves the performance by 17% compared node.__class__(**kwargs)
187-
recreated = node.__recreate__(kwargs)
225+
# upward in the hierarchy
226+
recreated = node.__recreate__(kwargs) if kwargs else node
188227
if (result := obj.match(recreated, ctx)) is NoMatch:
189228
return recreated
190-
else:
191-
return result
229+
return result
192230

193231
elif isinstance(obj, Mapping):
194232

195-
def fn(node, _, **kwargs):
233+
def fn(node, kwargs):
234+
# For a mapping we want to lookup the original node first, and
235+
# return a recreated one from the children if it's not present
196236
try:
197237
return obj[node]
198238
except KeyError:
199-
return node.__recreate__(kwargs)
239+
return node.__recreate__(kwargs) if kwargs else node
200240
elif callable(obj):
201241
fn = obj
202242
else:
@@ -313,7 +353,7 @@ def map_clear(self, fn: Callable, filter: Optional[Finder] = None) -> Any:
313353
if not dependents[dependency]:
314354
del results[dependency]
315355

316-
return results[self]
356+
return results.get(self, self)
317357

318358
@experimental
319359
def map_nodes(self, fn: Callable, filter: Optional[Finder] = None) -> Any:
@@ -451,8 +491,9 @@ def replace(
451491
Parameters
452492
----------
453493
replacer
454-
A `Pattern`, a `Mapping` or a callable which can be fed to
455-
`node.map()` directly to replace nodes.
494+
A `Pattern`, `Mapping` or Callable taking the original unrewritten
495+
node, and a mapping of attribute name to value of its rewritten
496+
children (or None if no children were rewritten).
456497
filter
457498
A type, tuple of types, a pattern or a callable to filter out nodes
458499
from the traversal. The traversal will only visit nodes that match
@@ -465,9 +506,28 @@ def replace(
465506
The root node of the graph with the replaced nodes.
466507
467508
"""
468-
replacer = _coerce_replacer(replacer, context)
469-
results = self.map(replacer, filter=filter)
470-
return results.get(self, self)
509+
replacements: dict[Node, Any] = {}
510+
511+
fn = _coerce_replacer(replacer, context)
512+
513+
graph, _ = Graph.from_bfs(self, filter=filter).toposort()
514+
for node in graph:
515+
kwargs = {}
516+
# Apply already rewritten nodes to the children of the node
517+
changed = False
518+
for k, v in zip(node.__argnames__, node.__args__):
519+
v, vchanged = _apply_replacements(v, replacements)
520+
changed |= vchanged
521+
kwargs[k] = v
522+
523+
# Call the replacer on the node with any rewritten nodes (or None
524+
# if unchanged).
525+
result = fn(node, kwargs if changed else None)
526+
if result is not node:
527+
# The node is changed, store it in the mapping of replacements
528+
replacements[node] = result
529+
530+
return replacements.get(self, self)
471531

472532

473533
class Graph(dict[Node, Sequence[Node]]):

ibis/common/tests/test_graph.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
traverse,
2020
)
2121
from ibis.common.grounds import Annotable, Concrete
22-
from ibis.common.patterns import Eq, If, InstanceOf, Object, TupleOf, _
22+
from ibis.common.patterns import Eq, If, InstanceOf, Object, TupleOf, _, pattern
2323

2424

2525
class MyNode(Node):
@@ -170,6 +170,36 @@ def test_replace_with_mapping():
170170
assert result == new_A
171171

172172

173+
@pytest.mark.parametrize("kind", ["pattern", "mapping", "function"])
174+
def test_replace_doesnt_recreate_unchanged_nodes(kind):
175+
A1 = MyNode(name="A1", children=[])
176+
A2 = MyNode(name="A2", children=[A1])
177+
B1 = MyNode(name="B1", children=[])
178+
B2 = MyNode(name="B2", children=[B1])
179+
C = MyNode(name="C", children=[A2, B2])
180+
181+
B3 = MyNode(name="B3", children=[])
182+
183+
if kind == "pattern":
184+
replacer = pattern(MyNode)(name="B2") >> B3
185+
elif kind == "mapping":
186+
replacer = {B2: B3}
187+
else:
188+
189+
def replacer(node, children):
190+
if node is B2:
191+
return B3
192+
return node.__recreate__(children) if children else node
193+
194+
res = C.replace(replacer)
195+
196+
assert res is not C
197+
assert res.name == "C"
198+
assert len(res.children) == 2
199+
assert res.children[0] is A2
200+
assert res.children[1] is B3
201+
202+
173203
def test_example():
174204
class Example(Annotable, Node):
175205
def __hash__(self):
@@ -343,17 +373,18 @@ def test_coerce_finder():
343373

344374

345375
def test_coerce_replacer():
346-
r = _coerce_replacer(lambda x, _, **kwargs: D)
347-
assert r(C, {}) == D
376+
r = _coerce_replacer(lambda x, children: D if children else C)
377+
assert r(C, {"children": []}) is D
378+
assert r(C, None) is C
348379

349380
r = _coerce_replacer({C: D, D: E})
350381
assert r(C, {}) == D
351382
assert r(D, {}) == E
352-
assert r(A, {}, name="A", children=[B, C]) == A
383+
assert r(A, {"name": "A", "children": [B, C]}) == A
353384

354385
r = _coerce_replacer(InstanceOf(MyNode) >> _.copy(name=_.name.lower()))
355-
assert r(C, {}, name="C", children=[]) == MyNode(name="c", children=[])
356-
assert r(D, {}, name="D", children=[]) == MyNode(name="d", children=[])
386+
assert r(C, {"name": "C", "children": []}) == MyNode(name="c", children=[])
387+
assert r(D, {"name": "D", "children": []}) == MyNode(name="d", children=[])
357388

358389

359390
def test_node_find_using_type():

0 commit comments

Comments
 (0)