15
15
"""Frozen Dictionary."""
16
16
17
17
import collections
18
- from typing import Any , TypeVar , Mapping , Dict , Tuple
18
+ from typing import Any , TypeVar , Mapping , Dict , Tuple , Union
19
19
20
20
from flax import serialization
21
21
import jax
@@ -189,7 +189,7 @@ def freeze(xs: Mapping[Any, Any]) -> FrozenDict[Any, Any]:
189
189
return FrozenDict (xs )
190
190
191
191
192
- def unfreeze (x : FrozenDict [ Any , Any ]) -> Dict [Any , Any ]:
192
+ def unfreeze (x : Union [ FrozenDict , Dict [ str , Any ] ]) -> Dict [Any , Any ]:
193
193
"""Unfreeze a FrozenDict.
194
194
195
195
Makes a mutable copy of a `FrozenDict` mutable by transforming
@@ -205,7 +205,7 @@ def unfreeze(x: FrozenDict[Any, Any]) -> Dict[Any, Any]:
205
205
# the dict branch would also work here but
206
206
# it is much less performant because jax.tree_util.tree_map
207
207
# uses an optimized C implementation.
208
- return jax .tree_util .tree_map (lambda y : y , x ._dict ) # pylint: disable=protected-access
208
+ return jax .tree_util .tree_map (lambda y : y , x ._dict ) # type: ignore
209
209
elif isinstance (x , dict ):
210
210
ys = {}
211
211
for key , value in x .items ():
@@ -215,9 +215,10 @@ def unfreeze(x: FrozenDict[Any, Any]) -> Dict[Any, Any]:
215
215
return x
216
216
217
217
218
- def copy (x : Dict [str , Any ], add_or_replace : Dict [str , Any ]) -> Dict [str , Any ]:
218
+ def copy (x : Union [ FrozenDict , Dict [str , Any ]] , add_or_replace : Union [ FrozenDict , Dict [str , Any ]] ) -> Union [ FrozenDict , Dict [str , Any ] ]:
219
219
"""Create a new dict with additional and/or replaced entries. This is a utility
220
- function for regular dicts that mimics the behavior of `FrozenDict.copy`.
220
+ function that can act on either a FrozenDict or regular dict and mimics the
221
+ behavior of `FrozenDict.copy`.
221
222
222
223
Example::
223
224
@@ -230,12 +231,16 @@ def copy(x: Dict[str, Any], add_or_replace: Dict[str, Any]) -> Dict[str, Any]:
230
231
A new dict with the additional and/or replaced entries.
231
232
"""
232
233
233
- new_dict = jax .tree_map (lambda x : x , x ) # make a deep copy of dict x
234
- new_dict .update (add_or_replace )
235
- return new_dict
234
+ if isinstance (x , FrozenDict ):
235
+ return x .copy (add_or_replace )
236
+ elif isinstance (x , dict ):
237
+ new_dict = jax .tree_map (lambda x : x , x ) # make a deep copy of dict x
238
+ new_dict .update (add_or_replace )
239
+ return new_dict
240
+ raise TypeError (f'Expected FrozenDict or dict, got { type (x )} ' )
236
241
237
242
238
- def pop (x : Dict [str , Any ], key : str ) -> Tuple [Dict [str , Any ], Any ]:
243
+ def pop (x : Union [ FrozenDict , Dict [str , Any ]] , key : str ) -> Tuple [Union [ FrozenDict , Dict [str , Any ] ], Any ]:
239
244
"""Create a new dict where one entry is removed. This is a utility
240
245
function for regular dicts that mimics the behavior of `FrozenDict.pop`.
241
246
@@ -250,9 +255,13 @@ def pop(x: Dict[str, Any], key: str) -> Tuple[Dict[str, Any], Any]:
250
255
A pair with the new dict and the removed value.
251
256
"""
252
257
253
- new_dict = jax .tree_map (lambda x : x , x ) # make a deep copy of dict x
254
- value = new_dict .pop (key )
255
- return new_dict , value
258
+ if isinstance (x , FrozenDict ):
259
+ return x .pop (key )
260
+ elif isinstance (x , dict ):
261
+ new_dict = jax .tree_map (lambda x : x , x ) # make a deep copy of dict x
262
+ value = new_dict .pop (key )
263
+ return new_dict , value
264
+ raise TypeError (f'Expected FrozenDict or dict, got { type (x )} ' )
256
265
257
266
258
267
def _frozen_dict_state_dict (xs ):
0 commit comments