Skip to content

Commit 1d9608c

Browse files
authored
[per-stream cdk] Support deserialization of legacy and per-stream state (#16205)
* interpret legacy and new per-stream format into AirbyteStateMessages * add ConnectorStateManager stubs for future work * remove frozen for the time being until we need to hash descriptors * add validation that AirbyteStateMessage has at least one of stream, global, or data fields * pr feedback and clean up of the code * remove changes to airbyte_protocol and perform validation in read_state() * fix import formatting
1 parent fd66f1f commit 1d9608c

File tree

5 files changed

+322
-23
lines changed

5 files changed

+322
-23
lines changed

airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
33
#
44

5-
6-
import copy
75
import logging
86
from abc import ABC, abstractmethod
97
from datetime import datetime
108
from functools import lru_cache
11-
from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional, Tuple
9+
from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union
1210

1311
from airbyte_cdk.models import (
1412
AirbyteCatalog,
@@ -22,6 +20,7 @@
2220
SyncMode,
2321
)
2422
from airbyte_cdk.models import Type as MessageType
23+
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
2524
from airbyte_cdk.sources.source import Source
2625
from airbyte_cdk.sources.streams import Stream
2726
from airbyte_cdk.sources.streams.http.http import HttpStream
@@ -91,10 +90,12 @@ def read(
9190
logger: logging.Logger,
9291
config: Mapping[str, Any],
9392
catalog: ConfiguredAirbyteCatalog,
94-
state: MutableMapping[str, Any] = None,
93+
state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]] = None,
9594
) -> Iterator[AirbyteMessage]:
9695
"""Implements the Read operation from the Airbyte Specification. See https://docs.airbyte.io/architecture/airbyte-protocol."""
97-
connector_state = copy.deepcopy(state or {})
96+
state_manager = ConnectorStateManager(state=state)
97+
connector_state = state_manager.get_legacy_state()
98+
9899
logger.info(f"Starting syncing {self.name}")
99100
config, internal_config = split_config(config)
100101
# TODO assert all streams exist in the connector
@@ -133,6 +134,10 @@ def read(
133134

134135
logger.info(f"Finished syncing {self.name}")
135136

137+
@property
138+
def per_stream_state_enabled(self):
139+
return False # While CDK per-stream is in active development we should keep this off
140+
136141
def _read_stream(
137142
self,
138143
logger: logging.Logger,
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#
2+
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
3+
#
4+
5+
import copy
6+
from typing import Any, List, Mapping, MutableMapping, Union
7+
8+
from airbyte_cdk.models import AirbyteStateBlob, AirbyteStateMessage, AirbyteStateType
9+
10+
11+
class ConnectorStateManager:
12+
"""
13+
ConnectorStateManager consolidates the various forms of a stream's incoming state message (STREAM / GLOBAL / LEGACY) under a common
14+
interface. It also provides methods to extract and update state
15+
"""
16+
17+
# In the immediate, we only persist legacy which will be used during abstract_source.read(). In the subsequent PRs we will
18+
# initialize the ConnectorStateManager according to the new per-stream interface received from the platform
19+
def __init__(self, state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]] = None):
20+
if not state:
21+
self.legacy = {}
22+
elif self.is_migrated_legacy_state(state):
23+
# The legacy state format received from the platform is parsed and stored as a single AirbyteStateMessage when reading
24+
# the file. This is used for input backwards compatibility.
25+
self.legacy = state[0].data
26+
elif isinstance(state, MutableMapping):
27+
# In the event that legacy state comes in as its original JSON object format, no changes to the input need to be made
28+
self.legacy = state
29+
else:
30+
raise ValueError("Input state should come in the form of list of Airbyte state messages or a mapping of states")
31+
32+
def get_stream_state(self, namespace: str, stream_name: str) -> AirbyteStateBlob:
33+
# todo implement in upcoming PRs
34+
pass
35+
36+
def get_legacy_state(self) -> MutableMapping[str, Any]:
37+
"""
38+
Returns a deep copy of the current legacy state dictionary made up of the state of all streams for a connector
39+
:return: A copy of the legacy state
40+
"""
41+
return copy.deepcopy(self.legacy, {})
42+
43+
def update_state_for_stream(self, namespace: str, stream_name: str, value: Mapping[str, Any]):
44+
# todo implement in upcoming PRs
45+
pass
46+
47+
@staticmethod
48+
def is_migrated_legacy_state(state: List[AirbyteStateMessage]) -> bool:
49+
return (
50+
isinstance(state, List)
51+
and len(state) == 1
52+
and isinstance(state[0], AirbyteStateMessage)
53+
and state[0].type == AirbyteStateType.LEGACY
54+
)

airbyte-cdk/python/airbyte_cdk/sources/source.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
import json
77
import logging
88
from abc import ABC, abstractmethod
9-
from collections import defaultdict
10-
from typing import Any, Dict, Generic, Iterable, Mapping, MutableMapping, TypeVar
9+
from typing import Any, Generic, Iterable, List, Mapping, MutableMapping, TypeVar, Union
1110

1211
from airbyte_cdk.connector import BaseConnector, DefaultConnectorMixin, TConfig
13-
from airbyte_cdk.models import AirbyteCatalog, AirbyteMessage, ConfiguredAirbyteCatalog
12+
from airbyte_cdk.models import AirbyteCatalog, AirbyteMessage, AirbyteStateMessage, AirbyteStateType, ConfiguredAirbyteCatalog
1413

1514
TState = TypeVar("TState")
1615
TCatalog = TypeVar("TCatalog")
@@ -39,15 +38,37 @@ def discover(self, logger: logging.Logger, config: TConfig) -> AirbyteCatalog:
3938
"""
4039

4140

42-
class Source(DefaultConnectorMixin, BaseSource[Mapping[str, Any], MutableMapping[str, Any], ConfiguredAirbyteCatalog], ABC):
41+
class Source(
42+
DefaultConnectorMixin,
43+
BaseSource[Mapping[str, Any], Union[List[AirbyteStateMessage], MutableMapping[str, Any]], ConfiguredAirbyteCatalog],
44+
ABC,
45+
):
4346
# can be overridden to change an input state
44-
def read_state(self, state_path: str) -> Dict[str, Any]:
47+
def read_state(self, state_path: str) -> List[AirbyteStateMessage]:
48+
"""
49+
Retrieves the input state of a sync by reading from the specified JSON file. Incoming state can be deserialized into either
50+
a JSON object for legacy state input or as a list of AirbyteStateMessages for the per-stream state format. Regardless of the
51+
incoming input type, it will always be transformed and output as a list of AirbyteStateMessage(s).
52+
:param state_path: The filepath to where the stream states are located
53+
:return: The complete stream state based on the connector's previous sync
54+
"""
4555
if state_path:
4656
state_obj = json.loads(open(state_path, "r").read())
47-
else:
48-
state_obj = {}
49-
state = defaultdict(dict, state_obj)
50-
return state
57+
if not state_obj:
58+
return []
59+
is_per_stream_state = isinstance(state_obj, List)
60+
if is_per_stream_state:
61+
parsed_state_messages = []
62+
for state in state_obj:
63+
parsed_message = AirbyteStateMessage.parse_obj(state)
64+
if not parsed_message.stream and not parsed_message.data and not parsed_message.global_:
65+
raise ValueError("AirbyteStateMessage should contain either a stream, global, or state field")
66+
parsed_state_messages.append(parsed_message)
67+
return parsed_state_messages
68+
else:
69+
# When the legacy JSON object format is received, always outputting an AirbyteStateMessage simplifies processing downstream
70+
return [AirbyteStateMessage(type=AirbyteStateType.LEGACY, data=state_obj)]
71+
return []
5172

5273
# can be overridden to change an input catalog
5374
def read_catalog(self, catalog_path: str) -> ConfiguredAirbyteCatalog:
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#
2+
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
3+
#
4+
5+
from contextlib import nullcontext as does_not_raise
6+
7+
import pytest
8+
from airbyte_cdk.models import AirbyteStateMessage, AirbyteStateType
9+
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
10+
11+
12+
@pytest.mark.parametrize(
13+
"input_state, expected_legacy_state, expected_error",
14+
[
15+
pytest.param(
16+
[AirbyteStateMessage(type=AirbyteStateType.LEGACY, data={"actresses": {"id": "seehorn_rhea"}})],
17+
{"actresses": {"id": "seehorn_rhea"}},
18+
does_not_raise(),
19+
id="test_legacy_input_state",
20+
),
21+
pytest.param(
22+
{
23+
"actors": {"created_at": "1962-10-22"},
24+
"actresses": {"id": "seehorn_rhea"},
25+
},
26+
{"actors": {"created_at": "1962-10-22"}, "actresses": {"id": "seehorn_rhea"}},
27+
does_not_raise(),
28+
id="test_supports_legacy_json_blob",
29+
),
30+
pytest.param({}, {}, does_not_raise(), id="test_initialize_empty_mapping_by_default"),
31+
pytest.param([], {}, does_not_raise(), id="test_initialize_empty_state"),
32+
pytest.param("strings_are_not_allowed", None, pytest.raises(ValueError), id="test_value_error_is_raised_on_invalid_state_input"),
33+
],
34+
)
35+
def test_get_legacy_state(input_state, expected_legacy_state, expected_error):
36+
with expected_error:
37+
state_manager = ConnectorStateManager(input_state)
38+
actual_legacy_state = state_manager.get_legacy_state()
39+
assert actual_legacy_state == expected_legacy_state

0 commit comments

Comments
 (0)