@@ -523,6 +523,27 @@ def reimport(self, other: '_ModuleInternalState') -> None:
523
523
_ParentType = Union [Type ['Module' ], Type [Scope ], Type [_Sentinel ], None ]
524
524
525
525
526
+ class ParentDescriptor :
527
+ """Wraps parent module references in weak refs.
528
+
529
+ This prevents reference cycles from forming via parent links which can lead
530
+ to accidental OOMs in eager mode due to slow garbage collection as well as
531
+ spurious tracer leaks during jit compilation.
532
+
533
+ Note: "descriptors" are the underlying python mechanism for implementing
534
+ dynamic @property decorators. We need to use a raw descriptor instead of the
535
+ more common decorator in order to force that the appropriate getter/setter
536
+ logic applies in subclasses even after various dataclass transforms.
537
+ """
538
+ def __get__ (self , obj , objtype = None ):
539
+ parent = object .__getattribute__ (obj , "_parent_ref" )
540
+ return parent () if isinstance (parent , weakref .ReferenceType ) else parent
541
+
542
+ def __set__ (self , obj , value ):
543
+ maybe_weak = weakref .ref (value ) if isinstance (value , Module ) else value
544
+ object .__setattr__ (obj , "_parent_ref" , maybe_weak )
545
+
546
+
526
547
# Base Module definition.
527
548
# -----------------------------------------------------------------------------
528
549
@@ -588,6 +609,8 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
588
609
# Set empty class defaults.
589
610
cls ._state = _uninitialized_module_internal_state
590
611
cls .scope : Optional [Scope ] = None
612
+ # Handles weak referencing of parent Modules to prevent reference cycles.
613
+ cls .parent = ParentDescriptor ()
591
614
592
615
@classmethod
593
616
def _customized_dataclass_transform (cls ):
0 commit comments