Skip to content

Add itertoolz.flatten #547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions toolz/curried/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
drop = toolz.curry(toolz.drop)
excepts = toolz.curry(toolz.excepts)
filter = toolz.curry(toolz.filter)
flat = toolz.curry(toolz.flat)
get = toolz.curry(toolz.get)
get_in = toolz.curry(toolz.get_in)
groupby = toolz.curry(toolz.groupby)
Expand Down
14 changes: 13 additions & 1 deletion toolz/itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
'first', 'second', 'nth', 'last', 'get', 'concat', 'concatv',
'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', 'iterate',
'sliding_window', 'partition', 'partition_all', 'count', 'pluck',
'join', 'tail', 'diff', 'topk', 'peek', 'peekn', 'random_sample')
'join', 'tail', 'diff', 'topk', 'peek', 'peekn', 'random_sample',
'flat')


def remove(predicate, seq):
Expand Down Expand Up @@ -1055,3 +1056,14 @@ def random_sample(prob, seq, random_state=None):

random_state = Random(random_state)
return filter(lambda _: random_state.random() < prob, seq)


def flat(level, seq):
""" Flatten a possible nested sequence by n levels """
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fill out this docstring soon.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Don't forget to also point to concat(seq), which flattens a sequence one level.

if level < 0:
raise ValueError("level must be >= 0")
for item in seq:
if level == 0 or not hasattr(item, '__iter__'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably better to have outside the for loop:

if level == 0:
    yield from seq
    return

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! That works really well.

yield item
else:
yield from flat(level - 1, item)
17 changes: 16 additions & 1 deletion toolz/tests/test_itertoolz.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import itertools
from itertools import starmap
from toolz.utils import raises
Expand All @@ -13,7 +14,7 @@
reduceby, iterate, accumulate,
sliding_window, count, partition,
partition_all, take_nth, pluck, join,
diff, topk, peek, peekn, random_sample)
diff, topk, peek, peekn, random_sample, flat)
from operator import add, mul


Expand Down Expand Up @@ -547,3 +548,17 @@ def test_random_sample():
assert mk_rsample(b"a") == mk_rsample(u"a")

assert raises(TypeError, lambda: mk_rsample([]))


def test_flat():
seq = [1, 2, 3, 4]
assert list(flat(0, seq)) == seq
assert list(flat(1, seq)) == seq

seq = [1, [2, [3]]]
assert list(flat(0, seq)) == seq
assert list(flat(1, seq)) == [1, 2, [3]]
assert list(flat(2, seq)) == [1, 2, 3]

with pytest.raises(ValueError):
list(flat(-1, seq))