Skip to content

Commit 1558848

Browse files
committed
add enter parameter to research to allow traversing custom data types
1 parent d9a927b commit 1558848

File tree

2 files changed

+97
-34
lines changed

2 files changed

+97
-34
lines changed

boltons/iterutils.py

+75-34
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,15 @@ def split_iter(src, sep=None, maxsplit=None):
165165
sep_func = sep
166166
elif not is_scalar(sep):
167167
sep = frozenset(sep)
168-
sep_func = lambda x: x in sep
168+
def sep_func(x): return x in sep
169169
else:
170-
sep_func = lambda x: x == sep
170+
def sep_func(x): return x == sep
171171

172172
cur_group = []
173173
split_count = 0
174174
for s in src:
175175
if maxsplit is not None and split_count >= maxsplit:
176-
sep_func = lambda x: False
176+
def sep_func(x): return False
177177
if sep_func(s):
178178
if sep is None and not cur_group:
179179
# If sep is none, str.split() "groups" separators
@@ -229,7 +229,7 @@ def rstrip(iterable, strip_value=None):
229229
['Foo', 'Bar']
230230
231231
"""
232-
return list(rstrip_iter(iterable,strip_value))
232+
return list(rstrip_iter(iterable, strip_value))
233233

234234

235235
def rstrip_iter(iterable, strip_value=None):
@@ -253,7 +253,7 @@ def rstrip_iter(iterable, strip_value=None):
253253
else:
254254
broken = True
255255
break
256-
if not broken: # Return to caller here because the end of the
256+
if not broken: # Return to caller here because the end of the
257257
return # iterator has been reached
258258
yield from cache
259259
yield i
@@ -268,10 +268,10 @@ def strip(iterable, strip_value=None):
268268
['Foo', 'Bar', 'Bam']
269269
270270
"""
271-
return list(strip_iter(iterable,strip_value))
271+
return list(strip_iter(iterable, strip_value))
272272

273273

274-
def strip_iter(iterable,strip_value=None):
274+
def strip_iter(iterable, strip_value=None):
275275
"""Strips values from the beginning and end of an iterable. Stripped items
276276
will match the value of the argument strip_value. Functionality is
277277
analogous to that of the method str.strip. Returns a generator.
@@ -280,7 +280,7 @@ def strip_iter(iterable,strip_value=None):
280280
['Foo', 'Bar', 'Bam']
281281
282282
"""
283-
return rstrip_iter(lstrip_iter(iterable,strip_value),strip_value)
283+
return rstrip_iter(lstrip_iter(iterable, strip_value), strip_value)
284284

285285

286286
def chunked(src, size, count=None, **kw):
@@ -340,11 +340,12 @@ def chunked_iter(src, size, **kw):
340340
raise ValueError('got unexpected keyword arguments: %r' % kw.keys())
341341
if not src:
342342
return
343-
postprocess = lambda chk: chk
343+
344+
def postprocess(chk): return chk
344345
if isinstance(src, (str, bytes)):
345-
postprocess = lambda chk, _sep=type(src)(): _sep.join(chk)
346+
def postprocess(chk, _sep=type(src)()): return _sep.join(chk)
346347
if isinstance(src, bytes):
347-
postprocess = lambda chk: bytes(chk)
348+
def postprocess(chk): return bytes(chk)
348349
src_iter = iter(src)
349350
while True:
350351
cur_chunk = list(itertools.islice(src_iter, size))
@@ -385,15 +386,19 @@ def chunk_ranges(input_size, chunk_size, input_offset=0, overlap_size=0, align=F
385386
>>> list(chunk_ranges(input_offset=3, input_size=15, chunk_size=5, overlap_size=1, align=True))
386387
[(3, 5), (4, 9), (8, 13), (12, 17), (16, 18)]
387388
"""
388-
input_size = _validate_positive_int(input_size, 'input_size', strictly_positive=False)
389+
input_size = _validate_positive_int(
390+
input_size, 'input_size', strictly_positive=False)
389391
chunk_size = _validate_positive_int(chunk_size, 'chunk_size')
390-
input_offset = _validate_positive_int(input_offset, 'input_offset', strictly_positive=False)
391-
overlap_size = _validate_positive_int(overlap_size, 'overlap_size', strictly_positive=False)
392+
input_offset = _validate_positive_int(
393+
input_offset, 'input_offset', strictly_positive=False)
394+
overlap_size = _validate_positive_int(
395+
overlap_size, 'overlap_size', strictly_positive=False)
392396

393397
input_stop = input_offset + input_size
394398

395399
if align:
396-
initial_chunk_len = chunk_size - input_offset % (chunk_size - overlap_size)
400+
initial_chunk_len = chunk_size - \
401+
input_offset % (chunk_size - overlap_size)
397402
if initial_chunk_len != overlap_size:
398403
yield (input_offset, min(input_offset + initial_chunk_len, input_stop))
399404
if input_offset + initial_chunk_len >= input_stop:
@@ -479,7 +484,7 @@ def windowed_iter(src, size, fill=_UNSET):
479484
480485
With *fill* set, the iterator always yields a number of windows
481486
equal to the length of the *src* iterable.
482-
487+
483488
>>> windowed(range(4), 3, fill=None)
484489
[(0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
485490
@@ -495,17 +500,16 @@ def windowed_iter(src, size, fill=_UNSET):
495500
except StopIteration:
496501
return zip([])
497502
return zip(*tees)
498-
503+
499504
for i, t in enumerate(tees):
500-
for _ in range(i):
505+
for _ in range(i):
501506
try:
502507
next(t)
503508
except StopIteration:
504509
continue
505510
return zip_longest(*tees, fillvalue=fill)
506511

507512

508-
509513
def xfrange(stop, start=None, step=1.0):
510514
"""Same as :func:`frange`, but generator-based instead of returning a
511515
list.
@@ -726,21 +730,21 @@ def bucketize(src, key=bool, value_transform=None, key_filter=None):
726730
src = zip(key, src)
727731

728732
if isinstance(key, str):
729-
key_func = lambda x: getattr(x, key, x)
733+
def key_func(x): return getattr(x, key, x)
730734
elif callable(key):
731735
key_func = key
732736
elif isinstance(key, list):
733-
key_func = lambda x: x[0]
737+
def key_func(x): return x[0]
734738
else:
735739
raise TypeError('expected key to be callable or a string or a list')
736740

737741
if value_transform is None:
738-
value_transform = lambda x: x
742+
def value_transform(x): return x
739743
if not callable(value_transform):
740744
raise TypeError('expected callable value transform function')
741745
if isinstance(key, list):
742746
f = value_transform
743-
value_transform=lambda x: f(x[1])
747+
def value_transform(x): return f(x[1])
744748

745749
ret = {}
746750
for val in src:
@@ -807,11 +811,11 @@ def unique_iter(src, key=None):
807811
if not is_iterable(src):
808812
raise TypeError('expected an iterable, not %r' % type(src))
809813
if key is None:
810-
key_func = lambda x: x
814+
def key_func(x): return x
811815
elif callable(key):
812816
key_func = key
813817
elif isinstance(key, str):
814-
key_func = lambda x: getattr(x, key, x)
818+
def key_func(x): return getattr(x, key, x)
815819
else:
816820
raise TypeError('"key" expected a string or callable, not %r' % key)
817821
seen = set()
@@ -862,7 +866,7 @@ def redundant(src, key=None, groups=False):
862866
elif callable(key):
863867
key_func = key
864868
elif isinstance(key, (str, bytes)):
865-
key_func = lambda x: getattr(x, key, x)
869+
def key_func(x): return getattr(x, key, x)
866870
else:
867871
raise TypeError('"key" expected a string or callable, not %r' % key)
868872
seen = {} # key to first seen item
@@ -964,6 +968,7 @@ def flatten_iter(iterable):
964968
else:
965969
yield item
966970

971+
967972
def flatten(iterable):
968973
"""``flatten()`` returns a collapsed list of all the elements from
969974
*iterable* while collapsing any nested iterables.
@@ -1006,6 +1011,7 @@ def default_visit(path, key, value):
10061011
# print('visit(%r, %r, %r)' % (path, key, value))
10071012
return key, value
10081013

1014+
10091015
# enable the extreme: monkeypatching iterutils with a different default_visit
10101016
_orig_default_visit = default_visit
10111017

@@ -1128,6 +1134,9 @@ def remap(root, visit=default_visit, enter=default_enter, exit=default_exit,
11281134
callable. When set to ``False``, remap ignores any errors
11291135
raised by the *visit* callback. Items causing exceptions
11301136
are kept. See examples for more details.
1137+
trace (bool): Pass ``trace=True`` to print out the entire
1138+
traversal. Or pass a tuple of ``'visit'``, ``'enter'``,
1139+
or ``'exit'`` to print only the selected events.
11311140
11321141
remap is designed to cover the majority of cases with just the
11331142
*visit* callable. While passing in multiple callables is very
@@ -1156,6 +1165,15 @@ def remap(root, visit=default_visit, enter=default_enter, exit=default_exit,
11561165
if not callable(exit):
11571166
raise TypeError('exit expected callable, not: %r' % exit)
11581167
reraise_visit = kwargs.pop('reraise_visit', True)
1168+
trace = kwargs.pop('trace', ())
1169+
if trace is True:
1170+
trace = ('visit', 'enter', 'exit')
1171+
elif isinstance(trace, str):
1172+
trace = (trace,)
1173+
if not isinstance(trace, (tuple, list, set)):
1174+
raise TypeError('trace expected tuple of event names, not: %r' % trace)
1175+
trace_enter, trace_exit, trace_visit = 'enter' in trace, 'exit' in trace, 'visit' in trace
1176+
11591177
if kwargs:
11601178
raise TypeError('unexpected keyword arguments: %r' % kwargs.keys())
11611179

@@ -1168,14 +1186,23 @@ def remap(root, visit=default_visit, enter=default_enter, exit=default_exit,
11681186
key, new_parent, old_parent = value
11691187
id_value = id(old_parent)
11701188
path, new_items = new_items_stack.pop()
1189+
if trace_exit:
1190+
print(' .. remap exit:', path, '-', key, '-',
1191+
old_parent, '-', new_parent, '-', new_items)
11711192
value = exit(path, key, old_parent, new_parent, new_items)
1193+
if trace_exit:
1194+
print(' .. remap exit result:', value)
11721195
registry[id_value] = value
11731196
if not new_items_stack:
11741197
continue
11751198
elif id_value in registry:
11761199
value = registry[id_value]
11771200
else:
1201+
if trace_enter:
1202+
print(' .. remap enter:', path, '-', key, '-', value)
11781203
res = enter(path, key, value)
1204+
if trace_enter:
1205+
print(' .. remap enter result:', res)
11791206
try:
11801207
new_parent, new_items = res
11811208
except TypeError:
@@ -1191,21 +1218,29 @@ def remap(root, visit=default_visit, enter=default_enter, exit=default_exit,
11911218
stack.append((_REMAP_EXIT, (key, new_parent, value)))
11921219
if new_items:
11931220
stack.extend(reversed(list(new_items)))
1221+
if trace_enter:
1222+
print(' .. remap stack size now:', len(stack))
11941223
continue
11951224
if visit is _orig_default_visit:
11961225
# avoid function call overhead by inlining identity operation
11971226
visited_item = (key, value)
11981227
else:
11991228
try:
1229+
if trace_visit:
1230+
print(' .. remap visit:', path, '-', key, '-', value)
12001231
visited_item = visit(path, key, value)
12011232
except Exception:
12021233
if reraise_visit:
12031234
raise
12041235
visited_item = True
12051236
if visited_item is False:
1237+
if trace_visit:
1238+
print(' .. remap visit result: <drop>')
12061239
continue # drop
12071240
elif visited_item is True:
12081241
visited_item = (key, value)
1242+
if trace_visit:
1243+
print(' .. remap visit result:', visited_item)
12091244
# TODO: typecheck?
12101245
# raise TypeError('expected (key, value) from visit(),'
12111246
# ' not: %r' % visited_item)
@@ -1221,6 +1256,7 @@ class PathAccessError(KeyError, IndexError, TypeError):
12211256
representing what can occur when looking up a path in a nested
12221257
object.
12231258
"""
1259+
12241260
def __init__(self, exc, seg, path):
12251261
self.exc = exc
12261262
self.seg = seg
@@ -1296,7 +1332,7 @@ def get_path(root, path, default=_UNSET):
12961332
return cur
12971333

12981334

1299-
def research(root, query=lambda p, k, v: True, reraise=False):
1335+
def research(root, query=lambda p, k, v: True, reraise=False, enter=default_enter):
13001336
"""The :func:`research` function uses :func:`remap` to recurse over
13011337
any data nested in *root*, and find values which match a given
13021338
criterion, specified by the *query* callable.
@@ -1343,16 +1379,16 @@ def research(root, query=lambda p, k, v: True, reraise=False):
13431379
if not callable(query):
13441380
raise TypeError('query expected callable, not: %r' % query)
13451381

1346-
def enter(path, key, value):
1382+
def _enter(path, key, value):
13471383
try:
13481384
if query(path, key, value):
13491385
ret.append((path + (key,), value))
13501386
except Exception:
13511387
if reraise:
13521388
raise
1353-
return default_enter(path, key, value)
1389+
return enter(path, key, value)
13541390

1355-
remap(root, enter=enter)
1391+
remap(root, enter=_enter)
13561392
return ret
13571393

13581394

@@ -1383,6 +1419,7 @@ class GUIDerator:
13831419
detect a fork on next iteration and reseed accordingly.
13841420
13851421
"""
1422+
13861423
def __init__(self, size=24):
13871424
self.size = size
13881425
if size < 20 or size > 36:
@@ -1495,13 +1532,16 @@ def soft_sorted(iterable, first=None, last=None, key=None, reverse=False):
14951532
last = last or []
14961533
key = key or (lambda x: x)
14971534
seq = list(iterable)
1498-
other = [x for x in seq if not ((first and key(x) in first) or (last and key(x) in last))]
1535+
other = [x for x in seq if not (
1536+
(first and key(x) in first) or (last and key(x) in last))]
14991537
other.sort(key=key, reverse=reverse)
15001538

15011539
if first:
1502-
first = sorted([x for x in seq if key(x) in first], key=lambda x: first.index(key(x)))
1540+
first = sorted([x for x in seq if key(x) in first],
1541+
key=lambda x: first.index(key(x)))
15031542
if last:
1504-
last = sorted([x for x in seq if key(x) in last], key=lambda x: last.index(key(x)))
1543+
last = sorted([x for x in seq if key(x) in last],
1544+
key=lambda x: last.index(key(x)))
15051545
return first + other + last
15061546

15071547

@@ -1536,7 +1576,7 @@ def __lt__(self, other):
15361576
ret = obj < other
15371577
except TypeError:
15381578
ret = ((type(obj).__name__, id(type(obj)), obj)
1539-
< (type(other).__name__, id(type(other)), other))
1579+
< (type(other).__name__, id(type(other)), other))
15401580
return ret
15411581

15421582
if key is not None and not callable(key):
@@ -1545,6 +1585,7 @@ def __lt__(self, other):
15451585

15461586
return sorted(iterable, key=_Wrapper, reverse=reverse)
15471587

1588+
15481589
"""
15491590
May actually be faster to do an isinstance check for a str path
15501591

tests/test_iterutils.py

+22
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,28 @@ def broken_query(p, k, v):
396396
assert research(root, broken_query) == []
397397

398398

399+
def test_research_custom_enter():
400+
# see #368
401+
from types import SimpleNamespace as NS
402+
root = NS(
403+
a='a',
404+
b='b',
405+
c=NS(aa='aa') )
406+
407+
def query(path, key, value):
408+
return value.startswith('a')
409+
410+
def custom_enter(path, key, value):
411+
if isinstance(value, NS):
412+
return [], value.__dict__.items()
413+
return default_enter(path, key, value)
414+
415+
with pytest.raises(TypeError):
416+
research(root, query)
417+
assert research(root, query, enter=custom_enter) == [(('a',), 'a'), (('c', 'aa'), 'aa')]
418+
419+
420+
399421
def test_backoff_basic():
400422
from boltons.iterutils import backoff
401423

0 commit comments

Comments
 (0)