Skip to content

Commit 779a122

Browse files
author
Flax Authors
committed
Merge pull request #2919 from chiamp:dict_utility_fns
PiperOrigin-RevId: 513830417
2 parents be3c846 + 6597e4c commit 779a122

File tree

2 files changed

+55
-24
lines changed

2 files changed

+55
-24
lines changed

flax/core/frozen_dict.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Frozen Dictionary."""
1616

1717
import collections
18-
from typing import Any, TypeVar, Mapping, Dict, Tuple
18+
from typing import Any, TypeVar, Mapping, Dict, Tuple, Union
1919

2020
from flax import serialization
2121
import jax
@@ -189,7 +189,7 @@ def freeze(xs: Mapping[Any, Any]) -> FrozenDict[Any, Any]:
189189
return FrozenDict(xs)
190190

191191

192-
def unfreeze(x: FrozenDict[Any, Any]) -> Dict[Any, Any]:
192+
def unfreeze(x: Union[FrozenDict, Dict[str, Any]]) -> Dict[Any, Any]:
193193
"""Unfreeze a FrozenDict.
194194
195195
Makes a mutable copy of a `FrozenDict` mutable by transforming
@@ -205,7 +205,7 @@ def unfreeze(x: FrozenDict[Any, Any]) -> Dict[Any, Any]:
205205
# the dict branch would also work here but
206206
# it is much less performant because jax.tree_util.tree_map
207207
# uses an optimized C implementation.
208-
return jax.tree_util.tree_map(lambda y: y, x._dict) # pylint: disable=protected-access
208+
return jax.tree_util.tree_map(lambda y: y, x._dict) # type: ignore
209209
elif isinstance(x, dict):
210210
ys = {}
211211
for key, value in x.items():
@@ -215,9 +215,10 @@ def unfreeze(x: FrozenDict[Any, Any]) -> Dict[Any, Any]:
215215
return x
216216

217217

218-
def copy(x: Dict[str, Any], add_or_replace: Dict[str, Any]) -> Dict[str, Any]:
218+
def copy(x: Union[FrozenDict, Dict[str, Any]], add_or_replace: Union[FrozenDict, Dict[str, Any]]) -> Union[FrozenDict, Dict[str, Any]]:
219219
"""Create a new dict with additional and/or replaced entries. This is a utility
220-
function for regular dicts that mimics the behavior of `FrozenDict.copy`.
220+
function that can act on either a FrozenDict or regular dict and mimics the
221+
behavior of `FrozenDict.copy`.
221222
222223
Example::
223224
@@ -230,12 +231,16 @@ def copy(x: Dict[str, Any], add_or_replace: Dict[str, Any]) -> Dict[str, Any]:
230231
A new dict with the additional and/or replaced entries.
231232
"""
232233

233-
new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x
234-
new_dict.update(add_or_replace)
235-
return new_dict
234+
if isinstance(x, FrozenDict):
235+
return x.copy(add_or_replace)
236+
elif isinstance(x, dict):
237+
new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x
238+
new_dict.update(add_or_replace)
239+
return new_dict
240+
raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')
236241

237242

238-
def pop(x: Dict[str, Any], key: str) -> Tuple[Dict[str, Any], Any]:
243+
def pop(x: Union[FrozenDict, Dict[str, Any]], key: str) -> Tuple[Union[FrozenDict, Dict[str, Any]], Any]:
239244
"""Create a new dict where one entry is removed. This is a utility
240245
function for regular dicts that mimics the behavior of `FrozenDict.pop`.
241246
@@ -250,9 +255,13 @@ def pop(x: Dict[str, Any], key: str) -> Tuple[Dict[str, Any], Any]:
250255
A pair with the new dict and the removed value.
251256
"""
252257

253-
new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x
254-
value = new_dict.pop(key)
255-
return new_dict, value
258+
if isinstance(x, FrozenDict):
259+
return x.pop(key)
260+
elif isinstance(x, dict):
261+
new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x
262+
value = new_dict.pop(key)
263+
return new_dict, value
264+
raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')
256265

257266

258267
def _frozen_dict_state_dict(xs):

tests/core/core_frozen_dict_test.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import jax
1818

1919

20-
from absl.testing import absltest
20+
from absl.testing import absltest, parameterized
2121

2222

23-
class FrozenDictTest(absltest.TestCase):
23+
class FrozenDictTest(parameterized.TestCase):
2424

2525
def test_frozen_dict_copies(self):
2626
xs = {'a': 1, 'b': {'c': 2}}
@@ -84,16 +84,38 @@ def test_frozen_dict_copy_reserved_name(self):
8484
result = FrozenDict({'a': 1}).copy({'cls': 2})
8585
self.assertEqual(result, {'a': 1, 'cls': 2})
8686

87-
def test_utility_pop(self):
88-
x = {'a': 1, 'b': {'c': 2}}
89-
new_x, value = pop(x, 'b')
90-
self.assertEqual(new_x, {'a': 1})
91-
self.assertEqual(value, {'c': 2})
92-
93-
def test_utility_copy(self):
94-
x = {'a': 1, 'b': {'c': 2}}
95-
new_x = copy(x, add_or_replace={'b': {'c': -1, 'd': 3}})
96-
self.assertEqual(new_x, {'a': 1, 'b': {'c': -1, 'd': 3}})
87+
@parameterized.parameters(
88+
{
89+
'x': {'a': 1, 'b': {'c': 2}},
90+
'key': 'b',
91+
'actual_new_x': {'a': 1},
92+
'actual_value': {'c': 2}
93+
}, {
94+
'x': FrozenDict({'a': 1, 'b': {'c': 2}}),
95+
'key': 'b',
96+
'actual_new_x': FrozenDict({'a': 1}),
97+
'actual_value': FrozenDict({'c': 2})
98+
},
99+
)
100+
def test_utility_pop(self, x, key, actual_new_x, actual_value):
101+
new_x, value = pop(x, key)
102+
self.assertTrue(new_x == actual_new_x and isinstance(new_x, type(actual_new_x)))
103+
self.assertTrue(value == actual_value and isinstance(value, type(actual_value)))
104+
105+
@parameterized.parameters(
106+
{
107+
'x': {'a': 1, 'b': {'c': 2}},
108+
'add_or_replace': {'b': {'c': -1, 'd': 3}},
109+
'actual_new_x': {'a': 1, 'b': {'c': -1, 'd': 3}},
110+
}, {
111+
'x': FrozenDict({'a': 1, 'b': {'c': 2}}),
112+
'add_or_replace': FrozenDict({'b': {'c': -1, 'd': 3}}),
113+
'actual_new_x': FrozenDict({'a': 1, 'b': {'c': -1, 'd': 3}}),
114+
},
115+
)
116+
def test_utility_copy(self, x, add_or_replace, actual_new_x):
117+
new_x = copy(x, add_or_replace=add_or_replace)
118+
self.assertTrue(new_x == actual_new_x and isinstance(new_x, type(actual_new_x)))
97119

98120

99121
if __name__ == '__main__':

0 commit comments

Comments
 (0)