diff --git a/CHANGELOG.md b/CHANGELOG.md index 21386b2a1..6faf8b5b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [13.3.5] - + +### Fixed + +- Stop threads in `Progress().track` if used in generators and an exception is raised https://github.com/Textualize/rich/pull/2934 + ## [13.3.4] - 2023-04-12 ### Fixed diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index fe88b7893..f93b0b152 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -68,4 +68,5 @@ The following people have contributed to the development of Rich: - [Ke Sun](https://github.com/ksun212) - [Qiming Xu](https://github.com/xqm32) - [James Addison](https://github.com/jayaddison) -- [Pierro](https://github.com/xpierroz) \ No newline at end of file +- [Pierro](https://github.com/xpierroz) +- [Nicolas Ganz](https://github.com/ThunderKey) diff --git a/rich/progress.py b/rich/progress.py index 43c47eb98..b70c759a7 100644 --- a/rich/progress.py +++ b/rich/progress.py @@ -27,6 +27,7 @@ NewType, Optional, Sequence, + Set, TextIO, Tuple, Type, @@ -86,6 +87,11 @@ def run(self) -> None: self.progress.update(self.task_id, completed=self.completed, refresh=True) + def stop(self) -> None: + """Stop this progress update thread.""" + self.done.set() + self.join() + def __enter__(self) -> "_TrackThread": self.start() return self @@ -96,8 +102,7 @@ def __exit__( exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: - self.done.set() - self.join() + self.stop() def track( @@ -1098,6 +1103,7 @@ def __init__( self.get_time = get_time or self.console.get_time self.print = self.console.print self.log = self.console.log + self._track_threads: Set[_TrackThread] = set() @classmethod def get_default_columns(cls) -> Tuple[ProgressColumn, ...]: @@ -1161,6 +1167,9 @@ def start(self) -> None: def stop(self) -> None: """Stop the progress display.""" + for track_thread in self._track_threads: + track_thread.stop() + self._track_threads.clear() self.live.stop() if not self.console.is_interactive: self.console.print() @@ -1207,9 +1216,18 @@ def track( if self.live.auto_refresh: with _TrackThread(self, task_id, update_period) as track_thread: - for value in sequence: - yield value - track_thread.completed += 1 + try: + self._track_threads.add(track_thread) + for value in sequence: + yield value + track_thread.completed += 1 + finally: + try: + self._track_threads.remove(track_thread) + except KeyError: + # happens if the Progress was already stopped + # for example if there was an exception in a generator + pass else: advance = self.advance refresh = self.refresh diff --git a/tests/test_progress.py b/tests/test_progress.py index 6a336d347..6e73fa0f8 100644 --- a/tests/test_progress.py +++ b/tests/test_progress.py @@ -338,6 +338,37 @@ def test_progress_track() -> None: assert result == expected +def test_progress_track_with_error_in_generator() -> None: + console = Console( + file=io.StringIO(), + force_terminal=True, + width=60, + color_system="truecolor", + legacy_windows=False, + _environ={}, + ) + progress = Progress(console=console, get_time=MockClock(auto=True)) + expected_error = ValueError("Some Error") + + def raise_value() -> None: + raise expected_error + + with pytest.raises(ValueError) as exc_info, progress: + tuple(raise_value() for _ in progress.track(range(10))) + + assert exc_info.value == expected_error + result = console.file.getvalue() + print(repr(result)) + expected = "\x1b[?25l\r\x1b[2KWorking... \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[35m 0%\x1b[0m \x1b[36m-:--:--\x1b[0m\r\x1b[2KWorking... \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[35m 0%\x1b[0m \x1b[36m-:--:--\x1b[0m\r\x1b[2KWorking... \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[35m 0%\x1b[0m \x1b[36m-:--:--\x1b[0m\n\x1b[?25h" + + print(expected) + print(repr(expected)) + print(result) + print(repr(result)) + + assert result == expected + + def test_columns() -> None: console = Console(