Skip to content

Commit a2e908d

Browse files
authored
connector builder: Set state on stream slices (#37109)
1 parent 58201ca commit a2e908d

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

airbyte-cdk/python/airbyte_cdk/connector_builder/message_grouper.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def _get_message_groups(
164164
current_slice_pages: List[StreamReadPages] = []
165165
current_page_request: Optional[HttpRequest] = None
166166
current_page_response: Optional[HttpResponse] = None
167+
latest_state_message: Optional[Dict[str, Any]] = None
167168

168169
while records_count < limit and (message := next(messages, None)):
169170
json_object = self._parse_json(message.log) if message.type == MessageType.LOG else None
@@ -180,7 +181,7 @@ def _get_message_groups(
180181
and message.type == MessageType.LOG
181182
and message.log.message.startswith(SliceLogger.SLICE_LOG_PREFIX)
182183
):
183-
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor)
184+
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor, state=latest_state_message)
184185
current_slice_descriptor = self._parse_slice_description(message.log.message)
185186
current_slice_pages = []
186187
at_least_one_page_in_group = False
@@ -222,10 +223,12 @@ def _get_message_groups(
222223
datetime_format_inferrer.accumulate(message.record)
223224
elif message.type == MessageType.CONTROL and message.control.type == OrchestratorType.CONNECTOR_CONFIG:
224225
yield message.control
226+
elif message.type == MessageType.STATE:
227+
latest_state_message = message.state
225228
else:
226229
if current_page_request or current_page_response or current_page_records:
227230
self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records)
228-
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor)
231+
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor, state=latest_state_message)
229232

230233
@staticmethod
231234
def _need_to_close_page(at_least_one_page_in_group: bool, message: AirbyteMessage, json_message: Optional[Dict[str, Any]]) -> bool:

airbyte-cdk/python/unit_tests/connector_builder/test_message_grouper.py

+14
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
AirbyteLogMessage,
1616
AirbyteMessage,
1717
AirbyteRecordMessage,
18+
AirbyteStateMessage,
19+
AirbyteStreamState,
1820
Level,
1921
OrchestratorType,
22+
StreamDescriptor,
2023
)
2124
from airbyte_cdk.models import Type as MessageType
2225
from unit_tests.connector_builder.utils import create_configured_catalog
@@ -470,6 +473,7 @@ def test_get_grouped_messages_with_many_slices(mock_entrypoint_read: Mock) -> No
470473
request_response_log_message(request, response, url),
471474
record_message("hashiras", {"name": "Obanai Iguro"}),
472475
request_response_log_message(request, response, url),
476+
state_message("hashiras", {"a_timestamp": 123}),
473477
]
474478
),
475479
)
@@ -486,13 +490,16 @@ def test_get_grouped_messages_with_many_slices(mock_entrypoint_read: Mock) -> No
486490
assert stream_read.slices[0].slice_descriptor == {"descriptor": "first_slice"}
487491
assert len(stream_read.slices[0].pages) == 1
488492
assert len(stream_read.slices[0].pages[0].records) == 1
493+
assert stream_read.slices[0].state is None
489494

490495
assert stream_read.slices[1].slice_descriptor == {"descriptor": "second_slice"}
491496
assert len(stream_read.slices[1].pages) == 3
492497
assert len(stream_read.slices[1].pages[0].records) == 2
493498
assert len(stream_read.slices[1].pages[1].records) == 1
494499
assert len(stream_read.slices[1].pages[2].records) == 0
495500

501+
assert stream_read.slices[1].state.stream.stream_state == {"a_timestamp": 123}
502+
496503

497504
@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read")
498505
def test_get_grouped_messages_given_maximum_number_of_slices_then_test_read_limit_reached(mock_entrypoint_read: Mock) -> None:
@@ -698,6 +705,13 @@ def record_message(stream: str, data: Mapping[str, Any]) -> AirbyteMessage:
698705
return AirbyteMessage(type=MessageType.RECORD, record=AirbyteRecordMessage(stream=stream, data=data, emitted_at=1234))
699706

700707

708+
def state_message(stream: str, data: Mapping[str, Any]) -> AirbyteMessage:
709+
return AirbyteMessage(type=MessageType.STATE, state=AirbyteStateMessage(stream=AirbyteStreamState(
710+
stream_descriptor=StreamDescriptor(name=stream),
711+
stream_state=data
712+
)))
713+
714+
701715
def slice_message(slice_descriptor: str = '{"key": "value"}') -> AirbyteMessage:
702716
return AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message="slice:" + slice_descriptor))
703717

0 commit comments

Comments
 (0)