Skip to content

make utility functions work on FrozenDicts and regular dicts #2919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions flax/core/frozen_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Frozen Dictionary."""

import collections
from typing import Any, TypeVar, Mapping, Dict, Tuple
from typing import Any, TypeVar, Mapping, Dict, Tuple, Union

from flax import serialization
import jax
Expand Down Expand Up @@ -189,7 +189,7 @@ def freeze(xs: Mapping[Any, Any]) -> FrozenDict[Any, Any]:
return FrozenDict(xs)


def unfreeze(x: FrozenDict[Any, Any]) -> Dict[Any, Any]:
def unfreeze(x: Union[FrozenDict, Dict[str, Any]]) -> Dict[Any, Any]:
"""Unfreeze a FrozenDict.

Makes a mutable copy of a `FrozenDict` mutable by transforming
Expand All @@ -205,7 +205,7 @@ def unfreeze(x: FrozenDict[Any, Any]) -> Dict[Any, Any]:
# the dict branch would also work here but
# it is much less performant because jax.tree_util.tree_map
# uses an optimized C implementation.
return jax.tree_util.tree_map(lambda y: y, x._dict) # pylint: disable=protected-access
return jax.tree_util.tree_map(lambda y: y, x._dict) # type: ignore
elif isinstance(x, dict):
ys = {}
for key, value in x.items():
Expand All @@ -215,9 +215,10 @@ def unfreeze(x: FrozenDict[Any, Any]) -> Dict[Any, Any]:
return x


def copy(x: Dict[str, Any], add_or_replace: Dict[str, Any]) -> Dict[str, Any]:
def copy(x: Union[FrozenDict, Dict[str, Any]], add_or_replace: Union[FrozenDict, Dict[str, Any]]) -> Union[FrozenDict, Dict[str, Any]]:
"""Create a new dict with additional and/or replaced entries. This is a utility
function for regular dicts that mimics the behavior of `FrozenDict.copy`.
function that can act on either a FrozenDict or regular dict and mimics the
behavior of `FrozenDict.copy`.

Example::

Expand All @@ -230,12 +231,16 @@ def copy(x: Dict[str, Any], add_or_replace: Dict[str, Any]) -> Dict[str, Any]:
A new dict with the additional and/or replaced entries.
"""

new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x
new_dict.update(add_or_replace)
return new_dict
if isinstance(x, FrozenDict):
return x.copy(add_or_replace)
elif isinstance(x, dict):
new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x
new_dict.update(add_or_replace)
return new_dict
raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')


def pop(x: Dict[str, Any], key: str) -> Tuple[Dict[str, Any], Any]:
def pop(x: Union[FrozenDict, Dict[str, Any]], key: str) -> Tuple[Union[FrozenDict, Dict[str, Any]], Any]:
"""Create a new dict where one entry is removed. This is a utility
function for regular dicts that mimics the behavior of `FrozenDict.pop`.

Expand All @@ -250,9 +255,13 @@ def pop(x: Dict[str, Any], key: str) -> Tuple[Dict[str, Any], Any]:
A pair with the new dict and the removed value.
"""

new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x
value = new_dict.pop(key)
return new_dict, value
if isinstance(x, FrozenDict):
return x.pop(key)
elif isinstance(x, dict):
new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x
value = new_dict.pop(key)
return new_dict, value
raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')


def _frozen_dict_state_dict(xs):
Expand Down
46 changes: 34 additions & 12 deletions tests/core/core_frozen_dict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import jax


from absl.testing import absltest
from absl.testing import absltest, parameterized


class FrozenDictTest(absltest.TestCase):
class FrozenDictTest(parameterized.TestCase):

def test_frozen_dict_copies(self):
xs = {'a': 1, 'b': {'c': 2}}
Expand Down Expand Up @@ -84,16 +84,38 @@ def test_frozen_dict_copy_reserved_name(self):
result = FrozenDict({'a': 1}).copy({'cls': 2})
self.assertEqual(result, {'a': 1, 'cls': 2})

def test_utility_pop(self):
x = {'a': 1, 'b': {'c': 2}}
new_x, value = pop(x, 'b')
self.assertEqual(new_x, {'a': 1})
self.assertEqual(value, {'c': 2})

def test_utility_copy(self):
x = {'a': 1, 'b': {'c': 2}}
new_x = copy(x, add_or_replace={'b': {'c': -1, 'd': 3}})
self.assertEqual(new_x, {'a': 1, 'b': {'c': -1, 'd': 3}})
@parameterized.parameters(
{
'x': {'a': 1, 'b': {'c': 2}},
'key': 'b',
'actual_new_x': {'a': 1},
'actual_value': {'c': 2}
}, {
'x': FrozenDict({'a': 1, 'b': {'c': 2}}),
'key': 'b',
'actual_new_x': FrozenDict({'a': 1}),
'actual_value': FrozenDict({'c': 2})
},
)
def test_utility_pop(self, x, key, actual_new_x, actual_value):
new_x, value = pop(x, key)
self.assertTrue(new_x == actual_new_x and isinstance(new_x, type(actual_new_x)))
self.assertTrue(value == actual_value and isinstance(value, type(actual_value)))

@parameterized.parameters(
{
'x': {'a': 1, 'b': {'c': 2}},
'add_or_replace': {'b': {'c': -1, 'd': 3}},
'actual_new_x': {'a': 1, 'b': {'c': -1, 'd': 3}},
}, {
'x': FrozenDict({'a': 1, 'b': {'c': 2}}),
'add_or_replace': FrozenDict({'b': {'c': -1, 'd': 3}}),
'actual_new_x': FrozenDict({'a': 1, 'b': {'c': -1, 'd': 3}}),
},
)
def test_utility_copy(self, x, add_or_replace, actual_new_x):
new_x = copy(x, add_or_replace=add_or_replace)
self.assertTrue(new_x == actual_new_x and isinstance(new_x, type(actual_new_x)))


if __name__ == '__main__':
Expand Down