Skip to content

Commit 03b7e1a

Browse files
authored
feat(airbyte-cdk): Add limitation for number of partitions to PerPartitionCursor (#42406)
1 parent 1de50aa commit 03b7e1a

File tree

2 files changed

+122
-5
lines changed

2 files changed

+122
-5
lines changed

airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
33
#
44

5-
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, Union
5+
import logging
6+
from collections import OrderedDict
7+
from typing import Any, Callable, Iterable, Mapping, Optional, Union
68

79
from airbyte_cdk.models import FailureType
810
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
@@ -41,6 +43,7 @@ class PerPartitionCursor(DeclarativeCursor):
4143
Therefore, we need to manage state per partition.
4244
"""
4345

46+
DEFAULT_MAX_PARTITIONS_NUMBER = 10000
4447
_NO_STATE: Mapping[str, Any] = {}
4548
_NO_CURSOR_STATE: Mapping[str, Any] = {}
4649
_KEY = 0
@@ -49,12 +52,17 @@ class PerPartitionCursor(DeclarativeCursor):
4952
def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRouter):
5053
self._cursor_factory = cursor_factory
5154
self._partition_router = partition_router
52-
self._cursor_per_partition: MutableMapping[str, DeclarativeCursor] = {}
55+
# The dict is ordered to ensure that once the maximum number of partitions is reached,
56+
# the oldest partitions can be efficiently removed, maintaining the most recent partitions.
57+
self._cursor_per_partition: OrderedDict[str, DeclarativeCursor] = OrderedDict()
5358
self._partition_serializer = PerPartitionKeySerializer()
5459

5560
def stream_slices(self) -> Iterable[StreamSlice]:
5661
slices = self._partition_router.stream_slices()
5762
for partition in slices:
63+
# Ensure the maximum number of partitions is not exceeded
64+
self._ensure_partition_limit()
65+
5866
cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition))
5967
if not cursor:
6068
cursor = self._create_cursor(self._NO_CURSOR_STATE)
@@ -63,6 +71,14 @@ def stream_slices(self) -> Iterable[StreamSlice]:
6371
for cursor_slice in cursor.stream_slices():
6472
yield StreamSlice(partition=partition, cursor_slice=cursor_slice)
6573

74+
def _ensure_partition_limit(self) -> None:
75+
"""
76+
Ensure the maximum number of partitions is not exceeded. If so, the oldest added partition will be dropped.
77+
"""
78+
while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
79+
oldest_partition = self._cursor_per_partition.popitem(last=False)[0] # Remove the oldest partition
80+
logging.warning(f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}.")
81+
6682
def set_initial_state(self, stream_state: StreamState) -> None:
6783
"""
6884
Set the initial state for the cursors.

airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_per_partition_cursor_integration.py

+104-3
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,21 @@
22
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
33
#
44

5-
from unittest.mock import patch
5+
from unittest.mock import MagicMock, patch
66

7-
from airbyte_cdk.models import SyncMode
8-
from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import StreamSlice
7+
from airbyte_cdk.models import (
8+
AirbyteStateBlob,
9+
AirbyteStateMessage,
10+
AirbyteStateType,
11+
AirbyteStream,
12+
AirbyteStreamState,
13+
ConfiguredAirbyteCatalog,
14+
ConfiguredAirbyteStream,
15+
DestinationSyncMode,
16+
StreamDescriptor,
17+
SyncMode,
18+
)
19+
from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import PerPartitionCursor, StreamSlice
920
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
1021
from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever
1122
from airbyte_cdk.sources.types import Record
@@ -268,3 +279,93 @@ def test_substream_without_input_state():
268279
cursor_slice={"start_time": "2022-02-01", "end_time": "2022-02-28"},
269280
),
270281
]
282+
283+
284+
def test_partition_limitation():
285+
source = ManifestDeclarativeSource(
286+
source_config=ManifestBuilder()
287+
.with_list_partition_router("Rates", "partition_field", ["1", "2", "3"])
288+
.with_incremental_sync(
289+
"Rates",
290+
start_datetime="2022-01-01",
291+
end_datetime="2022-02-28",
292+
datetime_format="%Y-%m-%d",
293+
cursor_field=CURSOR_FIELD,
294+
step="P1M",
295+
cursor_granularity="P1D",
296+
)
297+
.build()
298+
)
299+
300+
partition_slices = [
301+
StreamSlice(partition={"partition_field": "1"}, cursor_slice={}),
302+
StreamSlice(partition={"partition_field": "2"}, cursor_slice={}),
303+
StreamSlice(partition={"partition_field": "3"}, cursor_slice={}),
304+
]
305+
306+
records_list = [
307+
[
308+
Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-15"}, partition_slices[0]),
309+
Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-16"}, partition_slices[0]),
310+
],
311+
[Record({"a record key": "a record value", CURSOR_FIELD: "2022-02-15"}, partition_slices[0])],
312+
[Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-16"}, partition_slices[1])],
313+
[],
314+
[],
315+
[Record({"a record key": "a record value", CURSOR_FIELD: "2022-02-17"}, partition_slices[2])],
316+
]
317+
318+
configured_stream = ConfiguredAirbyteStream(
319+
stream=AirbyteStream(name="Rates", json_schema={}, supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental]),
320+
sync_mode=SyncMode.incremental,
321+
destination_sync_mode=DestinationSyncMode.append,
322+
)
323+
catalog = ConfiguredAirbyteCatalog(streams=[configured_stream])
324+
325+
initial_state = [
326+
AirbyteStateMessage(
327+
type=AirbyteStateType.STREAM,
328+
stream=AirbyteStreamState(
329+
stream_descriptor=StreamDescriptor(name="post_comment_votes", namespace=None),
330+
stream_state=AirbyteStateBlob.parse_obj(
331+
{
332+
"states": [
333+
{
334+
"partition": {"partition_field": "1"},
335+
"cursor": {CURSOR_FIELD: "2022-01-01"},
336+
},
337+
{
338+
"partition": {"partition_field": "2"},
339+
"cursor": {CURSOR_FIELD: "2022-01-02"},
340+
},
341+
{
342+
"partition": {"partition_field": "3"},
343+
"cursor": {CURSOR_FIELD: "2022-01-03"},
344+
},
345+
]
346+
}
347+
),
348+
),
349+
)
350+
]
351+
logger = MagicMock()
352+
353+
# with patch.object(PerPartitionCursor, "stream_slices", return_value=partition_slices):
354+
with patch.object(SimpleRetriever, "_read_pages", side_effect=records_list):
355+
with patch.object(PerPartitionCursor, "DEFAULT_MAX_PARTITIONS_NUMBER", 2):
356+
output = list(source.read(logger, {}, catalog, initial_state))
357+
358+
# assert output_data == expected_records
359+
final_state = [message.state.stream.stream_state.dict() for message in output if message.state]
360+
assert final_state[-1] == {
361+
"states": [
362+
{
363+
"partition": {"partition_field": "2"},
364+
"cursor": {CURSOR_FIELD: "2022-01-16"},
365+
},
366+
{
367+
"partition": {"partition_field": "3"},
368+
"cursor": {CURSOR_FIELD: "2022-02-17"},
369+
},
370+
]
371+
}

0 commit comments

Comments
 (0)