Skip to content

Commit 86ee91e

Browse files
authored
Connector builder: read input state if it exists (#37495)
1 parent 28209fd commit 86ee91e

File tree

6 files changed

+71
-46
lines changed

6 files changed

+71
-46
lines changed

airbyte-cdk/python/airbyte_cdk/connector_builder/connector_builder_handler.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
import dataclasses
66
from datetime import datetime
7-
from typing import Any, Mapping
7+
from typing import Any, List, Mapping
88

99
from airbyte_cdk.connector_builder.message_grouper import MessageGrouper
10-
from airbyte_cdk.models import AirbyteMessage, AirbyteRecordMessage, ConfiguredAirbyteCatalog
10+
from airbyte_cdk.models import AirbyteMessage, AirbyteRecordMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog
1111
from airbyte_cdk.models import Type
1212
from airbyte_cdk.models import Type as MessageType
1313
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
@@ -54,12 +54,12 @@ def create_source(config: Mapping[str, Any], limits: TestReadLimits) -> Manifest
5454

5555

5656
def read_stream(
57-
source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, limits: TestReadLimits
57+
source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, state: List[AirbyteStateMessage], limits: TestReadLimits
5858
) -> AirbyteMessage:
5959
try:
6060
handler = MessageGrouper(limits.max_pages_per_slice, limits.max_slices, limits.max_records)
6161
stream_name = configured_catalog.streams[0].stream.name # The connector builder only supports a single stream
62-
stream_read = handler.get_message_groups(source, config, configured_catalog, limits.max_records)
62+
stream_read = handler.get_message_groups(source, config, configured_catalog, state, limits.max_records)
6363
return AirbyteMessage(
6464
type=MessageType.RECORD,
6565
record=AirbyteRecordMessage(data=dataclasses.asdict(stream_read), stream=stream_name, emitted_at=_emitted_at()),

airbyte-cdk/python/airbyte_cdk/connector_builder/main.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,17 @@
99
from airbyte_cdk.connector import BaseConnector
1010
from airbyte_cdk.connector_builder.connector_builder_handler import TestReadLimits, create_source, get_limits, read_stream, resolve_manifest
1111
from airbyte_cdk.entrypoint import AirbyteEntrypoint
12-
from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog
12+
from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog
1313
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
14+
from airbyte_cdk.sources.source import Source
1415
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
1516

1617

17-
def get_config_and_catalog_from_args(args: List[str]) -> Tuple[str, Mapping[str, Any], Optional[ConfiguredAirbyteCatalog]]:
18+
def get_config_and_catalog_from_args(args: List[str]) -> Tuple[str, Mapping[str, Any], Optional[ConfiguredAirbyteCatalog], Any]:
1819
# TODO: Add functionality for the `debug` logger.
1920
# Currently, no one `debug` level log will be displayed during `read` a stream for a connector created through `connector-builder`.
2021
parsed_args = AirbyteEntrypoint.parse_args(args)
21-
config_path, catalog_path = parsed_args.config, parsed_args.catalog
22+
config_path, catalog_path, state_path = parsed_args.config, parsed_args.catalog, parsed_args.state
2223
if parsed_args.command != "read":
2324
raise ValueError("Only read commands are allowed for Connector Builder requests.")
2425

@@ -32,38 +33,41 @@ def get_config_and_catalog_from_args(args: List[str]) -> Tuple[str, Mapping[str,
3233
command = config["__command"]
3334
if command == "test_read":
3435
catalog = ConfiguredAirbyteCatalog.parse_obj(BaseConnector.read_config(catalog_path))
36+
state = Source.read_state(state_path)
3537
else:
3638
catalog = None
39+
state = []
3740

3841
if "__injected_declarative_manifest" not in config:
3942
raise ValueError(
4043
f"Invalid config: `__injected_declarative_manifest` should be provided at the root of the config but config only has keys {list(config.keys())}"
4144
)
4245

43-
return command, config, catalog
46+
return command, config, catalog, state
4447

4548

4649
def handle_connector_builder_request(
4750
source: ManifestDeclarativeSource,
4851
command: str,
4952
config: Mapping[str, Any],
5053
catalog: Optional[ConfiguredAirbyteCatalog],
54+
state: List[AirbyteStateMessage],
5155
limits: TestReadLimits,
5256
) -> AirbyteMessage:
5357
if command == "resolve_manifest":
5458
return resolve_manifest(source)
5559
elif command == "test_read":
5660
assert catalog is not None, "`test_read` requires a valid `ConfiguredAirbyteCatalog`, got None."
57-
return read_stream(source, config, catalog, limits)
61+
return read_stream(source, config, catalog, state, limits)
5862
else:
5963
raise ValueError(f"Unrecognized command {command}.")
6064

6165

6266
def handle_request(args: List[str]) -> AirbyteMessage:
63-
command, config, catalog = get_config_and_catalog_from_args(args)
67+
command, config, catalog, state = get_config_and_catalog_from_args(args)
6468
limits = get_limits(config)
6569
source = create_source(config, limits)
66-
return handle_connector_builder_request(source, command, config, catalog, limits).json(exclude_unset=True)
70+
return handle_connector_builder_request(source, command, config, catalog, state, limits).json(exclude_unset=True)
6771

6872

6973
if __name__ == "__main__":

airbyte-cdk/python/airbyte_cdk/connector_builder/message_grouper.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
AirbyteControlMessage,
2929
AirbyteLogMessage,
3030
AirbyteMessage,
31+
AirbyteStateMessage,
3132
AirbyteTraceMessage,
3233
ConfiguredAirbyteCatalog,
3334
OrchestratorType,
@@ -75,6 +76,7 @@ def get_message_groups(
7576
source: DeclarativeSource,
7677
config: Mapping[str, Any],
7778
configured_catalog: ConfiguredAirbyteCatalog,
79+
state: List[AirbyteStateMessage],
7880
record_limit: Optional[int] = None,
7981
) -> StreamRead:
8082
if record_limit is not None and not (1 <= record_limit <= self._max_record_limit):
@@ -96,7 +98,7 @@ def get_message_groups(
9698
latest_config_update: AirbyteControlMessage = None
9799
auxiliary_requests = []
98100
for message_group in self._get_message_groups(
99-
self._read_stream(source, config, configured_catalog),
101+
self._read_stream(source, config, configured_catalog, state),
100102
schema_inferrer,
101103
datetime_format_inferrer,
102104
record_limit,
@@ -181,7 +183,7 @@ def _get_message_groups(
181183
and message.type == MessageType.LOG
182184
and message.log.message.startswith(SliceLogger.SLICE_LOG_PREFIX)
183185
):
184-
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor, state=latest_state_message)
186+
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor, state=[latest_state_message] if latest_state_message else [])
185187
current_slice_descriptor = self._parse_slice_description(message.log.message)
186188
current_slice_pages = []
187189
at_least_one_page_in_group = False
@@ -228,7 +230,7 @@ def _get_message_groups(
228230
else:
229231
if current_page_request or current_page_response or current_page_records:
230232
self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records)
231-
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor, state=latest_state_message)
233+
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor, state=[latest_state_message] if latest_state_message else [])
232234

233235
@staticmethod
234236
def _need_to_close_page(at_least_one_page_in_group: bool, message: AirbyteMessage, json_message: Optional[Dict[str, Any]]) -> bool:
@@ -279,12 +281,13 @@ def _close_page(
279281
current_page_records.clear()
280282

281283
def _read_stream(
282-
self, source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog
284+
self, source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog,
285+
state: List[AirbyteStateMessage]
283286
) -> Iterator[AirbyteMessage]:
284287
# the generator can raise an exception
285288
# iterate over the generated messages. if next raise an exception, catch it and yield it as an AirbyteLogMessage
286289
try:
287-
yield from AirbyteEntrypoint(source).read(source.spec(self.logger), config, configured_catalog, {})
290+
yield from AirbyteEntrypoint(source).read(source.spec(self.logger), config, configured_catalog, state)
288291
except Exception as e:
289292
error_message = f"{e.args[0] if len(e.args) > 0 else str(e)}"
290293
yield AirbyteTracedException.from_exception(e, message=error_message).as_airbyte_message()

airbyte-cdk/python/airbyte_cdk/connector_builder/models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class StreamReadPages:
3232
class StreamReadSlices:
3333
pages: List[StreamReadPages]
3434
slice_descriptor: Optional[Dict[str, Any]]
35-
state: Optional[Dict[str, Any]] = None
35+
state: Optional[List[Dict[str, Any]]] = None
3636

3737

3838
@dataclass

airbyte-cdk/python/unit_tests/connector_builder/test_connector_builder_handler.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@
2828
AirbyteLogMessage,
2929
AirbyteMessage,
3030
AirbyteRecordMessage,
31+
AirbyteStateMessage,
3132
AirbyteStream,
33+
AirbyteStreamState,
3234
ConfiguredAirbyteCatalog,
3335
ConfiguredAirbyteStream,
3436
ConnectorSpecification,
3537
DestinationSyncMode,
3638
Level,
39+
StreamDescriptor,
3740
SyncMode,
3841
)
3942
from airbyte_cdk.models import Type
@@ -50,6 +53,18 @@
5053
_stream_options = {"name": _stream_name, "primary_key": _stream_primary_key, "url_base": _stream_url_base}
5154
_page_size = 2
5255

56+
_A_STATE = [AirbyteStateMessage(
57+
type="STREAM",
58+
stream=AirbyteStreamState(
59+
stream_descriptor=StreamDescriptor(
60+
name=_stream_name
61+
),
62+
stream_state={
63+
"key": "value"
64+
}
65+
)
66+
)]
67+
5368
MANIFEST = {
5469
"version": "0.30.3",
5570
"definitions": {
@@ -266,7 +281,7 @@ def test_resolve_manifest(valid_resolve_manifest_config_file):
266281
config["__command"] = command
267282
source = ManifestDeclarativeSource(MANIFEST)
268283
limits = TestReadLimits()
269-
resolved_manifest = handle_connector_builder_request(source, command, config, create_configured_catalog("dummy_stream"), limits)
284+
resolved_manifest = handle_connector_builder_request(source, command, config, create_configured_catalog("dummy_stream"), _A_STATE, limits)
270285

271286
expected_resolved_manifest = {
272287
"type": "DeclarativeSource",
@@ -455,10 +470,11 @@ def test_read():
455470
),
456471
)
457472
limits = TestReadLimits()
458-
with patch("airbyte_cdk.connector_builder.message_grouper.MessageGrouper.get_message_groups", return_value=stream_read):
473+
with patch("airbyte_cdk.connector_builder.message_grouper.MessageGrouper.get_message_groups", return_value=stream_read) as mock:
459474
output_record = handle_connector_builder_request(
460-
source, "test_read", config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), limits
475+
source, "test_read", config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), _A_STATE, limits
461476
)
477+
mock.assert_called_with(source, config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), _A_STATE, limits.max_records)
462478
output_record.record.emitted_at = 1
463479
assert output_record == expected_airbyte_message
464480

@@ -492,7 +508,7 @@ def test_config_update():
492508
return_value=refresh_request_response,
493509
):
494510
output = handle_connector_builder_request(
495-
source, "test_read", config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), TestReadLimits()
511+
source, "test_read", config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), _A_STATE, TestReadLimits()
496512
)
497513
assert output.record.data["latest_config_update"]
498514

@@ -529,7 +545,7 @@ def check_config_against_spec(self):
529545

530546
source = MockManifestDeclarativeSource()
531547
limits = TestReadLimits()
532-
response = read_stream(source, TEST_READ_CONFIG, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), limits)
548+
response = read_stream(source, TEST_READ_CONFIG, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), _A_STATE, limits)
533549

534550
expected_stream_read = StreamRead(
535551
logs=[LogMessage("error_message - a stack trace", "ERROR")],
@@ -716,7 +732,7 @@ def test_read_source(mock_http_stream):
716732

717733
source = create_source(config, limits)
718734

719-
output_data = read_stream(source, config, catalog, limits).record.data
735+
output_data = read_stream(source, config, catalog, _A_STATE, limits).record.data
720736
slices = output_data["slices"]
721737

722738
assert len(slices) == max_slices
@@ -761,7 +777,7 @@ def test_read_source_single_page_single_slice(mock_http_stream):
761777

762778
source = create_source(config, limits)
763779

764-
output_data = read_stream(source, config, catalog, limits).record.data
780+
output_data = read_stream(source, config, catalog, _A_STATE, limits).record.data
765781
slices = output_data["slices"]
766782

767783
assert len(slices) == max_slices
@@ -817,7 +833,7 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error
817833
source = create_source(config, limits)
818834

819835
with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False):
820-
output_data = read_stream(source, config, catalog, limits).record.data
836+
output_data = read_stream(source, config, catalog, _A_STATE, limits).record.data
821837
if expected_error:
822838
assert len(output_data["logs"]) > 0, "Expected at least one log message with the expected error"
823839
error_message = output_data["logs"][0]
@@ -875,7 +891,7 @@ def test_handle_read_external_oauth_request(deployment_mode, token_url, expected
875891
source = create_source(config, limits)
876892

877893
with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False):
878-
output_data = read_stream(source, config, catalog, limits).record.data
894+
output_data = read_stream(source, config, catalog, _A_STATE, limits).record.data
879895
if expected_error:
880896
assert len(output_data["logs"]) > 0, "Expected at least one log message with the expected error"
881897
error_message = output_data["logs"][0]

0 commit comments

Comments
 (0)