Skip to content

ensure threads are stopped in Progress().track #2934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [13.3.5] -

### Fixed

- Stop threads in `Progress().track` if used in generators and an exception is raised https://github.com/Textualize/rich/pull/2934

## [13.3.4] - 2023-04-12

### Fixed
Expand Down
3 changes: 2 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@ The following people have contributed to the development of Rich:
- [Ke Sun](https://github.com/ksun212)
- [Qiming Xu](https://github.com/xqm32)
- [James Addison](https://github.com/jayaddison)
- [Pierro](https://github.com/xpierroz)
- [Pierro](https://github.com/xpierroz)
- [Nicolas Ganz](https://github.com/ThunderKey)
28 changes: 23 additions & 5 deletions rich/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
NewType,
Optional,
Sequence,
Set,
TextIO,
Tuple,
Type,
Expand Down Expand Up @@ -86,6 +87,11 @@ def run(self) -> None:

self.progress.update(self.task_id, completed=self.completed, refresh=True)

def stop(self) -> None:
"""Stop this progress update thread."""
self.done.set()
self.join()

def __enter__(self) -> "_TrackThread":
self.start()
return self
Expand All @@ -96,8 +102,7 @@ def __exit__(
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.done.set()
self.join()
self.stop()


def track(
Expand Down Expand Up @@ -1098,6 +1103,7 @@ def __init__(
self.get_time = get_time or self.console.get_time
self.print = self.console.print
self.log = self.console.log
self._track_threads: Set[_TrackThread] = set()

@classmethod
def get_default_columns(cls) -> Tuple[ProgressColumn, ...]:
Expand Down Expand Up @@ -1161,6 +1167,9 @@ def start(self) -> None:

def stop(self) -> None:
"""Stop the progress display."""
for track_thread in self._track_threads:
track_thread.stop()
self._track_threads.clear()
self.live.stop()
if not self.console.is_interactive:
self.console.print()
Expand Down Expand Up @@ -1207,9 +1216,18 @@ def track(

if self.live.auto_refresh:
with _TrackThread(self, task_id, update_period) as track_thread:
for value in sequence:
yield value
track_thread.completed += 1
try:
self._track_threads.add(track_thread)
for value in sequence:
yield value
track_thread.completed += 1
finally:
try:
self._track_threads.remove(track_thread)
except KeyError:
# happens if the Progress was already stopped
# for example if there was an exception in a generator
pass
else:
advance = self.advance
refresh = self.refresh
Expand Down
31 changes: 31 additions & 0 deletions tests/test_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,37 @@ def test_progress_track() -> None:
assert result == expected


def test_progress_track_with_error_in_generator() -> None:
console = Console(
file=io.StringIO(),
force_terminal=True,
width=60,
color_system="truecolor",
legacy_windows=False,
_environ={},
)
progress = Progress(console=console, get_time=MockClock(auto=True))
expected_error = ValueError("Some Error")

def raise_value() -> None:
raise expected_error

with pytest.raises(ValueError) as exc_info, progress:
tuple(raise_value() for _ in progress.track(range(10)))

assert exc_info.value == expected_error
result = console.file.getvalue()
print(repr(result))
expected = "\x1b[?25l\r\x1b[2KWorking... \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[35m 0%\x1b[0m \x1b[36m-:--:--\x1b[0m\r\x1b[2KWorking... \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[35m 0%\x1b[0m \x1b[36m-:--:--\x1b[0m\r\x1b[2KWorking... \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[35m 0%\x1b[0m \x1b[36m-:--:--\x1b[0m\n\x1b[?25h"

print(expected)
print(repr(expected))
print(result)
print(repr(result))

assert result == expected


def test_columns() -> None:

console = Console(
Expand Down