Skip to content

Commit f3c255e

Browse files
committed
add default parent_ref
1 parent 5d4324b commit f3c255e

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

flax/linen/module.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,9 @@ class ParentDescriptor:
536536
logic applies in subclasses even after various dataclass transforms.
537537
"""
538538
def __get__(self, obj, objtype=None):
539+
# check if obj is None, happens during %autoreload
540+
if obj is None:
541+
return None
539542
parent = object.__getattribute__(obj, "_parent_ref")
540543
return parent() if isinstance(parent, weakref.ReferenceType) else parent
541544

@@ -610,8 +613,8 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
610613
cls._state = _uninitialized_module_internal_state
611614
cls.scope: Optional[Scope] = None
612615
# Handles weak referencing of parent Modules to prevent reference cycles.
613-
cls.parent = ParentDescriptor()
614616
cls._parent_ref = None
617+
cls.parent = ParentDescriptor()
615618

616619
@classmethod
617620
def _customized_dataclass_transform(cls):

0 commit comments

Comments
 (0)