22
22
Finder = Callable [["Node" ], bool ]
23
23
FinderLike = Union [Finder , Pattern , _ClassInfo ]
24
24
25
- Replacer = Callable [["Node" , dict ["Node" , Any ]], "Node" ]
25
+ Replacer = Callable [["Node" , dict ["Node" , Any ] | None ], "Node" ]
26
26
ReplacerLike = Union [Replacer , Pattern , Mapping ]
27
27
28
28
@@ -127,6 +127,47 @@ def _recursive_lookup(obj: Any, dct: dict) -> Any:
127
127
return obj
128
128
129
129
130
+ def _apply_replacements (obj : Any , replacements : dict ) -> tuple [Any , bool ]:
131
+ """Replace nodes in a possibly nested object.
132
+
133
+ Parameters
134
+ ----------
135
+ obj
136
+ The object to traverse.
137
+ replacements
138
+ A mapping of replacement values.
139
+
140
+ Returns
141
+ -------
142
+ tuple[Any, bool]
143
+ A tuple of the replaced object and whether any replacements were made.
144
+ """
145
+ if isinstance (obj , Node ):
146
+ val = replacements .get (obj )
147
+ return (obj , False ) if val is None else (val , True )
148
+ typ = type (obj )
149
+ if typ in (tuple , frozenset , list ):
150
+ changed = False
151
+ items = []
152
+ for i in obj :
153
+ i , ichanged = _apply_replacements (i , replacements )
154
+ changed |= ichanged
155
+ items .append (i )
156
+ return typ (items ), changed
157
+ elif isinstance (obj , dict ):
158
+ changed = False
159
+ items = {}
160
+ for k , v in obj .items ():
161
+ k , kchanged = _apply_replacements (k , replacements )
162
+ v , vchanged = _apply_replacements (v , replacements )
163
+ changed |= kchanged
164
+ changed |= vchanged
165
+ items [k ] = v
166
+ return items , changed
167
+ else :
168
+ return obj , False
169
+
170
+
130
171
def _coerce_finder (obj : FinderLike , context : Optional [dict ] = None ) -> Finder :
131
172
"""Coerce an object into a callable finder function.
132
173
@@ -165,8 +206,7 @@ def _coerce_replacer(obj: ReplacerLike, context: Optional[dict] = None) -> Repla
165
206
Parameters
166
207
----------
167
208
obj
168
- A Pattern, a Mapping or a callable which can be fed to `node.map()`
169
- to replace nodes.
209
+ A Pattern, Mapping, or Callable.
170
210
context
171
211
Optional context to use if the replacer is a pattern.
172
212
@@ -177,26 +217,26 @@ def _coerce_replacer(obj: ReplacerLike, context: Optional[dict] = None) -> Repla
177
217
"""
178
218
if isinstance (obj , Pattern ):
179
219
180
- def fn (node , _ , ** kwargs ):
220
+ def fn (node , kwargs ):
181
221
ctx = context or {}
182
222
# need to first reconstruct the node from the possible rewritten
183
223
# children, so we can match on the new node containing the rewritten
184
224
# child arguments, this way we can propagate the rewritten nodes
185
- # upward in the hierarchy, using a specialized __recreate__ method
186
- # improves the performance by 17% compared node.__class__(**kwargs)
187
- recreated = node .__recreate__ (kwargs )
225
+ # upward in the hierarchy
226
+ recreated = node .__recreate__ (kwargs ) if kwargs else node
188
227
if (result := obj .match (recreated , ctx )) is NoMatch :
189
228
return recreated
190
- else :
191
- return result
229
+ return result
192
230
193
231
elif isinstance (obj , Mapping ):
194
232
195
- def fn (node , _ , ** kwargs ):
233
+ def fn (node , kwargs ):
234
+ # For a mapping we want to lookup the original node first, and
235
+ # return a recreated one from the children if it's not present
196
236
try :
197
237
return obj [node ]
198
238
except KeyError :
199
- return node .__recreate__ (kwargs )
239
+ return node .__recreate__ (kwargs ) if kwargs else node
200
240
elif callable (obj ):
201
241
fn = obj
202
242
else :
@@ -313,7 +353,7 @@ def map_clear(self, fn: Callable, filter: Optional[Finder] = None) -> Any:
313
353
if not dependents [dependency ]:
314
354
del results [dependency ]
315
355
316
- return results [ self ]
356
+ return results . get ( self , self )
317
357
318
358
@experimental
319
359
def map_nodes (self , fn : Callable , filter : Optional [Finder ] = None ) -> Any :
@@ -451,8 +491,9 @@ def replace(
451
491
Parameters
452
492
----------
453
493
replacer
454
- A `Pattern`, a `Mapping` or a callable which can be fed to
455
- `node.map()` directly to replace nodes.
494
+ A `Pattern`, `Mapping` or Callable taking the original unrewritten
495
+ node, and a mapping of attribute name to value of its rewritten
496
+ children (or None if no children were rewritten).
456
497
filter
457
498
A type, tuple of types, a pattern or a callable to filter out nodes
458
499
from the traversal. The traversal will only visit nodes that match
@@ -465,9 +506,28 @@ def replace(
465
506
The root node of the graph with the replaced nodes.
466
507
467
508
"""
468
- replacer = _coerce_replacer (replacer , context )
469
- results = self .map (replacer , filter = filter )
470
- return results .get (self , self )
509
+ replacements : dict [Node , Any ] = {}
510
+
511
+ fn = _coerce_replacer (replacer , context )
512
+
513
+ graph , _ = Graph .from_bfs (self , filter = filter ).toposort ()
514
+ for node in graph :
515
+ kwargs = {}
516
+ # Apply already rewritten nodes to the children of the node
517
+ changed = False
518
+ for k , v in zip (node .__argnames__ , node .__args__ ):
519
+ v , vchanged = _apply_replacements (v , replacements )
520
+ changed |= vchanged
521
+ kwargs [k ] = v
522
+
523
+ # Call the replacer on the node with any rewritten nodes (or None
524
+ # if unchanged).
525
+ result = fn (node , kwargs if changed else None )
526
+ if result is not node :
527
+ # The node is changed, store it in the mapping of replacements
528
+ replacements [node ] = result
529
+
530
+ return replacements .get (self , self )
471
531
472
532
473
533
class Graph (dict [Node , Sequence [Node ]]):
0 commit comments