Skip to content

Commit 53e766a

Browse files
committed
ensure threads are stopped in Progress().track
If Progress().track is used with a generator and an exception is raised within this generator the _TrackThread is not stopped and the program does not finish. This change ensures that the _TrackThreads are stopped correctly
1 parent 86418df commit 53e766a

File tree

4 files changed

+62
-6
lines changed

4 files changed

+62
-6
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [13.3.5] -
9+
10+
### Fixed
11+
12+
- Stop threads in `Progress().track` if used in generators and an exception is raised https://github.com/Textualize/rich/pull/2934
13+
814
## [13.3.4] - 2023-04-12
915

1016
### Fixed

CONTRIBUTORS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,5 @@ The following people have contributed to the development of Rich:
6868
- [Ke Sun](https://github.com/ksun212)
6969
- [Qiming Xu](https://github.com/xqm32)
7070
- [James Addison](https://github.com/jayaddison)
71-
- [Pierro](https://github.com/xpierroz)
71+
- [Pierro](https://github.com/xpierroz)
72+
- [Nicolas Ganz](https://github.com/ThunderKey)

rich/progress.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
NewType,
2828
Optional,
2929
Sequence,
30+
Set,
3031
TextIO,
3132
Tuple,
3233
Type,
@@ -86,6 +87,11 @@ def run(self) -> None:
8687

8788
self.progress.update(self.task_id, completed=self.completed, refresh=True)
8889

90+
def stop(self) -> None:
91+
"""Stop this progress update thread."""
92+
self.done.set()
93+
self.join()
94+
8995
def __enter__(self) -> "_TrackThread":
9096
self.start()
9197
return self
@@ -96,8 +102,7 @@ def __exit__(
96102
exc_val: Optional[BaseException],
97103
exc_tb: Optional[TracebackType],
98104
) -> None:
99-
self.done.set()
100-
self.join()
105+
self.stop()
101106

102107

103108
def track(
@@ -1098,6 +1103,7 @@ def __init__(
10981103
self.get_time = get_time or self.console.get_time
10991104
self.print = self.console.print
11001105
self.log = self.console.log
1106+
self._track_threads: Set[_TrackThread] = set()
11011107

11021108
@classmethod
11031109
def get_default_columns(cls) -> Tuple[ProgressColumn, ...]:
@@ -1161,6 +1167,9 @@ def start(self) -> None:
11611167

11621168
def stop(self) -> None:
11631169
"""Stop the progress display."""
1170+
for track_thread in self._track_threads:
1171+
track_thread.stop()
1172+
self._track_threads.clear()
11641173
self.live.stop()
11651174
if not self.console.is_interactive:
11661175
self.console.print()
@@ -1207,9 +1216,18 @@ def track(
12071216

12081217
if self.live.auto_refresh:
12091218
with _TrackThread(self, task_id, update_period) as track_thread:
1210-
for value in sequence:
1211-
yield value
1212-
track_thread.completed += 1
1219+
try:
1220+
self._track_threads.add(track_thread)
1221+
for value in sequence:
1222+
yield value
1223+
track_thread.completed += 1
1224+
finally:
1225+
try:
1226+
self._track_threads.remove(track_thread)
1227+
except KeyError:
1228+
# happens if the Progress was already stopped
1229+
# for example if there was an exception in a generator
1230+
pass
12131231
else:
12141232
advance = self.advance
12151233
refresh = self.refresh

tests/test_progress.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,37 @@ def test_progress_track() -> None:
338338
assert result == expected
339339

340340

341+
def test_progress_track_with_error_in_generator() -> None:
342+
console = Console(
343+
file=io.StringIO(),
344+
force_terminal=True,
345+
width=60,
346+
color_system="truecolor",
347+
legacy_windows=False,
348+
_environ={},
349+
)
350+
progress = Progress(console=console, get_time=MockClock(auto=True))
351+
expected_error = ValueError("Some Error")
352+
353+
def raise_value() -> None:
354+
raise expected_error
355+
356+
with pytest.raises(ValueError) as exc_info, progress:
357+
tuple(raise_value() for _ in progress.track(range(10)))
358+
359+
assert exc_info.value == expected_error
360+
result = console.file.getvalue()
361+
print(repr(result))
362+
expected = "\x1b[?25l\r\x1b[2KWorking... \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[35m 0%\x1b[0m \x1b[36m-:--:--\x1b[0m\r\x1b[2KWorking... \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[35m 0%\x1b[0m \x1b[36m-:--:--\x1b[0m\r\x1b[2KWorking... \x1b[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[35m 0%\x1b[0m \x1b[36m-:--:--\x1b[0m\n\x1b[?25h"
363+
364+
print(expected)
365+
print(repr(expected))
366+
print(result)
367+
print(repr(result))
368+
369+
assert result == expected
370+
371+
341372
def test_columns() -> None:
342373

343374
console = Console(

0 commit comments

Comments
 (0)