Skip to content

Commit bbf69ae

Browse files
authored
Concurrent CDK: support partitioned states (#36811)
1 parent f29f7bb commit bbf69ae

21 files changed

+637
-461
lines changed

airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from airbyte_cdk.models import AirbyteMessage, AirbyteStreamStatus
88
from airbyte_cdk.models import Type as MessageType
99
from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel
10+
from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException
1011
from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager
1112
from airbyte_cdk.sources.message import MessageRepository
1213
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
@@ -17,7 +18,9 @@
1718
from airbyte_cdk.sources.streams.concurrent.partitions.types import PartitionCompleteSentinel
1819
from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message
1920
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
21+
from airbyte_cdk.utils import AirbyteTracedException
2022
from airbyte_cdk.utils.stream_status_utils import as_airbyte_message as stream_status_as_airbyte_message
23+
from airbyte_protocol.models import StreamDescriptor
2124

2225

2326
class ConcurrentReadProcessor:
@@ -56,6 +59,7 @@ def __init__(
5659
self._message_repository = message_repository
5760
self._partition_reader = partition_reader
5861
self._streams_done: Set[str] = set()
62+
self._exceptions_per_stream_name: dict[str, List[Exception]] = {}
5963

6064
def on_partition_generation_completed(self, sentinel: PartitionGenerationCompletedSentinel) -> Iterable[AirbyteMessage]:
6165
"""
@@ -126,14 +130,16 @@ def on_record(self, record: Record) -> Iterable[AirbyteMessage]:
126130
yield message
127131
yield from self._message_repository.consume_queue()
128132

129-
def on_exception(self, exception: Exception) -> Iterable[AirbyteMessage]:
133+
def on_exception(self, exception: StreamThreadException) -> Iterable[AirbyteMessage]:
130134
"""
131135
This method is called when an exception is raised.
132136
1. Stop all running streams
133137
2. Raise the exception
134138
"""
135-
yield from self._stop_streams()
136-
raise exception
139+
self._exceptions_per_stream_name.setdefault(exception.stream_name, []).append(exception.exception)
140+
yield AirbyteTracedException.from_exception(exception).as_airbyte_message(
141+
stream_descriptor=StreamDescriptor(name=exception.stream_name)
142+
)
137143

138144
def start_next_partition_generator(self) -> Optional[AirbyteMessage]:
139145
"""
@@ -177,13 +183,7 @@ def _on_stream_is_done(self, stream_name: str) -> Iterable[AirbyteMessage]:
177183
yield from self._message_repository.consume_queue()
178184
self._logger.info(f"Finished syncing {stream.name}")
179185
self._streams_done.add(stream_name)
180-
yield stream_status_as_airbyte_message(stream.as_airbyte_stream(), AirbyteStreamStatus.COMPLETE)
181-
182-
def _stop_streams(self) -> Iterable[AirbyteMessage]:
183-
self._thread_pool_manager.shutdown()
184-
for stream_name in self._streams_to_running_partitions.keys():
185-
stream = self._stream_name_to_instance[stream_name]
186-
if not self._is_stream_done(stream_name):
187-
self._logger.info(f"Marking stream {stream.name} as STOPPED")
188-
self._logger.info(f"Finished syncing {stream.name}")
189-
yield stream_status_as_airbyte_message(stream.as_airbyte_stream(), AirbyteStreamStatus.INCOMPLETE)
186+
stream_status = (
187+
AirbyteStreamStatus.INCOMPLETE if self._exceptions_per_stream_name.get(stream_name, []) else AirbyteStreamStatus.COMPLETE
188+
)
189+
yield stream_status_as_airbyte_message(stream.as_airbyte_stream(), stream_status)

airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/concurrent_source.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from airbyte_cdk.models import AirbyteMessage
1010
from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor
1111
from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel
12+
from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException
1213
from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager
1314
from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository
1415
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
@@ -123,11 +124,6 @@ def _consume_from_queue(
123124
concurrent_stream_processor: ConcurrentReadProcessor,
124125
) -> Iterable[AirbyteMessage]:
125126
while airbyte_message_or_record_or_exception := queue.get():
126-
try:
127-
self._threadpool.shutdown_if_exception()
128-
except Exception as exception:
129-
concurrent_stream_processor.on_exception(exception)
130-
131127
yield from self._handle_item(
132128
airbyte_message_or_record_or_exception,
133129
concurrent_stream_processor,
@@ -142,7 +138,7 @@ def _handle_item(
142138
concurrent_stream_processor: ConcurrentReadProcessor,
143139
) -> Iterable[AirbyteMessage]:
144140
# handle queue item and call the appropriate handler depending on the type of the queue item
145-
if isinstance(queue_item, Exception):
141+
if isinstance(queue_item, StreamThreadException):
146142
yield from concurrent_stream_processor.on_exception(queue_item)
147143
elif isinstance(queue_item, PartitionGenerationCompletedSentinel):
148144
yield from concurrent_stream_processor.on_partition_generation_completed(queue_item)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
3+
from typing import Any
4+
5+
6+
class StreamThreadException(Exception):
7+
def __init__(self, exception: Exception, stream_name: str):
8+
self._exception = exception
9+
self._stream_name = stream_name
10+
11+
@property
12+
def stream_name(self) -> str:
13+
return self._stream_name
14+
15+
@property
16+
def exception(self) -> Exception:
17+
return self._exception
18+
19+
def __str__(self) -> str:
20+
return f"Exception while syncing stream {self._stream_name}: {self._exception}"
21+
22+
def __eq__(self, other: Any) -> bool:
23+
if isinstance(other, StreamThreadException):
24+
return self._exception == other._exception and self._stream_name == other._stream_name
25+
return False

airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -71,26 +71,26 @@ def _prune_futures(self, futures: List[Future[Any]]) -> None:
7171
)
7272
futures.pop(index)
7373

74-
def shutdown(self) -> None:
74+
def _shutdown(self) -> None:
75+
# Without a way to stop the threads that have already started, this will not stop the Python application. We are fine today with
76+
# this imperfect approach because we only do this in case of `self._most_recently_seen_exception` which we don't expect to happen
7577
self._threadpool.shutdown(wait=False, cancel_futures=True)
7678

7779
def is_done(self) -> bool:
7880
return all([f.done() for f in self._futures])
7981

80-
def shutdown_if_exception(self) -> None:
81-
"""
82-
This method will raise if there is an exception so that the caller can use it.
83-
"""
84-
if self._most_recently_seen_exception:
85-
self._stop_and_raise_exception(self._most_recently_seen_exception)
86-
8782
def check_for_errors_and_shutdown(self) -> None:
8883
"""
8984
Check if any of the futures have an exception, and raise it if so. If all futures are done, shutdown the threadpool.
9085
If the futures are not done, raise an exception.
9186
:return:
9287
"""
93-
self.shutdown_if_exception()
88+
if self._most_recently_seen_exception:
89+
self._logger.exception(
90+
"An unknown exception has occurred while reading concurrently",
91+
exc_info=self._most_recently_seen_exception,
92+
)
93+
self._stop_and_raise_exception(self._most_recently_seen_exception)
9494

9595
exceptions_from_futures = [f for f in [future.exception() for future in self._futures] if f is not None]
9696
if exceptions_from_futures:
@@ -102,8 +102,8 @@ def check_for_errors_and_shutdown(self) -> None:
102102
exception = RuntimeError(f"Failed reading with futures not done: {futures_not_done}")
103103
self._stop_and_raise_exception(exception)
104104
else:
105-
self.shutdown()
105+
self._shutdown()
106106

107107
def _stop_and_raise_exception(self, exception: BaseException) -> None:
108-
self.shutdown()
108+
self._shutdown()
109109
raise exception

airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/cursor.py

+100-13
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
#
44
import functools
55
from abc import ABC, abstractmethod
6-
from datetime import datetime
7-
from typing import Any, List, Mapping, MutableMapping, Optional, Protocol, Tuple
6+
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Protocol, Tuple
87

98
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
109
from airbyte_cdk.sources.message import MessageRepository
@@ -18,19 +17,41 @@ def _extract_value(mapping: Mapping[str, Any], path: List[str]) -> Any:
1817
return functools.reduce(lambda a, b: a[b], path, mapping)
1918

2019

21-
class Comparable(Protocol):
20+
class GapType(Protocol):
21+
"""
22+
This is the representation of gaps between two cursor values. Examples:
23+
* if cursor values are datetimes, GapType is timedelta
24+
* if cursor values are integer, GapType will also be integer
25+
"""
26+
27+
pass
28+
29+
30+
class CursorValueType(Protocol):
2231
"""Protocol for annotating comparable types."""
2332

2433
@abstractmethod
25-
def __lt__(self: "Comparable", other: "Comparable") -> bool:
34+
def __lt__(self: "CursorValueType", other: "CursorValueType") -> bool:
35+
pass
36+
37+
@abstractmethod
38+
def __ge__(self: "CursorValueType", other: "CursorValueType") -> bool:
39+
pass
40+
41+
@abstractmethod
42+
def __add__(self: "CursorValueType", other: GapType) -> "CursorValueType":
43+
pass
44+
45+
@abstractmethod
46+
def __sub__(self: "CursorValueType", other: GapType) -> "CursorValueType":
2647
pass
2748

2849

2950
class CursorField:
3051
def __init__(self, cursor_field_key: str) -> None:
3152
self.cursor_field_key = cursor_field_key
3253

33-
def extract_value(self, record: Record) -> Comparable:
54+
def extract_value(self, record: Record) -> CursorValueType:
3455
cursor_value = record.data.get(self.cursor_field_key)
3556
if cursor_value is None:
3657
raise ValueError(f"Could not find cursor field {self.cursor_field_key} in record")
@@ -118,7 +139,10 @@ def __init__(
118139
connector_state_converter: AbstractStreamStateConverter,
119140
cursor_field: CursorField,
120141
slice_boundary_fields: Optional[Tuple[str, str]],
121-
start: Optional[Any],
142+
start: Optional[CursorValueType],
143+
end_provider: Callable[[], CursorValueType],
144+
lookback_window: Optional[GapType] = None,
145+
slice_range: Optional[GapType] = None,
122146
) -> None:
123147
self._stream_name = stream_name
124148
self._stream_namespace = stream_namespace
@@ -129,15 +153,18 @@ def __init__(
129153
# To see some example where the slice boundaries might not be defined, check https://github.com/airbytehq/airbyte/blob/1ce84d6396e446e1ac2377362446e3fb94509461/airbyte-integrations/connectors/source-stripe/source_stripe/streams.py#L363-L379
130154
self._slice_boundary_fields = slice_boundary_fields if slice_boundary_fields else tuple()
131155
self._start = start
156+
self._end_provider = end_provider
132157
self._most_recent_record: Optional[Record] = None
133158
self._has_closed_at_least_one_slice = False
134159
self.start, self._concurrent_state = self._get_concurrent_state(stream_state)
160+
self._lookback_window = lookback_window
161+
self._slice_range = slice_range
135162

136163
@property
137164
def state(self) -> MutableMapping[str, Any]:
138165
return self._concurrent_state
139166

140-
def _get_concurrent_state(self, state: MutableMapping[str, Any]) -> Tuple[datetime, MutableMapping[str, Any]]:
167+
def _get_concurrent_state(self, state: MutableMapping[str, Any]) -> Tuple[CursorValueType, MutableMapping[str, Any]]:
141168
if self._connector_state_converter.is_state_message_compatible(state):
142169
return self._start or self._connector_state_converter.zero_value, self._connector_state_converter.deserialize(state)
143170
return self._connector_state_converter.convert_from_sequential_state(self._cursor_field, state, self._start)
@@ -203,23 +230,20 @@ def _emit_state_message(self) -> None:
203230
self._connector_state_manager.update_state_for_stream(
204231
self._stream_name,
205232
self._stream_namespace,
206-
self._connector_state_converter.convert_to_sequential_state(self._cursor_field, self.state),
233+
self._connector_state_converter.convert_to_state_message(self._cursor_field, self.state),
207234
)
208-
# TODO: if we migrate stored state to the concurrent state format
209-
# (aka stop calling self._connector_state_converter.convert_to_sequential_state`), we'll need to cast datetimes to string or
210-
# int before emitting state
211235
state_message = self._connector_state_manager.create_state_message(self._stream_name, self._stream_namespace)
212236
self._message_repository.emit_message(state_message)
213237

214238
def _merge_partitions(self) -> None:
215239
self.state["slices"] = self._connector_state_converter.merge_intervals(self.state["slices"])
216240

217-
def _extract_from_slice(self, partition: Partition, key: str) -> Comparable:
241+
def _extract_from_slice(self, partition: Partition, key: str) -> CursorValueType:
218242
try:
219243
_slice = partition.to_slice()
220244
if not _slice:
221245
raise KeyError(f"Could not find key `{key}` in empty slice")
222-
return self._connector_state_converter.parse_value(_slice[key]) # type: ignore # we expect the devs to specify a key that would return a Comparable
246+
return self._connector_state_converter.parse_value(_slice[key]) # type: ignore # we expect the devs to specify a key that would return a CursorValueType
223247
except KeyError as exception:
224248
raise KeyError(f"Partition is expected to have key `{key}` but could not be found") from exception
225249

@@ -229,3 +253,66 @@ def ensure_at_least_one_state_emitted(self) -> None:
229253
called.
230254
"""
231255
self._emit_state_message()
256+
257+
def generate_slices(self) -> Iterable[Tuple[CursorValueType, CursorValueType]]:
258+
"""
259+
Generating slices based on a few parameters:
260+
* lookback_window: Buffer to remove from END_KEY of the highest slice
261+
* slice_range: Max difference between two slices. If the difference between two slices is greater, multiple slices will be created
262+
* start: `_split_per_slice_range` will clip any value to `self._start which means that:
263+
* if upper is less than self._start, no slices will be generated
264+
* if lower is less than self._start, self._start will be used as the lower boundary (lookback_window will not be considered in that case)
265+
266+
Note that the slices will overlap at their boundaries. We therefore expect to have at least the lower or the upper boundary to be
267+
inclusive in the API that is queried.
268+
"""
269+
self._merge_partitions()
270+
271+
if self._start is not None and self._is_start_before_first_slice():
272+
yield from self._split_per_slice_range(self._start, self.state["slices"][0][self._connector_state_converter.START_KEY])
273+
274+
if len(self.state["slices"]) == 1:
275+
yield from self._split_per_slice_range(
276+
self._calculate_lower_boundary_of_last_slice(self.state["slices"][0][self._connector_state_converter.END_KEY]),
277+
self._end_provider(),
278+
)
279+
elif len(self.state["slices"]) > 1:
280+
for i in range(len(self.state["slices"]) - 1):
281+
yield from self._split_per_slice_range(
282+
self.state["slices"][i][self._connector_state_converter.END_KEY],
283+
self.state["slices"][i + 1][self._connector_state_converter.START_KEY],
284+
)
285+
yield from self._split_per_slice_range(
286+
self._calculate_lower_boundary_of_last_slice(self.state["slices"][-1][self._connector_state_converter.END_KEY]),
287+
self._end_provider(),
288+
)
289+
else:
290+
raise ValueError("Expected at least one slice")
291+
292+
def _is_start_before_first_slice(self) -> bool:
293+
return self._start is not None and self._start < self.state["slices"][0][self._connector_state_converter.START_KEY]
294+
295+
def _calculate_lower_boundary_of_last_slice(self, lower_boundary: CursorValueType) -> CursorValueType:
296+
if self._lookback_window:
297+
return lower_boundary - self._lookback_window
298+
return lower_boundary
299+
300+
def _split_per_slice_range(self, lower: CursorValueType, upper: CursorValueType) -> Iterable[Tuple[CursorValueType, CursorValueType]]:
301+
if lower >= upper:
302+
return
303+
304+
if self._start and upper < self._start:
305+
return
306+
307+
lower = max(lower, self._start) if self._start else lower
308+
if not self._slice_range or lower + self._slice_range >= upper:
309+
yield lower, upper
310+
else:
311+
stop_processing = False
312+
current_lower_boundary = lower
313+
while not stop_processing:
314+
current_upper_boundary = min(current_lower_boundary + self._slice_range, upper)
315+
yield current_lower_boundary, current_upper_boundary
316+
current_lower_boundary = current_upper_boundary
317+
if current_upper_boundary >= upper:
318+
stop_processing = True

airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from queue import Queue
66

77
from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel
8+
from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException
89
from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager
910
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
1011
from airbyte_cdk.sources.streams.concurrent.partitions.types import QueueItem
@@ -52,4 +53,5 @@ def generate_partitions(self, stream: AbstractStream) -> None:
5253
self._queue.put(partition)
5354
self._queue.put(PartitionGenerationCompletedSentinel(stream))
5455
except Exception as e:
55-
self._queue.put(e)
56+
self._queue.put(StreamThreadException(e, stream.name))
57+
self._queue.put(PartitionGenerationCompletedSentinel(stream))

airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_reader.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
from queue import Queue
55

6+
from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException
67
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
78
from airbyte_cdk.sources.streams.concurrent.partitions.types import PartitionCompleteSentinel, QueueItem
89

@@ -35,4 +36,5 @@ def process_partition(self, partition: Partition) -> None:
3536
self._queue.put(record)
3637
self._queue.put(PartitionCompleteSentinel(partition))
3738
except Exception as e:
38-
self._queue.put(e)
39+
self._queue.put(StreamThreadException(e, partition.stream_name()))
40+
self._queue.put(PartitionCompleteSentinel(partition))

airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/types.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
33
#
44

5-
from typing import Union
5+
from typing import Any, Union
66

77
from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel
88
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
@@ -21,6 +21,11 @@ def __init__(self, partition: Partition):
2121
"""
2222
self.partition = partition
2323

24+
def __eq__(self, other: Any) -> bool:
25+
if isinstance(other, PartitionCompleteSentinel):
26+
return self.partition == other.partition
27+
return False
28+
2429

2530
"""
2631
Typedef representing the items that can be added to the ThreadBasedConcurrentStream

0 commit comments

Comments
 (0)