|
| 1 | +import time |
| 2 | +import os |
| 3 | +import datetime |
| 4 | +import pytz |
| 5 | +import threading |
| 6 | +from pathlib import Path |
| 7 | + |
| 8 | +from typing import Iterable |
| 9 | + |
| 10 | +import aimrocks.errors |
| 11 | + |
| 12 | +from aim import Repo |
| 13 | +from aim.sdk.run_status_watcher import Event |
| 14 | + |
| 15 | + |
| 16 | +class RunStatusManager: |
| 17 | + INDEXING_GRACE_PERIOD = 10 |
| 18 | + |
| 19 | + def __init__(self, repo: Repo, scan_interval: int = 60): |
| 20 | + self.repo = repo |
| 21 | + self.scan_interval = scan_interval |
| 22 | + |
| 23 | + self.progress_dir = Path(self.repo.path) / 'meta' / 'progress' |
| 24 | + self.progress_dir.mkdir(parents=True, exist_ok=True) |
| 25 | + |
| 26 | + self.heartbeat_dir = Path(self.repo.path) / 'check_ins' |
| 27 | + self.run_heartbeat_cache = {} |
| 28 | + |
| 29 | + self._stop_event = threading.Event() |
| 30 | + self._monitor_thread = None |
| 31 | + self._corrupted_runs = set() |
| 32 | + |
| 33 | + def start(self): |
| 34 | + if not self._monitor_thread or not self._monitor_thread.is_alive(): |
| 35 | + self._stop_event.clear() |
| 36 | + self._monitor_thread = threading.Thread(target=self._run_forever, daemon=True) |
| 37 | + self._monitor_thread.start() |
| 38 | + |
| 39 | + def stop(self): |
| 40 | + self._stop_event.set() |
| 41 | + if self._monitor_thread: |
| 42 | + self._monitor_thread.join() |
| 43 | + |
| 44 | + def _run_forever(self): |
| 45 | + while not self._stop_event.is_set(): |
| 46 | + self.check_and_terminate_stalled_runs() |
| 47 | + time.sleep(self.scan_interval) |
| 48 | + |
| 49 | + def _runs_with_progress(self) -> Iterable[str]: |
| 50 | + runs_with_progress = filter(lambda x: x not in self._corrupted_runs, os.listdir(self.progress_dir)) |
| 51 | + run_hashes = sorted(runs_with_progress, key=lambda r: os.path.getmtime(os.path.join(self.progress_dir, r))) |
| 52 | + return run_hashes |
| 53 | + |
| 54 | + def check_and_terminate_stalled_runs(self): |
| 55 | + for run_hash in self._runs_with_progress(): |
| 56 | + if self._is_run_stalled(run_hash): |
| 57 | + self._mark_run_as_terminated(run_hash) |
| 58 | + |
| 59 | + def _is_run_stalled(self, run_hash: str) -> bool: |
| 60 | + stalled = False |
| 61 | + |
| 62 | + heartbeat_files = list(sorted(self.heartbeat_dir.glob(f'{run_hash}-*-progress-*-*'), reverse=True)) |
| 63 | + if heartbeat_files: |
| 64 | + latest_file = heartbeat_files[0].name |
| 65 | + last_heartbeat = Event(latest_file) |
| 66 | + |
| 67 | + last_recorded_heartbeat = self.run_heartbeat_cache.get(run_hash) |
| 68 | + if last_recorded_heartbeat is None: |
| 69 | + # First time seeing a heartbeat for this run; store and move on |
| 70 | + self.run_heartbeat_cache[run_hash] = last_heartbeat |
| 71 | + elif last_heartbeat.idx > last_recorded_heartbeat.idx: |
| 72 | + # Newer heartbeat arrived, so the run isn't stalled |
| 73 | + self.run_heartbeat_cache[run_hash] = last_heartbeat |
| 74 | + else: |
| 75 | + # No new heartbeat event since last time; check if enough time passed |
| 76 | + time_passed = time.time() - last_recorded_heartbeat.detected_epoch_time |
| 77 | + if (last_recorded_heartbeat.next_event_in + RunStatusManager.INDEXING_GRACE_PERIOD) < time_passed: |
| 78 | + stalled = True |
| 79 | + else: |
| 80 | + stalled = True |
| 81 | + |
| 82 | + return stalled |
| 83 | + |
| 84 | + def _mark_run_as_terminated(self, run_hash: str): |
| 85 | + # TODO [AT]: Add run state handling once decided on terms (finished, terminated, aborted, etc.) |
| 86 | + try: |
| 87 | + meta_run_tree = self.repo.request_tree('meta', run_hash, read_only=False).subtree( |
| 88 | + ('meta', 'chunks', run_hash) |
| 89 | + ) |
| 90 | + if meta_run_tree.get('end_time') is None: |
| 91 | + meta_run_tree['end_time'] = datetime.datetime.now(pytz.utc).timestamp() |
| 92 | + progress_path = self.progress_dir / run_hash |
| 93 | + progress_path.unlink(missing_ok=True) |
| 94 | + except (aimrocks.errors.RocksIOError, aimrocks.errors.Corruption): |
| 95 | + self._corrupted_runs.add(run_hash) |
0 commit comments