Skip to content

Commit 2ac5248

Browse files
Emit record counts in state messages for concurrent streams (#35907)
Co-authored-by: brianjlai <[email protected]> Co-authored-by: Brian Lai <[email protected]>
1 parent c8bec40 commit 2ac5248

File tree

8 files changed

+297
-23
lines changed

8 files changed

+297
-23
lines changed

airbyte-cdk/python/airbyte_cdk/entrypoint.py

+27-7
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,24 @@
1010
import socket
1111
import sys
1212
import tempfile
13+
from collections import defaultdict
1314
from functools import wraps
14-
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union
15+
from typing import Any, DefaultDict, Iterable, List, Mapping, MutableMapping, Optional, Union
1516
from urllib.parse import urlparse
1617

1718
import requests
1819
from airbyte_cdk.connector import TConfig
1920
from airbyte_cdk.exception_handler import init_uncaught_exception_handler
2021
from airbyte_cdk.logger import init_logger
21-
from airbyte_cdk.models import AirbyteMessage, Status, Type
22-
from airbyte_cdk.models.airbyte_protocol import ConnectorSpecification # type: ignore [attr-defined]
22+
from airbyte_cdk.models import AirbyteMessage, FailureType, Status, Type
23+
from airbyte_cdk.models.airbyte_protocol import AirbyteStateStats, ConnectorSpecification # type: ignore [attr-defined]
2324
from airbyte_cdk.sources import Source
25+
from airbyte_cdk.sources.connector_state_manager import HashableStreamDescriptor
2426
from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit, split_config
25-
from airbyte_cdk.utils import is_cloud_environment
27+
from airbyte_cdk.utils import is_cloud_environment, message_utils
2628
from airbyte_cdk.utils.airbyte_secrets_utils import get_secrets, update_secrets
2729
from airbyte_cdk.utils.constants import ENV_REQUEST_CACHE_PATH
2830
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
29-
from airbyte_protocol.models import FailureType
3031
from requests import PreparedRequest, Response, Session
3132

3233
logger = init_logger("airbyte")
@@ -160,8 +161,27 @@ def read(
160161
if self.source.check_config_against_spec:
161162
self.validate_connection(source_spec, config)
162163

163-
yield from self.source.read(self.logger, config, catalog, state)
164-
yield from self._emit_queued_messages(self.source)
164+
stream_message_counter: DefaultDict[HashableStreamDescriptor, int] = defaultdict(int)
165+
for message in self.source.read(self.logger, config, catalog, state):
166+
yield self.handle_record_counts(message, stream_message_counter)
167+
for message in self._emit_queued_messages(self.source):
168+
yield self.handle_record_counts(message, stream_message_counter)
169+
170+
@staticmethod
171+
def handle_record_counts(message: AirbyteMessage, stream_message_count: DefaultDict[HashableStreamDescriptor, int]) -> AirbyteMessage:
172+
if message.type == Type.RECORD:
173+
stream_message_count[message_utils.get_stream_descriptor(message)] += 1
174+
175+
elif message.type == Type.STATE:
176+
stream_descriptor = message_utils.get_stream_descriptor(message)
177+
178+
# Set record count from the counter onto the state message
179+
message.state.sourceStats = message.state.sourceStats or AirbyteStateStats()
180+
message.state.sourceStats.recordCount = stream_message_count.get(stream_descriptor, 0)
181+
182+
# Reset the counter
183+
stream_message_count[stream_descriptor] = 0
184+
return message
165185

166186
@staticmethod
167187
def validate_connection(source_spec: ConnectorSpecification, config: TConfig) -> None:

airbyte-cdk/python/airbyte_cdk/sources/connector_state_manager.py

-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def create_state_message(self, stream_name: str, namespace: Optional[str]) -> Ai
8282
Generates an AirbyteMessage using the current per-stream state of a specified stream in either the per-stream or legacy format
8383
:param stream_name: The name of the stream for the message that is being created
8484
:param namespace: The namespace of the stream for the message that is being created
85-
:param send_per_stream_state: Decides which state format the message should be generated as
8685
:return: The Airbyte state message to be emitted by the connector during a sync
8786
"""
8887
hashable_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
3+
from airbyte_cdk.sources.connector_state_manager import HashableStreamDescriptor
4+
from airbyte_protocol.models import AirbyteMessage, Type
5+
6+
7+
def get_stream_descriptor(message: AirbyteMessage) -> HashableStreamDescriptor:
8+
if message.type == Type.RECORD:
9+
return HashableStreamDescriptor(name=message.record.stream, namespace=message.record.namespace)
10+
elif message.type == Type.STATE:
11+
if not message.state.stream or not message.state.stream.stream_descriptor:
12+
raise ValueError("State message was not in per-stream state format, which is required for record counts.")
13+
return HashableStreamDescriptor(
14+
name=message.state.stream.stream_descriptor.name, namespace=message.state.stream.stream_descriptor.namespace
15+
)
16+
else:
17+
raise NotImplementedError(f"get_stream_descriptor is not implemented for message type '{message.type}'.")

airbyte-cdk/python/unit_tests/sources/file_based/test_scenarios.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from airbyte_cdk.sources.file_based.stream.concurrent.cursor import AbstractConcurrentFileBasedCursor
1717
from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput
1818
from airbyte_cdk.test.entrypoint_wrapper import read as entrypoint_read
19+
from airbyte_cdk.utils import message_utils
1920
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
2021
from airbyte_protocol.models import AirbyteLogMessage, AirbyteMessage, ConfiguredAirbyteCatalog
2122
from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenario
@@ -71,7 +72,7 @@ def assert_exception(expected_exception: type[BaseException], output: Entrypoint
7172

7273

7374
def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[AbstractSource]) -> None:
74-
records, log_messages = output.records_and_state_messages, output.logs
75+
records_and_state_messages, log_messages = output.records_and_state_messages, output.logs
7576
logs = [message.log for message in log_messages if message.log.level.value in scenario.log_levels]
7677
if scenario.expected_records is None:
7778
return
@@ -85,7 +86,7 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac
8586
),
8687
)
8788
sorted_records = sorted(
88-
filter(lambda r: r.record, records),
89+
filter(lambda r: r.record, records_and_state_messages),
8990
key=lambda record: ",".join(
9091
f"{k}={v}" for k, v in sorted(record.record.data.items(), key=lambda items: (items[0], items[1])) if k != "emitted_at"
9192
),
@@ -104,8 +105,9 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac
104105
assert actual.record.stream == expected["stream"]
105106

106107
expected_states = list(filter(lambda e: "data" not in e, expected_records))
107-
states = list(filter(lambda r: r.state, records))
108+
states = list(filter(lambda r: r.state, records_and_state_messages))
108109
assert len(states) > 0, "No state messages emitted. Successful syncs should emit at least one stream state."
110+
_verify_state_record_counts(sorted_records, states)
109111

110112
if hasattr(scenario.source, "cursor_cls") and issubclass(scenario.source.cursor_cls, AbstractConcurrentFileBasedCursor):
111113
# Only check the last state emitted because we don't know the order the others will be in.
@@ -126,9 +128,33 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac
126128
_verify_analytics(analytics, scenario.expected_analytics)
127129

128130

131+
def _verify_state_record_counts(records: List[AirbyteMessage], states: List[AirbyteMessage]) -> None:
132+
actual_record_counts = {}
133+
for record in records:
134+
stream_descriptor = message_utils.get_stream_descriptor(record)
135+
actual_record_counts[stream_descriptor] = actual_record_counts.get(stream_descriptor, 0) + 1
136+
137+
state_record_count_sums = {}
138+
for state_message in states:
139+
stream_descriptor = message_utils.get_stream_descriptor(state_message)
140+
state_record_count_sums[stream_descriptor] = (
141+
state_record_count_sums.get(stream_descriptor, 0)
142+
+ state_message.state.sourceStats.recordCount
143+
)
144+
145+
for stream, actual_count in actual_record_counts.items():
146+
assert state_record_count_sums.get(stream) == actual_count
147+
148+
# We can have extra keys in state_record_count_sums if we processed a stream and reported 0 records
149+
extra_keys = state_record_count_sums.keys() - actual_record_counts.keys()
150+
for stream in extra_keys:
151+
assert state_record_count_sums[stream] == 0
152+
153+
129154
def _verify_analytics(analytics: List[AirbyteMessage], expected_analytics: Optional[List[AirbyteAnalyticsTraceMessage]]) -> None:
130155
if expected_analytics:
131-
assert len(analytics) == len(expected_analytics), \
156+
assert len(analytics) == len(
157+
expected_analytics), \
132158
f"Number of actual analytics messages ({len(analytics)}) did not match expected ({len(expected_analytics)})"
133159
for actual, expected in zip(analytics, expected_analytics):
134160
actual_type, actual_value = actual.trace.analytics.type, actual.trace.analytics.value

airbyte-cdk/python/unit_tests/sources/mock_server_tests/test_mock_server_abstract_source.py

+11
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def test_full_refresh_sync(self, http_mocker):
205205
validate_message_order([Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages)
206206
assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "users"
207207
assert actual_messages.state_messages[0].state.stream.stream_state == {"__ab_full_refresh_state_message": True}
208+
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 2
208209

209210
@HttpMocker()
210211
def test_full_refresh_with_slices(self, http_mocker):
@@ -232,6 +233,7 @@ def test_full_refresh_with_slices(self, http_mocker):
232233
validate_message_order([Type.RECORD, Type.RECORD, Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages)
233234
assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "dividers"
234235
assert actual_messages.state_messages[0].state.stream.stream_state == {"__ab_full_refresh_state_message": True}
236+
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 4
235237

236238

237239
@freezegun.freeze_time(_NOW)
@@ -264,8 +266,10 @@ def test_incremental_sync(self, http_mocker):
264266
validate_message_order([Type.RECORD, Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages)
265267
assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "planets"
266268
assert actual_messages.state_messages[0].state.stream.stream_state == {"created_at": last_record_date_0}
269+
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 3
267270
assert actual_messages.state_messages[1].state.stream.stream_descriptor.name == "planets"
268271
assert actual_messages.state_messages[1].state.stream.stream_state == {"created_at": last_record_date_1}
272+
assert actual_messages.state_messages[1].state.sourceStats.recordCount == 2
269273

270274
@HttpMocker()
271275
def test_incremental_running_as_full_refresh(self, http_mocker):
@@ -295,6 +299,7 @@ def test_incremental_running_as_full_refresh(self, http_mocker):
295299
validate_message_order([Type.RECORD, Type.RECORD, Type.RECORD, Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages)
296300
assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "planets"
297301
assert actual_messages.state_messages[0].state.stream.stream_state == {"created_at": last_record_date_1}
302+
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 5
298303

299304
@HttpMocker()
300305
def test_legacy_incremental_sync(self, http_mocker):
@@ -324,8 +329,10 @@ def test_legacy_incremental_sync(self, http_mocker):
324329
validate_message_order([Type.RECORD, Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages)
325330
assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "legacies"
326331
assert actual_messages.state_messages[0].state.stream.stream_state == {"created_at": last_record_date_0}
332+
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 3
327333
assert actual_messages.state_messages[1].state.stream.stream_descriptor.name == "legacies"
328334
assert actual_messages.state_messages[1].state.stream.stream_state == {"created_at": last_record_date_1}
335+
assert actual_messages.state_messages[1].state.sourceStats.recordCount == 2
329336

330337

331338
@freezegun.freeze_time(_NOW)
@@ -395,12 +402,16 @@ def test_incremental_and_full_refresh_streams(self, http_mocker):
395402
], actual_messages.records_and_state_messages)
396403
assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "users"
397404
assert actual_messages.state_messages[0].state.stream.stream_state == {"__ab_full_refresh_state_message": True}
405+
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 2
398406
assert actual_messages.state_messages[1].state.stream.stream_descriptor.name == "planets"
399407
assert actual_messages.state_messages[1].state.stream.stream_state == {"created_at": last_record_date_0}
408+
assert actual_messages.state_messages[1].state.sourceStats.recordCount == 3
400409
assert actual_messages.state_messages[2].state.stream.stream_descriptor.name == "planets"
401410
assert actual_messages.state_messages[2].state.stream.stream_state == {"created_at": last_record_date_1}
411+
assert actual_messages.state_messages[2].state.sourceStats.recordCount == 2
402412
assert actual_messages.state_messages[3].state.stream.stream_descriptor.name == "dividers"
403413
assert actual_messages.state_messages[3].state.stream.stream_state == {"__ab_full_refresh_state_message": True}
414+
assert actual_messages.state_messages[3].state.sourceStats.recordCount == 4
404415

405416

406417
def emits_successful_sync_status_messages(status_messages: List[AirbyteStreamStatus]) -> bool:

airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@
255255
primary_key=[],
256256
cursor_field=None,
257257
logger=logging.getLogger("test_logger"),
258-
cursor=FinalStateCursor(stream_name="stream1", stream_namespace=None, message_repository=_message_repository),
258+
cursor=FinalStateCursor(stream_name="stream2", stream_namespace=None, message_repository=_message_repository),
259259
),
260260
]
261261
)

0 commit comments

Comments
 (0)