|
3 | 3 |
|
4 | 4 | from abc import abstractmethod
|
5 | 5 | from collections import deque
|
6 |
| -from collections.abc import Iterable, Iterator, KeysView, Sequence |
| 6 | +from collections.abc import Iterable, Iterator, KeysView, Mapping, Sequence |
7 | 7 | from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar
|
8 | 8 |
|
9 | 9 | from ibis.common.bases import Hashable
|
@@ -233,24 +233,33 @@ def replace(
|
233 | 233 | -------
|
234 | 234 | The root node of the graph with the replaced nodes.
|
235 | 235 | """
|
236 |
| - pat = pattern(pat) |
237 |
| - ctx = context or {} |
238 | 236 |
|
239 |
| - def fn(node, _, **kwargs): |
240 |
| - # need to first reconstruct the node from the possible rewritten |
241 |
| - # children, so we can match on the new node containing the rewritten |
242 |
| - # child arguments, this way we can propagate the rewritten nodes |
243 |
| - # upward in the hierarchy |
244 |
| - # TODO(kszucs): add a __recreate__() method to the Node interface |
245 |
| - # with a default implementation that uses the __class__ constructor |
246 |
| - # which is supposed to provide an implementation for quick object |
247 |
| - # reconstruction (the __recreate__ implementation in grounds.py |
248 |
| - # should be sped up as well by totally avoiding the validation) |
249 |
| - recreated = node.__class__(**kwargs) |
250 |
| - if (result := pat.match(recreated, ctx)) is NoMatch: |
251 |
| - return recreated |
252 |
| - else: |
253 |
| - return result |
| 237 | + if isinstance(pat, Mapping): |
| 238 | + |
| 239 | + def fn(node, _, **kwargs): |
| 240 | + try: |
| 241 | + return pat[node] |
| 242 | + except KeyError: |
| 243 | + return node.__class__(**kwargs) |
| 244 | + else: |
| 245 | + pat = pattern(pat) |
| 246 | + ctx = context or {} |
| 247 | + |
| 248 | + def fn(node, _, **kwargs): |
| 249 | + # need to first reconstruct the node from the possible rewritten |
| 250 | + # children, so we can match on the new node containing the rewritten |
| 251 | + # child arguments, this way we can propagate the rewritten nodes |
| 252 | + # upward in the hierarchy |
| 253 | + # TODO(kszucs): add a __recreate__() method to the Node interface |
| 254 | + # with a default implementation that uses the __class__ constructor |
| 255 | + # which is supposed to provide an implementation for quick object |
| 256 | + # reconstruction (the __recreate__ implementation in grounds.py |
| 257 | + # should be sped up as well by totally avoiding the validation) |
| 258 | + recreated = node.__class__(**kwargs) |
| 259 | + if (result := pat.match(recreated, ctx)) is NoMatch: |
| 260 | + return recreated |
| 261 | + else: |
| 262 | + return result |
254 | 263 |
|
255 | 264 | results = self.map(fn, filter=filter)
|
256 | 265 | return results.get(self, self)
|
|
0 commit comments