diff --git a/CHANGELOG.md b/CHANGELOG.md index cb1476956..490857b35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed `typing_extensions` from runtime dependencies https://github.com/Textualize/rich/pull/3763 - Live objects (including Progress) may now be nested https://github.com/Textualize/rich/pull/3768 +### Fixed + +- Fixed extraction of recursive exceptions https://github.com/Textualize/rich/pull/3772 + ## [14.0.0] - 2025-03-30 ### Added diff --git a/rich/traceback.py b/rich/traceback.py index b2cc63040..25d399a7a 100644 --- a/rich/traceback.py +++ b/rich/traceback.py @@ -14,6 +14,7 @@ List, Optional, Sequence, + Set, Tuple, Type, Union, @@ -418,6 +419,7 @@ def extract( locals_max_string: int = LOCALS_MAX_STRING, locals_hide_dunder: bool = True, locals_hide_sunder: bool = False, + _visited_exceptions: Optional[Set[BaseException]] = None, ) -> Trace: """Extract traceback information. @@ -443,6 +445,10 @@ def extract( notes: List[str] = getattr(exc_value, "__notes__", None) or [] + grouped_exceptions: Set[BaseException] = ( + set() if _visited_exceptions is None else _visited_exceptions + ) + def safe_str(_object: Any) -> str: """Don't allow exceptions from __str__ to propagate.""" try: @@ -462,6 +468,9 @@ def safe_str(_object: Any) -> str: if isinstance(exc_value, (BaseExceptionGroup, ExceptionGroup)): stack.is_group = True for exception in exc_value.exceptions: + if exception in grouped_exceptions: + continue + grouped_exceptions.add(exception) stack.exceptions.append( Traceback.extract( type(exception), @@ -471,6 +480,7 @@ def safe_str(_object: Any) -> str: locals_max_length=locals_max_length, locals_hide_dunder=locals_hide_dunder, locals_hide_sunder=locals_hide_sunder, + _visited_exceptions=grouped_exceptions, ) ) @@ -561,23 +571,26 @@ def get_locals( if frame_summary.f_locals.get("_rich_traceback_guard", False): del stack.frames[:] - cause = getattr(exc_value, "__cause__", None) - if cause: - exc_type = cause.__class__ - exc_value = cause - # __traceback__ can be None, e.g. for exceptions raised by the - # 'multiprocessing' module - traceback = cause.__traceback__ - is_cause = True - continue + if not grouped_exceptions: + cause = getattr(exc_value, "__cause__", None) + if cause is not None and cause is not exc_value: + exc_type = cause.__class__ + exc_value = cause + # __traceback__ can be None, e.g. for exceptions raised by the + # 'multiprocessing' module + traceback = cause.__traceback__ + is_cause = True + continue - cause = exc_value.__context__ - if cause and not getattr(exc_value, "__suppress_context__", False): - exc_type = cause.__class__ - exc_value = cause - traceback = cause.__traceback__ - is_cause = False - continue + cause = exc_value.__context__ + if cause is not None and not getattr( + exc_value, "__suppress_context__", False + ): + exc_type = cause.__class__ + exc_value = cause + traceback = cause.__traceback__ + is_cause = False + continue # No cover, code is reached but coverage doesn't recognize it. break # pragma: no cover diff --git a/tests/test_traceback.py b/tests/test_traceback.py index bc9bc91e9..bcae6920b 100644 --- a/tests/test_traceback.py +++ b/tests/test_traceback.py @@ -373,3 +373,27 @@ def test_notes() -> None: traceback = Traceback() assert traceback.trace.stacks[0].notes == ["Hello", "World"] + + +def test_recursive_exception() -> None: + """Regression test for https://github.com/Textualize/rich/issues/3708 + + Test this doesn't create an infinite loop. + + """ + console = Console() + + def foo() -> None: + try: + raise RuntimeError("Hello") + except Exception as e: + raise e from e + + def bar() -> None: + try: + foo() + except Exception as e: + assert e is e.__cause__ + console.print_exception(show_locals=True) + + bar()