diff --git a/flax/core/lift.py b/flax/core/lift.py index f99c9cb64..649f1da29 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -1491,6 +1491,8 @@ def rematted(variable_groups, rng_groups, *args, **kwargs): def _hashable_filter(x): """Hashable version of CollectionFilter.""" + if isinstance(x, str): + return (x,) if isinstance(x, Iterable): return tuple(x) # convert un-hashable list & sets to tuple if isinstance(x, DenyList):