|
| 1 | +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. |
| 2 | + |
| 3 | +import logging |
| 4 | +import time |
| 5 | +from typing import Any, Generator, Iterable, List, Mapping, Optional, Set |
| 6 | + |
| 7 | +from airbyte_cdk import StreamSlice |
| 8 | +from airbyte_cdk.logger import lazy_log |
| 9 | +from airbyte_cdk.models import FailureType |
| 10 | +from airbyte_cdk.sources.declarative.async_job.job import AsyncJob |
| 11 | +from airbyte_cdk.sources.declarative.async_job.repository import AsyncJobRepository |
| 12 | +from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus |
| 13 | +from airbyte_cdk.utils.traced_exception import AirbyteTracedException |
| 14 | + |
| 15 | +LOGGER = logging.getLogger("airbyte") |
| 16 | + |
| 17 | + |
| 18 | +class AsyncPartition: |
| 19 | + """ |
| 20 | + This bucket of api_jobs is a bit useless for this iteration but should become interesting when we will be able to split jobs |
| 21 | + """ |
| 22 | + |
| 23 | + _MAX_NUMBER_OF_ATTEMPTS = 3 |
| 24 | + |
| 25 | + def __init__(self, jobs: List[AsyncJob], stream_slice: StreamSlice) -> None: |
| 26 | + self._attempts_per_job = {job: 0 for job in jobs} |
| 27 | + self._stream_slice = stream_slice |
| 28 | + |
| 29 | + def has_reached_max_attempt(self) -> bool: |
| 30 | + return any(map(lambda attempt_count: attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS, self._attempts_per_job.values())) |
| 31 | + |
| 32 | + def replace_job(self, job_to_replace: AsyncJob, new_jobs: List[AsyncJob]) -> None: |
| 33 | + current_attempt_count = self._attempts_per_job.pop(job_to_replace, None) |
| 34 | + if current_attempt_count is None: |
| 35 | + raise ValueError("Could not find job to replace") |
| 36 | + elif current_attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS: |
| 37 | + raise ValueError(f"Max attempt reached for job in partition {self._stream_slice}") |
| 38 | + |
| 39 | + new_attempt_count = current_attempt_count + 1 |
| 40 | + for job in new_jobs: |
| 41 | + self._attempts_per_job[job] = new_attempt_count |
| 42 | + |
| 43 | + def should_split(self, job: AsyncJob) -> bool: |
| 44 | + """ |
| 45 | + Not used right now but once we support job split, we should split based on the number of attempts |
| 46 | + """ |
| 47 | + return False |
| 48 | + |
| 49 | + @property |
| 50 | + def jobs(self) -> Iterable[AsyncJob]: |
| 51 | + return self._attempts_per_job.keys() |
| 52 | + |
| 53 | + @property |
| 54 | + def stream_slice(self) -> StreamSlice: |
| 55 | + return self._stream_slice |
| 56 | + |
| 57 | + @property |
| 58 | + def status(self) -> AsyncJobStatus: |
| 59 | + """ |
| 60 | + Given different job statuses, the priority is: FAILED, TIMED_OUT, RUNNING. Else, it means everything is completed. |
| 61 | + """ |
| 62 | + statuses = set(map(lambda job: job.status(), self.jobs)) |
| 63 | + if statuses == {AsyncJobStatus.COMPLETED}: |
| 64 | + return AsyncJobStatus.COMPLETED |
| 65 | + elif AsyncJobStatus.FAILED in statuses: |
| 66 | + return AsyncJobStatus.FAILED |
| 67 | + elif AsyncJobStatus.TIMED_OUT in statuses: |
| 68 | + return AsyncJobStatus.TIMED_OUT |
| 69 | + else: |
| 70 | + return AsyncJobStatus.RUNNING |
| 71 | + |
| 72 | + def __repr__(self) -> str: |
| 73 | + return f"AsyncPartition(stream_slice={self._stream_slice}, attempt_per_job={self._attempts_per_job})" |
| 74 | + |
| 75 | + |
| 76 | +class AsyncJobOrchestrator: |
| 77 | + _WAIT_TIME_BETWEEN_STATUS_UPDATE_IN_SECONDS = 5 |
| 78 | + |
| 79 | + def __init__( |
| 80 | + self, |
| 81 | + job_repository: AsyncJobRepository, |
| 82 | + slices: Iterable[StreamSlice], |
| 83 | + number_of_retries: Optional[int] = None, |
| 84 | + ) -> None: |
| 85 | + self._job_repository: AsyncJobRepository = job_repository |
| 86 | + self._slice_iterator = iter(slices) |
| 87 | + self._running_partitions: List[AsyncPartition] = [] |
| 88 | + |
| 89 | + def _replace_failed_jobs(self, partition: AsyncPartition) -> None: |
| 90 | + failed_status_jobs = (AsyncJobStatus.FAILED, AsyncJobStatus.TIMED_OUT) |
| 91 | + jobs_to_replace = [job for job in partition.jobs if job.status() in failed_status_jobs] |
| 92 | + for job in jobs_to_replace: |
| 93 | + new_job = self._job_repository.start(job.job_parameters()) |
| 94 | + partition.replace_job(job, [new_job]) |
| 95 | + |
| 96 | + def _start_jobs(self) -> None: |
| 97 | + """ |
| 98 | + Retry failed jobs and start jobs for each slice in the slice iterator. |
| 99 | + This method iterates over the running jobs and slice iterator and starts a job for each slice. |
| 100 | + The started jobs are added to the running partitions. |
| 101 | + Returns: |
| 102 | + None |
| 103 | +
|
| 104 | + TODO Eventually, we need to cap the number of concurrent jobs. |
| 105 | + However, the first iteration is for sendgrid which only has one job. |
| 106 | + """ |
| 107 | + for partition in self._running_partitions: |
| 108 | + self._replace_failed_jobs(partition) |
| 109 | + |
| 110 | + for _slice in self._slice_iterator: |
| 111 | + job = self._job_repository.start(_slice) |
| 112 | + self._running_partitions.append(AsyncPartition([job], _slice)) |
| 113 | + |
| 114 | + def _get_running_jobs(self) -> Set[AsyncJob]: |
| 115 | + """ |
| 116 | + Returns a set of running AsyncJob objects. |
| 117 | +
|
| 118 | + Returns: |
| 119 | + Set[AsyncJob]: A set of AsyncJob objects that are currently running. |
| 120 | + """ |
| 121 | + return {job for partition in self._running_partitions for job in partition.jobs if job.status() == AsyncJobStatus.RUNNING} |
| 122 | + |
| 123 | + def _update_jobs_status(self) -> None: |
| 124 | + """ |
| 125 | + Update the status of all running jobs in the repository. |
| 126 | + """ |
| 127 | + running_jobs = self._get_running_jobs() |
| 128 | + if running_jobs: |
| 129 | + # update the status only if there are RUNNING jobs |
| 130 | + self._job_repository.update_jobs_status(running_jobs) |
| 131 | + |
| 132 | + def _wait_on_status_update(self) -> None: |
| 133 | + """ |
| 134 | + Waits for a specified amount of time between status updates. |
| 135 | +
|
| 136 | +
|
| 137 | + This method is used to introduce a delay between status updates in order to avoid excessive polling. |
| 138 | + The duration of the delay is determined by the value of `_WAIT_TIME_BETWEEN_STATUS_UPDATE_IN_SECONDS`. |
| 139 | +
|
| 140 | + Returns: |
| 141 | + None |
| 142 | + """ |
| 143 | + lazy_log( |
| 144 | + LOGGER, |
| 145 | + logging.DEBUG, |
| 146 | + lambda: f"Polling status in progress. There are currently {len(self._running_partitions)} running partitions.", |
| 147 | + ) |
| 148 | + |
| 149 | + # wait only when there are running partitions |
| 150 | + if self._running_partitions: |
| 151 | + lazy_log( |
| 152 | + LOGGER, |
| 153 | + logging.DEBUG, |
| 154 | + lambda: f"Waiting for {self._WAIT_TIME_BETWEEN_STATUS_UPDATE_IN_SECONDS} seconds before next poll...", |
| 155 | + ) |
| 156 | + time.sleep(self._WAIT_TIME_BETWEEN_STATUS_UPDATE_IN_SECONDS) |
| 157 | + |
| 158 | + def _process_completed_partition(self, partition: AsyncPartition) -> None: |
| 159 | + """ |
| 160 | + Process a completed partition. |
| 161 | + Args: |
| 162 | + partition (AsyncPartition): The completed partition to process. |
| 163 | + """ |
| 164 | + job_ids = list(map(lambda job: job.api_job_id(), {job for job in partition.jobs})) |
| 165 | + LOGGER.info(f"The following jobs for stream slice {partition.stream_slice} have been completed: {job_ids}.") |
| 166 | + |
| 167 | + def _process_running_partitions_and_yield_completed_ones(self) -> Generator[AsyncPartition, Any, None]: |
| 168 | + """ |
| 169 | + Process the running partitions. |
| 170 | +
|
| 171 | + Yields: |
| 172 | + AsyncPartition: The processed partition. |
| 173 | +
|
| 174 | + Raises: |
| 175 | + Any: Any exception raised during processing. |
| 176 | + """ |
| 177 | + current_running_partitions: List[AsyncPartition] = [] |
| 178 | + for partition in self._running_partitions: |
| 179 | + match partition.status: |
| 180 | + case AsyncJobStatus.COMPLETED: |
| 181 | + self._process_completed_partition(partition) |
| 182 | + yield partition |
| 183 | + case AsyncJobStatus.RUNNING: |
| 184 | + current_running_partitions.append(partition) |
| 185 | + case _ if partition.has_reached_max_attempt(): |
| 186 | + self._process_partitions_with_errors(partition) |
| 187 | + case _: |
| 188 | + # job will be restarted in `_start_job` |
| 189 | + current_running_partitions.insert(0, partition) |
| 190 | + # update the referenced list with running partitions |
| 191 | + self._running_partitions = current_running_partitions |
| 192 | + |
| 193 | + def _process_partitions_with_errors(self, partition: AsyncPartition) -> None: |
| 194 | + """ |
| 195 | + Process a partition with status errors (FAILED and TIMEOUT). |
| 196 | +
|
| 197 | + Args: |
| 198 | + partition (AsyncPartition): The partition to process. |
| 199 | + Returns: |
| 200 | + AirbyteTracedException: An exception indicating that at least one job could not be completed. |
| 201 | + Raises: |
| 202 | + AirbyteTracedException: If at least one job could not be completed. |
| 203 | + """ |
| 204 | + status_by_job_id = {job.api_job_id(): job.status() for job in partition.jobs} |
| 205 | + raise AirbyteTracedException( |
| 206 | + message=f"At least one job could not be completed. Job statuses were: {status_by_job_id}", |
| 207 | + failure_type=FailureType.system_error, |
| 208 | + ) |
| 209 | + |
| 210 | + def create_and_get_completed_partitions(self) -> Iterable[AsyncPartition]: |
| 211 | + """ |
| 212 | + Creates and retrieves completed partitions. |
| 213 | + This method continuously starts jobs, updates job status, processes running partitions, |
| 214 | + logs polling partitions, and waits for status updates. It yields completed partitions |
| 215 | + as they become available. |
| 216 | +
|
| 217 | + Returns: |
| 218 | + An iterable of completed partitions, represented as AsyncPartition objects. |
| 219 | + Each partition is wrapped in an Optional, allowing for None values. |
| 220 | + """ |
| 221 | + while True: |
| 222 | + self._start_jobs() |
| 223 | + if not self._running_partitions: |
| 224 | + break |
| 225 | + |
| 226 | + self._update_jobs_status() |
| 227 | + yield from self._process_running_partitions_and_yield_completed_ones() |
| 228 | + self._wait_on_status_update() |
| 229 | + |
| 230 | + def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]: |
| 231 | + """ |
| 232 | + Fetches records from the given partition's jobs. |
| 233 | +
|
| 234 | + Args: |
| 235 | + partition (AsyncPartition): The partition containing the jobs. |
| 236 | +
|
| 237 | + Yields: |
| 238 | + Iterable[Mapping[str, Any]]: The fetched records from the jobs. |
| 239 | + """ |
| 240 | + for job in partition.jobs: |
| 241 | + yield from self._job_repository.fetch_records(job) |
0 commit comments