Skip to content

Commit 8a9f4b9

Browse files
committed
✨ Add map_leaves
1 parent 6d900eb commit 8a9f4b9

File tree

3 files changed

+204
-1
lines changed

3 files changed

+204
-1
lines changed

src/nested_dict_tools/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
NestedMutableMappingNode,
1111
flatten_dict,
1212
get_deep,
13+
map_leaves,
1314
set_deep,
1415
unflatten_dict,
1516
)
@@ -24,6 +25,7 @@
2425
"NestedMutableMappingNode",
2526
"flatten_dict",
2627
"get_deep",
28+
"map_leaves",
2729
"set_deep",
2830
"unflatten_dict",
2931
]

src/nested_dict_tools/core.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
This code is licensed under the terms of the MIT license.
2020
"""
2121

22-
from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence
22+
from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping, Sequence
2323
from typing import Any, Literal, cast, overload
2424

2525
type NestedMapping[K, V] = Mapping[K, NestedMappingNode[K, V]]
@@ -205,3 +205,98 @@ def set_deep[K, V](d: NestedMutableMapping[K, Any], keys: Sequence[K], value: An
205205
sub_dict[key] = sub_dict = {}
206206

207207
sub_dict[keys[-1]] = value
208+
209+
210+
@overload
211+
def map_leaves[K, V, W](
212+
func: Callable[[V], W], nested_dict1: NestedMapping[K, V], /
213+
) -> NestedMapping[K, W]: ...
214+
215+
216+
@overload
217+
def map_leaves[K, V1, V2, W](
218+
func: Callable[[V1, V2], W],
219+
nested_dict1: NestedMapping[K, V1],
220+
nested_dict2: NestedMapping[K, V2],
221+
/,
222+
) -> NestedMapping[K, W]: ...
223+
224+
225+
@overload
226+
def map_leaves[K, V1, V2, V3, W](
227+
func: Callable[[V1, V2, V3], W],
228+
nested_dict1: NestedMapping[K, V1],
229+
nested_dict2: NestedMapping[K, V2],
230+
nested_dict3: NestedMapping[K, V3],
231+
/,
232+
) -> NestedMapping[K, W]: ...
233+
234+
235+
@overload
236+
def map_leaves[K, V1, V2, V3, V4, W](
237+
func: Callable[[V1, V2, V3, V4], W],
238+
nested_dict1: NestedMapping[K, V1],
239+
nested_dict2: NestedMapping[K, V2],
240+
nested_dict3: NestedMapping[K, V3],
241+
nested_dict4: NestedMapping[K, V4],
242+
/,
243+
) -> NestedMapping[K, W]: ...
244+
245+
246+
@overload
247+
def map_leaves[K, V1, V2, V3, V4, V5, W](
248+
func: Callable[[V1, V2, V3, V4, V5], W],
249+
nested_dict1: NestedMapping[K, V1],
250+
nested_dict2: NestedMapping[K, V2],
251+
nested_dict3: NestedMapping[K, V3],
252+
nested_dict4: NestedMapping[K, V4],
253+
nested_dict5: NestedMapping[K, V5],
254+
/,
255+
) -> NestedMapping[K, W]: ...
256+
257+
258+
@overload
259+
def map_leaves[K, W](
260+
func: Callable[..., W],
261+
nested_dict1: NestedMapping[K, Any],
262+
nested_dict2: NestedMapping[K, Any],
263+
nested_dict3: NestedMapping[K, Any],
264+
nested_dict4: NestedMapping[K, Any],
265+
nested_dict5: NestedMapping[K, Any],
266+
/,
267+
*nested_dicts: NestedMapping[K, Any],
268+
) -> NestedMapping[K, W]: ...
269+
270+
271+
def map_leaves[K, V, W](
272+
func: Callable[..., W],
273+
*nested_dicts: NestedMapping[K, V],
274+
) -> NestedMapping[K, W]:
275+
"""
276+
Apply the function to every leaf (non-mapping values) of the nested dictionaries.
277+
278+
If multiple nested dictionaries are passed, performs element-wise operations on their corresponding values at each key.
279+
280+
Args:
281+
func: Function to apply on the leaves.
282+
*nested_dicts: Nested dictionaries on which to apply the function.
283+
284+
Return:
285+
The result nested dictionary with mapped leaves.
286+
287+
>>> map_leaves(lambda x: x * 2, {"a": 1, "b": 2, "c": 3})
288+
{'a': 2, 'b': 4, 'c': 6}
289+
290+
>>> map_leaves(lambda x, y: x + y, {"a": 1, "b": 2}, {"a": 3, "b": 4})
291+
{'a': 4, 'b': 6}
292+
"""
293+
dict_res: NestedMapping[K, W] = {}
294+
dict1 = nested_dicts[0]
295+
for key in dict1:
296+
args = (d[key] for d in nested_dicts)
297+
if isinstance(dict1[key], Mapping):
298+
dict_res[key] = map_leaves(func, *cast(Iterator[NestedMapping[K, V]], args))
299+
else:
300+
dict_res[key] = func(*cast(Iterator[V], args))
301+
302+
return dict_res

tests/test_core.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests for core module."""
22

33
import math
4+
import operator
45
import time
56

67
import pytest
@@ -9,6 +10,7 @@
910
KeySeparatorCollisionError,
1011
flatten_dict,
1112
get_deep,
13+
map_leaves,
1214
set_deep,
1315
unflatten_dict,
1416
)
@@ -393,3 +395,107 @@ def test_set_deep_with_large_dictionary(self):
393395
for key in keys[:-1]:
394396
sub_dict = sub_dict[key]
395397
assert sub_dict[keys[-1]] == value
398+
399+
400+
class TestMapLeaves:
401+
# Function correctly maps single-level dictionary with single input dictionary
402+
def test_map_single_level_dict(self):
403+
input_dict = {"a": 1, "b": 2, "c": 3}
404+
expected = {"a": 2, "b": 4, "c": 6}
405+
result = map_leaves(lambda x: x * 2, input_dict)
406+
assert result == expected
407+
408+
# Empty input dictionary
409+
def test_map_empty_dict(self):
410+
input_dict = {}
411+
expected = {}
412+
result = map_leaves(lambda x: x * 2, input_dict)
413+
assert result == expected
414+
415+
# Function correctly maps nested dictionary with multiple input dictionaries
416+
def test_map_leaves_with_multiple_dicts(self):
417+
# Define the input dictionaries
418+
dict1 = {"a": 1, "b": {"c": 2, "d": 3}}
419+
dict2 = {"a": 4, "b": {"c": 5, "d": 6}}
420+
421+
# Expected output after applying the function
422+
expected_output = {"a": 5, "b": {"c": 7, "d": 9}}
423+
424+
# Call the map_leaves function and assert the result
425+
result = map_leaves(operator.add, dict1, dict2)
426+
assert result == expected_output
427+
428+
# Dictionary with mixed types (mappings and non-mappings) at same level
429+
def test_map_leaves_with_mixed_types(self):
430+
input_dict = {
431+
"a": {"b": 1, "c": {"d": 2}},
432+
"e": 3,
433+
"f": {"g": 4, "h": {"i": 5}},
434+
}
435+
expected_output = {
436+
"a": {"b": 2, "c": {"d": 4}},
437+
"e": 6,
438+
"f": {"g": 8, "h": {"i": 10}},
439+
}
440+
441+
result = map_leaves(lambda x: x * 2, input_dict)
442+
assert result == expected_output
443+
444+
# Dictionaries with different structures/missing keys
445+
def test_map_leaves_with_different_structures(self):
446+
dict1 = {"a": 1, "b": {"c": 2}}
447+
dict2 = {"a": 3, "b": {"d": 4}}
448+
449+
with pytest.raises(KeyError):
450+
map_leaves(operator.add, dict1, dict2)
451+
452+
# Deep recursion with many nested levels
453+
def test_map_leaves_deep_recursion(self):
454+
# Create a deeply nested dictionary
455+
depth = 100
456+
nested_dict = current_level = {}
457+
for i in range(depth):
458+
current_level[f"level_{i}"] = {}
459+
current_level = current_level[f"level_{i}"]
460+
current_level["value"] = 1
461+
462+
# Define a simple function to apply
463+
def increment(x):
464+
return x + 1
465+
466+
# Apply map_leaves with deep recursion
467+
result = map_leaves(increment, nested_dict)
468+
469+
# Verify the result
470+
current_level = result
471+
for i in range(depth):
472+
current_level = current_level[f"level_{i}"]
473+
assert current_level["value"] == 2
474+
475+
# Non-commutative operations with multiple dictionaries
476+
def test_non_commutative_operations(self):
477+
dict1 = {"a": 1, "b": 2}
478+
dict2 = {"a": 3, "b": 4}
479+
480+
expected_output = {"a": -2, "b": -2}
481+
result = map_leaves(operator.sub, dict1, dict2)
482+
483+
assert result == expected_output
484+
485+
# Function returning different type than input values
486+
def test_map_leaves_with_type_conversion(self):
487+
# Define a function that changes the type of the input value
488+
def to_string(x):
489+
return str(x)
490+
491+
# Input nested dictionary with integer values
492+
input_dict = {"a": 1, "b": {"c": 2, "d": 3}}
493+
494+
# Expected output where all integer values are converted to strings
495+
expected_output = {"a": "1", "b": {"c": "2", "d": "3"}}
496+
497+
# Apply map_leaves with the to_string function
498+
result = map_leaves(to_string, input_dict)
499+
500+
# Assert that the result matches the expected output
501+
assert result == expected_output

0 commit comments

Comments
 (0)