Skip to content

Commit 6baf254

Browse files
authored
feat(cdk): add async job components (#45178)
1 parent 254f34a commit 6baf254

35 files changed

+1739
-39
lines changed

airbyte-cdk/python/airbyte_cdk/logger.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import logging
77
import logging.config
8-
from typing import Any, Mapping, Optional, Tuple
8+
from typing import Any, Callable, Mapping, Optional, Tuple
99

1010
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteMessageSerializer, Level, Type
1111
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
@@ -38,6 +38,14 @@ def init_logger(name: Optional[str] = None) -> logging.Logger:
3838
return logger
3939

4040

41+
def lazy_log(logger: logging.Logger, level: int, lazy_log_provider: Callable[[], str]) -> None:
42+
"""
43+
This method ensure that the processing of the log message is only done if the logger is enabled for the log level.
44+
"""
45+
if logger.isEnabledFor(level):
46+
logger.log(level, lazy_log_provider())
47+
48+
4149
class AirbyteLogFormatter(logging.Formatter):
4250
"""Output log records using AirbyteMessage"""
4351

airbyte-cdk/python/airbyte_cdk/sources/declarative/async_job/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
3+
4+
from datetime import timedelta
5+
from typing import Optional
6+
7+
from airbyte_cdk import StreamSlice
8+
from airbyte_cdk.sources.declarative.async_job.timer import Timer
9+
10+
from .status import AsyncJobStatus
11+
12+
13+
class AsyncJob:
14+
"""
15+
Description of an API job.
16+
17+
Note that the timer will only stop once `update_status` is called so the job might be completed on the API side but until we query for
18+
it and call `ApiJob.update_status`, `ApiJob.status` will not reflect the actual API side status.
19+
"""
20+
21+
def __init__(self, api_job_id: str, job_parameters: StreamSlice, timeout: Optional[timedelta] = None) -> None:
22+
self._api_job_id = api_job_id
23+
self._job_parameters = job_parameters
24+
self._status = AsyncJobStatus.RUNNING
25+
26+
timeout = timeout if timeout else timedelta(minutes=60)
27+
self._timer = Timer(timeout)
28+
self._timer.start()
29+
30+
def api_job_id(self) -> str:
31+
return self._api_job_id
32+
33+
def status(self) -> AsyncJobStatus:
34+
if self._timer.has_timed_out():
35+
return AsyncJobStatus.TIMED_OUT
36+
return self._status
37+
38+
def job_parameters(self) -> StreamSlice:
39+
return self._job_parameters
40+
41+
def update_status(self, status: AsyncJobStatus) -> None:
42+
if self._status != AsyncJobStatus.RUNNING and status == AsyncJobStatus.RUNNING:
43+
self._timer.start()
44+
elif status.is_terminal():
45+
self._timer.stop()
46+
47+
self._status = status
48+
49+
def __repr__(self) -> str:
50+
return f"AsyncJob(data={self.api_job_id()}, job_parameters={self.job_parameters()}, status={self.status()})"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
3+
import logging
4+
import time
5+
from typing import Any, Generator, Iterable, List, Mapping, Optional, Set
6+
7+
from airbyte_cdk import StreamSlice
8+
from airbyte_cdk.logger import lazy_log
9+
from airbyte_cdk.models import FailureType
10+
from airbyte_cdk.sources.declarative.async_job.job import AsyncJob
11+
from airbyte_cdk.sources.declarative.async_job.repository import AsyncJobRepository
12+
from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus
13+
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
14+
15+
LOGGER = logging.getLogger("airbyte")
16+
17+
18+
class AsyncPartition:
19+
"""
20+
This bucket of api_jobs is a bit useless for this iteration but should become interesting when we will be able to split jobs
21+
"""
22+
23+
_MAX_NUMBER_OF_ATTEMPTS = 3
24+
25+
def __init__(self, jobs: List[AsyncJob], stream_slice: StreamSlice) -> None:
26+
self._attempts_per_job = {job: 0 for job in jobs}
27+
self._stream_slice = stream_slice
28+
29+
def has_reached_max_attempt(self) -> bool:
30+
return any(map(lambda attempt_count: attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS, self._attempts_per_job.values()))
31+
32+
def replace_job(self, job_to_replace: AsyncJob, new_jobs: List[AsyncJob]) -> None:
33+
current_attempt_count = self._attempts_per_job.pop(job_to_replace, None)
34+
if current_attempt_count is None:
35+
raise ValueError("Could not find job to replace")
36+
elif current_attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS:
37+
raise ValueError(f"Max attempt reached for job in partition {self._stream_slice}")
38+
39+
new_attempt_count = current_attempt_count + 1
40+
for job in new_jobs:
41+
self._attempts_per_job[job] = new_attempt_count
42+
43+
def should_split(self, job: AsyncJob) -> bool:
44+
"""
45+
Not used right now but once we support job split, we should split based on the number of attempts
46+
"""
47+
return False
48+
49+
@property
50+
def jobs(self) -> Iterable[AsyncJob]:
51+
return self._attempts_per_job.keys()
52+
53+
@property
54+
def stream_slice(self) -> StreamSlice:
55+
return self._stream_slice
56+
57+
@property
58+
def status(self) -> AsyncJobStatus:
59+
"""
60+
Given different job statuses, the priority is: FAILED, TIMED_OUT, RUNNING. Else, it means everything is completed.
61+
"""
62+
statuses = set(map(lambda job: job.status(), self.jobs))
63+
if statuses == {AsyncJobStatus.COMPLETED}:
64+
return AsyncJobStatus.COMPLETED
65+
elif AsyncJobStatus.FAILED in statuses:
66+
return AsyncJobStatus.FAILED
67+
elif AsyncJobStatus.TIMED_OUT in statuses:
68+
return AsyncJobStatus.TIMED_OUT
69+
else:
70+
return AsyncJobStatus.RUNNING
71+
72+
def __repr__(self) -> str:
73+
return f"AsyncPartition(stream_slice={self._stream_slice}, attempt_per_job={self._attempts_per_job})"
74+
75+
76+
class AsyncJobOrchestrator:
77+
_WAIT_TIME_BETWEEN_STATUS_UPDATE_IN_SECONDS = 5
78+
79+
def __init__(
80+
self,
81+
job_repository: AsyncJobRepository,
82+
slices: Iterable[StreamSlice],
83+
number_of_retries: Optional[int] = None,
84+
) -> None:
85+
self._job_repository: AsyncJobRepository = job_repository
86+
self._slice_iterator = iter(slices)
87+
self._running_partitions: List[AsyncPartition] = []
88+
89+
def _replace_failed_jobs(self, partition: AsyncPartition) -> None:
90+
failed_status_jobs = (AsyncJobStatus.FAILED, AsyncJobStatus.TIMED_OUT)
91+
jobs_to_replace = [job for job in partition.jobs if job.status() in failed_status_jobs]
92+
for job in jobs_to_replace:
93+
new_job = self._job_repository.start(job.job_parameters())
94+
partition.replace_job(job, [new_job])
95+
96+
def _start_jobs(self) -> None:
97+
"""
98+
Retry failed jobs and start jobs for each slice in the slice iterator.
99+
This method iterates over the running jobs and slice iterator and starts a job for each slice.
100+
The started jobs are added to the running partitions.
101+
Returns:
102+
None
103+
104+
TODO Eventually, we need to cap the number of concurrent jobs.
105+
However, the first iteration is for sendgrid which only has one job.
106+
"""
107+
for partition in self._running_partitions:
108+
self._replace_failed_jobs(partition)
109+
110+
for _slice in self._slice_iterator:
111+
job = self._job_repository.start(_slice)
112+
self._running_partitions.append(AsyncPartition([job], _slice))
113+
114+
def _get_running_jobs(self) -> Set[AsyncJob]:
115+
"""
116+
Returns a set of running AsyncJob objects.
117+
118+
Returns:
119+
Set[AsyncJob]: A set of AsyncJob objects that are currently running.
120+
"""
121+
return {job for partition in self._running_partitions for job in partition.jobs if job.status() == AsyncJobStatus.RUNNING}
122+
123+
def _update_jobs_status(self) -> None:
124+
"""
125+
Update the status of all running jobs in the repository.
126+
"""
127+
running_jobs = self._get_running_jobs()
128+
if running_jobs:
129+
# update the status only if there are RUNNING jobs
130+
self._job_repository.update_jobs_status(running_jobs)
131+
132+
def _wait_on_status_update(self) -> None:
133+
"""
134+
Waits for a specified amount of time between status updates.
135+
136+
137+
This method is used to introduce a delay between status updates in order to avoid excessive polling.
138+
The duration of the delay is determined by the value of `_WAIT_TIME_BETWEEN_STATUS_UPDATE_IN_SECONDS`.
139+
140+
Returns:
141+
None
142+
"""
143+
lazy_log(
144+
LOGGER,
145+
logging.DEBUG,
146+
lambda: f"Polling status in progress. There are currently {len(self._running_partitions)} running partitions.",
147+
)
148+
149+
# wait only when there are running partitions
150+
if self._running_partitions:
151+
lazy_log(
152+
LOGGER,
153+
logging.DEBUG,
154+
lambda: f"Waiting for {self._WAIT_TIME_BETWEEN_STATUS_UPDATE_IN_SECONDS} seconds before next poll...",
155+
)
156+
time.sleep(self._WAIT_TIME_BETWEEN_STATUS_UPDATE_IN_SECONDS)
157+
158+
def _process_completed_partition(self, partition: AsyncPartition) -> None:
159+
"""
160+
Process a completed partition.
161+
Args:
162+
partition (AsyncPartition): The completed partition to process.
163+
"""
164+
job_ids = list(map(lambda job: job.api_job_id(), {job for job in partition.jobs}))
165+
LOGGER.info(f"The following jobs for stream slice {partition.stream_slice} have been completed: {job_ids}.")
166+
167+
def _process_running_partitions_and_yield_completed_ones(self) -> Generator[AsyncPartition, Any, None]:
168+
"""
169+
Process the running partitions.
170+
171+
Yields:
172+
AsyncPartition: The processed partition.
173+
174+
Raises:
175+
Any: Any exception raised during processing.
176+
"""
177+
current_running_partitions: List[AsyncPartition] = []
178+
for partition in self._running_partitions:
179+
match partition.status:
180+
case AsyncJobStatus.COMPLETED:
181+
self._process_completed_partition(partition)
182+
yield partition
183+
case AsyncJobStatus.RUNNING:
184+
current_running_partitions.append(partition)
185+
case _ if partition.has_reached_max_attempt():
186+
self._process_partitions_with_errors(partition)
187+
case _:
188+
# job will be restarted in `_start_job`
189+
current_running_partitions.insert(0, partition)
190+
# update the referenced list with running partitions
191+
self._running_partitions = current_running_partitions
192+
193+
def _process_partitions_with_errors(self, partition: AsyncPartition) -> None:
194+
"""
195+
Process a partition with status errors (FAILED and TIMEOUT).
196+
197+
Args:
198+
partition (AsyncPartition): The partition to process.
199+
Returns:
200+
AirbyteTracedException: An exception indicating that at least one job could not be completed.
201+
Raises:
202+
AirbyteTracedException: If at least one job could not be completed.
203+
"""
204+
status_by_job_id = {job.api_job_id(): job.status() for job in partition.jobs}
205+
raise AirbyteTracedException(
206+
message=f"At least one job could not be completed. Job statuses were: {status_by_job_id}",
207+
failure_type=FailureType.system_error,
208+
)
209+
210+
def create_and_get_completed_partitions(self) -> Iterable[AsyncPartition]:
211+
"""
212+
Creates and retrieves completed partitions.
213+
This method continuously starts jobs, updates job status, processes running partitions,
214+
logs polling partitions, and waits for status updates. It yields completed partitions
215+
as they become available.
216+
217+
Returns:
218+
An iterable of completed partitions, represented as AsyncPartition objects.
219+
Each partition is wrapped in an Optional, allowing for None values.
220+
"""
221+
while True:
222+
self._start_jobs()
223+
if not self._running_partitions:
224+
break
225+
226+
self._update_jobs_status()
227+
yield from self._process_running_partitions_and_yield_completed_ones()
228+
self._wait_on_status_update()
229+
230+
def fetch_records(self, partition: AsyncPartition) -> Iterable[Mapping[str, Any]]:
231+
"""
232+
Fetches records from the given partition's jobs.
233+
234+
Args:
235+
partition (AsyncPartition): The partition containing the jobs.
236+
237+
Yields:
238+
Iterable[Mapping[str, Any]]: The fetched records from the jobs.
239+
"""
240+
for job in partition.jobs:
241+
yield from self._job_repository.fetch_records(job)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
3+
from abc import abstractmethod
4+
from typing import Any, Iterable, Mapping, Set
5+
6+
from airbyte_cdk import StreamSlice
7+
from airbyte_cdk.sources.declarative.async_job.job import AsyncJob
8+
9+
10+
class AsyncJobRepository:
11+
@abstractmethod
12+
def start(self, stream_slice: StreamSlice) -> AsyncJob:
13+
pass
14+
15+
@abstractmethod
16+
def update_jobs_status(self, jobs: Set[AsyncJob]) -> None:
17+
pass
18+
19+
@abstractmethod
20+
def fetch_records(self, job: AsyncJob) -> Iterable[Mapping[str, Any]]:
21+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
3+
4+
from enum import Enum
5+
6+
_TERMINAL = True
7+
8+
9+
class AsyncJobStatus(Enum):
10+
RUNNING = ("RUNNING", not _TERMINAL)
11+
COMPLETED = ("COMPLETED", _TERMINAL)
12+
FAILED = ("FAILED", _TERMINAL)
13+
TIMED_OUT = ("TIMED_OUT", _TERMINAL)
14+
15+
def __init__(self, value: str, is_terminal: bool) -> None:
16+
self._value = value
17+
self._is_terminal = is_terminal
18+
19+
def is_terminal(self) -> bool:
20+
"""
21+
A status is terminal when a job status can't be updated anymore. For example if a job is completed, it will stay completed but a
22+
running job might because completed, failed or timed out.
23+
"""
24+
return self._is_terminal

0 commit comments

Comments
 (0)