Skip to content

Commit e3ce82e

Browse files
authored
feat(airbyte-cdk): add global_state => per_partition transformation (#45122)
Signed-off-by: Artem Inzhyyants <[email protected]>
1 parent bd20b74 commit e3ce82e

File tree

2 files changed

+42
-19
lines changed

2 files changed

+42
-19
lines changed

airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@
66
from collections import OrderedDict
77
from typing import Any, Callable, Iterable, Mapping, Optional, Union
88

9-
from airbyte_cdk.models import FailureType
109
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
1110
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
1211
from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import PerPartitionKeySerializer
1312
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState
14-
from airbyte_cdk.utils import AirbyteTracedException
1513

1614

1715
class CursorFactory:
@@ -48,6 +46,7 @@ class PerPartitionCursor(DeclarativeCursor):
4846
_NO_CURSOR_STATE: Mapping[str, Any] = {}
4947
_KEY = 0
5048
_VALUE = 1
49+
_state_to_migrate_from: Mapping[str, Any] = {}
5150

5251
def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRouter):
5352
self._cursor_factory = cursor_factory
@@ -65,7 +64,8 @@ def stream_slices(self) -> Iterable[StreamSlice]:
6564

6665
cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition))
6766
if not cursor:
68-
cursor = self._create_cursor(self._NO_CURSOR_STATE)
67+
partition_state = self._state_to_migrate_from if self._state_to_migrate_from else self._NO_CURSOR_STATE
68+
cursor = self._create_cursor(partition_state)
6969
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor
7070

7171
for cursor_slice in cursor.stream_slices():
@@ -113,15 +113,13 @@ def set_initial_state(self, stream_state: StreamState) -> None:
113113
return
114114

115115
if "states" not in stream_state:
116-
raise AirbyteTracedException(
117-
internal_message=f"Could not sync parse the following state: {stream_state}",
118-
message="The state for is format invalid. Validate that the migration steps included a reset and that it was performed "
119-
"properly. Otherwise, please contact Airbyte support.",
120-
failure_type=FailureType.config_error,
121-
)
116+
# We assume that `stream_state` is in a global format that can be applied to all partitions.
117+
# Example: {"global_state_format_key": "global_state_format_value"}
118+
self._state_to_migrate_from = stream_state
122119

123-
for state in stream_state["states"]:
124-
self._cursor_per_partition[self._to_partition_key(state["partition"])] = self._create_cursor(state["cursor"])
120+
else:
121+
for state in stream_state["states"]:
122+
self._cursor_per_partition[self._to_partition_key(state["partition"])] = self._create_cursor(state["cursor"])
125123

126124
# Set parent state for partition routers based on parent streams
127125
self._partition_router.set_initial_state(stream_state)

airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py

+33-8
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@
66
from unittest.mock import Mock
77

88
import pytest
9-
from airbyte_cdk.models import FailureType
109
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
1110
from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import PerPartitionCursor, PerPartitionKeySerializer, StreamSlice
1211
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
1312
from airbyte_cdk.sources.types import Record
14-
from airbyte_cdk.utils import AirbyteTracedException
1513

1614
PARTITION = {
1715
"partition_key string": "partition value",
@@ -519,10 +517,37 @@ def test_get_stream_state_includes_parent_state(mocked_cursor_factory, mocked_pa
519517
assert stream_state == expected_state
520518

521519

522-
def test_given_invalid_state_when_set_initial_state_then_raise_config_error(mocked_cursor_factory, mocked_partition_router) -> None:
523-
cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router)
524-
525-
with pytest.raises(AirbyteTracedException) as exception:
526-
cursor.set_initial_state({"invalid_state": 1})
520+
def test_per_partition_state_when_set_initial_global_state(mocked_cursor_factory, mocked_partition_router) -> None:
521+
first_partition = {"first_partition_key": "first_partition_value"}
522+
second_partition = {"second_partition_key": "second_partition_value"}
523+
global_state = {"global_state_format_key": "global_state_format_value"}
527524

528-
assert exception.value.failure_type == FailureType.config_error
525+
mocked_partition_router.stream_slices.return_value = [
526+
StreamSlice(partition=first_partition, cursor_slice={}),
527+
StreamSlice(partition=second_partition, cursor_slice={}),
528+
]
529+
mocked_cursor_factory.create.side_effect = [
530+
MockedCursorBuilder().with_stream_state(global_state).build(),
531+
MockedCursorBuilder().with_stream_state(global_state).build(),
532+
]
533+
cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router)
534+
global_state = {"global_state_format_key": "global_state_format_value"}
535+
cursor.set_initial_state(global_state)
536+
assert cursor._state_to_migrate_from == global_state
537+
list(cursor.stream_slices())
538+
assert cursor._cursor_per_partition['{"first_partition_key":"first_partition_value"}'].set_initial_state.call_count == 1
539+
assert cursor._cursor_per_partition['{"first_partition_key":"first_partition_value"}'].set_initial_state.call_args[0] == (
540+
{"global_state_format_key": "global_state_format_value"},
541+
)
542+
assert cursor._cursor_per_partition['{"second_partition_key":"second_partition_value"}'].set_initial_state.call_count == 1
543+
assert cursor._cursor_per_partition['{"second_partition_key":"second_partition_value"}'].set_initial_state.call_args[0] == (
544+
{"global_state_format_key": "global_state_format_value"},
545+
)
546+
expected_state = [
547+
{"cursor": {"global_state_format_key": "global_state_format_value"}, "partition": {"first_partition_key": "first_partition_value"}},
548+
{
549+
"cursor": {"global_state_format_key": "global_state_format_value"},
550+
"partition": {"second_partition_key": "second_partition_value"},
551+
},
552+
]
553+
assert cursor.get_stream_state()["states"] == expected_state

0 commit comments

Comments
 (0)