Skip to content

Commit fac0dc4

Browse files
Joe Reuteraaronsteers
authored andcommitted
airbyte-lib: Stream state (#34778)
Co-authored-by: Aaron Steers <[email protected]>
1 parent bca7587 commit fac0dc4

File tree

9 files changed

+276
-19
lines changed

9 files changed

+276
-19
lines changed

airbyte-lib/airbyte_lib/_file_writers/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
if TYPE_CHECKING:
1919
import pyarrow as pa
2020

21+
from airbyte_protocol.models import (
22+
AirbyteStateMessage,
23+
)
24+
2125

2226
DEFAULT_BATCH_SIZE = 10000
2327

@@ -109,3 +113,14 @@ def cleanup_batch(
109113
Subclasses should override `_cleanup_batch` instead.
110114
"""
111115
self._cleanup_batch(stream_name, batch_id, batch_handle)
116+
117+
@overrides
118+
def _finalize_state_messages(
119+
self,
120+
stream_name: str,
121+
state_messages: list[AirbyteStateMessage],
122+
) -> None:
123+
"""
124+
State messages are not used in file writers, so this method is a no-op.
125+
"""
126+
pass

airbyte-lib/airbyte_lib/_processors.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def __init__(
7878
)
7979
raise TypeError(err_msg)
8080

81+
self.source_catalog: ConfiguredAirbyteCatalog | None = None
82+
self._source_name: str | None = None
83+
8184
self._pending_batches: dict[str, dict[str, Any]] = defaultdict(lambda: {}, {})
8285
self._finalized_batches: dict[str, dict[str, Any]] = defaultdict(lambda: {}, {})
8386

@@ -301,6 +304,16 @@ def _finalize_batches(
301304

302305
return batches_to_finalize
303306

307+
@abc.abstractmethod
308+
def _finalize_state_messages(
309+
self,
310+
stream_name: str,
311+
state_messages: list[AirbyteStateMessage],
312+
) -> None:
313+
"""Handle state messages.
314+
Might be a no-op if the processor doesn't handle incremental state."""
315+
pass
316+
304317
@final
305318
@contextlib.contextmanager
306319
def _finalizing_batches(
@@ -318,6 +331,7 @@ def _finalizing_batches(
318331

319332
progress.log_batches_finalizing(stream_name, len(batches_to_finalize))
320333
yield batches_to_finalize
334+
self._finalize_state_messages(stream_name, state_messages_to_finalize)
321335
progress.log_batches_finalized(stream_name, len(batches_to_finalize))
322336

323337
self._finalized_batches[stream_name].update(batches_to_finalize)

airbyte-lib/airbyte_lib/caches/_catalog_manager.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
import json
77
from typing import TYPE_CHECKING, Callable
88

9-
from sqlalchemy import Column, String
9+
from sqlalchemy import Column, DateTime, String
1010
from sqlalchemy.ext.declarative import declarative_base
1111
from sqlalchemy.orm import Session
12+
from sqlalchemy.sql import func
1213

1314
from airbyte_protocol.models import (
15+
AirbyteStateMessage,
1416
AirbyteStream,
1517
ConfiguredAirbyteCatalog,
1618
ConfiguredAirbyteStream,
@@ -25,6 +27,9 @@
2527
from sqlalchemy.engine import Engine
2628

2729
STREAMS_TABLE_NAME = "_airbytelib_streams"
30+
STATE_TABLE_NAME = "_airbytelib_state"
31+
32+
GLOBAL_STATE_STREAM_NAMES = ["_GLOBAL", "_LEGACY"]
2833

2934
Base = declarative_base()
3035

@@ -38,7 +43,24 @@ class CachedStream(Base): # type: ignore[valid-type,misc]
3843
catalog_metadata = Column(String)
3944

4045

46+
class StreamState(Base): # type: ignore[valid-type,misc]
47+
__tablename__ = STATE_TABLE_NAME
48+
49+
source_name = Column(String)
50+
stream_name = Column(String)
51+
table_name = Column(String, primary_key=True)
52+
state_json = Column(String)
53+
last_updated = Column(DateTime(timezone=True), onupdate=func.now(), default=func.now())
54+
55+
4156
class CatalogManager:
57+
"""
58+
A class to manage the stream catalog of data synced to a cache:
59+
* What streams exist and to what tables they map
60+
* The JSON schema for each stream
61+
* The state of each stream if available
62+
"""
63+
4264
def __init__(
4365
self,
4466
engine: Engine,
@@ -68,6 +90,56 @@ def _ensure_internal_tables(self) -> None:
6890
engine = self._engine
6991
Base.metadata.create_all(engine)
7092

93+
def save_state(
94+
self,
95+
source_name: str,
96+
state: AirbyteStateMessage,
97+
stream_name: str,
98+
) -> None:
99+
self._ensure_internal_tables()
100+
engine = self._engine
101+
with Session(engine) as session:
102+
session.query(StreamState).filter(
103+
StreamState.table_name == self._table_name_resolver(stream_name)
104+
).delete()
105+
session.commit()
106+
session.add(
107+
StreamState(
108+
source_name=source_name,
109+
stream_name=stream_name,
110+
table_name=self._table_name_resolver(stream_name),
111+
state_json=state.json(),
112+
)
113+
)
114+
session.commit()
115+
116+
def get_state(
117+
self,
118+
source_name: str,
119+
streams: list[str],
120+
) -> list[dict] | None:
121+
self._ensure_internal_tables()
122+
engine = self._engine
123+
with Session(engine) as session:
124+
states = (
125+
session.query(StreamState)
126+
.filter(
127+
StreamState.source_name == source_name,
128+
StreamState.stream_name.in_([*streams, *GLOBAL_STATE_STREAM_NAMES]),
129+
)
130+
.all()
131+
)
132+
if not states:
133+
return None
134+
# Only return the states if the table name matches what the current cache
135+
# would generate. Otherwise consider it part of a different cache.
136+
states = [
137+
state
138+
for state in states
139+
if state.table_name == self._table_name_resolver(state.stream_name)
140+
]
141+
return [json.loads(state.state_json) for state in states]
142+
71143
def register_source(
72144
self,
73145
source_name: str,

airbyte-lib/airbyte_lib/caches/base.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from sqlalchemy.sql.base import Executable
4747

4848
from airbyte_protocol.models import (
49+
AirbyteStateMessage,
4950
ConfiguredAirbyteCatalog,
5051
)
5152

@@ -561,6 +562,36 @@ def _finalize_batches(
561562
# Return the batch handles as measure of work completed.
562563
return batches_to_finalize
563564

565+
@overrides
566+
def _finalize_state_messages(
567+
self,
568+
stream_name: str,
569+
state_messages: list[AirbyteStateMessage],
570+
) -> None:
571+
"""Handle state messages by passing them to the catalog manager."""
572+
if not self._catalog_manager:
573+
raise exc.AirbyteLibInternalError(
574+
message="Catalog manager should exist but does not.",
575+
)
576+
if state_messages and self._source_name:
577+
self._catalog_manager.save_state(
578+
source_name=self._source_name,
579+
stream_name=stream_name,
580+
state=state_messages[-1],
581+
)
582+
583+
def get_state(self) -> list[dict]:
584+
"""Return the current state of the source."""
585+
if not self._source_name:
586+
return []
587+
if not self._catalog_manager:
588+
raise exc.AirbyteLibInternalError(
589+
message="Catalog manager should exist but does not.",
590+
)
591+
return (
592+
self._catalog_manager.get_state(self._source_name, list(self._streams_with_data)) or []
593+
)
594+
564595
def _execute_sql(self, sql: str | TextClause | Executable) -> CursorResult:
565596
"""Execute the given SQL statement."""
566597
if isinstance(sql, str):
@@ -881,6 +912,7 @@ def register_source(
881912
882913
This method is called by the source when it is initialized.
883914
"""
915+
self._source_name = source_name
884916
self._ensure_schema_exists()
885917
super().register_source(
886918
source_name,

airbyte-lib/airbyte_lib/source.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from airbyte_protocol.models import (
1414
AirbyteCatalog,
1515
AirbyteMessage,
16+
AirbyteStateMessage,
1617
ConfiguredAirbyteCatalog,
1718
ConfiguredAirbyteStream,
1819
ConnectorSpecification,
@@ -46,7 +47,7 @@
4647

4748

4849
@contextmanager
49-
def as_temp_files(files_contents: list[dict]) -> Generator[list[str], Any, None]:
50+
def as_temp_files(files_contents: list[Any]) -> Generator[list[str], Any, None]:
5051
"""Write the given contents to temporary files and yield the file paths as strings."""
5152
temp_files: list[Any] = []
5253
try:
@@ -140,6 +141,10 @@ def set_config(
140141

141142
self._config_dict = config
142143

144+
def get_config(self) -> dict[str, Any]:
145+
"""Get the config for the connector."""
146+
return self._config
147+
143148
@property
144149
def _config(self) -> dict[str, Any]:
145150
if self._config_dict is None:
@@ -252,7 +257,7 @@ def configured_catalog(self) -> ConfiguredAirbyteCatalog:
252257
# TODO: Set sync modes and primary key to a sensible adaptive default
253258
ConfiguredAirbyteStream(
254259
stream=stream,
255-
sync_mode=SyncMode.full_refresh,
260+
sync_mode=SyncMode.incremental,
256261
destination_sync_mode=DestinationSyncMode.overwrite,
257262
primary_key=stream.source_defined_primary_key,
258263
)
@@ -309,7 +314,6 @@ def _with_missing_columns(records: Iterable[dict[str, Any]]) -> Iterator[dict[st
309314
self._read_with_catalog(
310315
streaming_cache_info,
311316
configured_catalog,
312-
force_full_refresh=True, # Always full refresh when skipping the cache
313317
),
314318
)
315319
)
@@ -360,8 +364,7 @@ def uninstall(self) -> None:
360364
def _read(
361365
self,
362366
cache_info: CacheTelemetryInfo,
363-
*,
364-
force_full_refresh: bool,
367+
state: list[AirbyteStateMessage] | None = None,
365368
) -> Iterable[AirbyteMessage]:
366369
"""
367370
Call read on the connector.
@@ -379,15 +382,14 @@ def _read(
379382
yield from self._read_with_catalog(
380383
cache_info,
381384
catalog=self.configured_catalog,
382-
force_full_refresh=force_full_refresh,
385+
state=state,
383386
)
384387

385388
def _read_with_catalog(
386389
self,
387390
cache_info: CacheTelemetryInfo,
388391
catalog: ConfiguredAirbyteCatalog,
389-
*,
390-
force_full_refresh: bool,
392+
state: list[AirbyteStateMessage] | None = None,
391393
) -> Iterator[AirbyteMessage]:
392394
"""Call read on the connector.
393395
@@ -397,21 +399,27 @@ def _read_with_catalog(
397399
* Listen to the messages and return the AirbyteRecordMessages that come along.
398400
* Send out telemetry on the performed sync (with information about which source was used and
399401
the type of the cache)
400-
401-
TODO: When we add support for incremental syncs, we should only send `--state <state_file>`
402-
if force_full_refresh is False.
403402
"""
404-
_ = force_full_refresh # TODO: Use this decide whether to send `--state <state_file>`
405403
source_tracking_information = self.executor.get_telemetry_info()
406404
send_telemetry(source_tracking_information, cache_info, SyncState.STARTED)
407405
try:
408-
with as_temp_files([self._config, catalog.json()]) as [
406+
with as_temp_files(
407+
[self._config, catalog.json(), json.dumps(state) if state else "[]"]
408+
) as [
409409
config_file,
410410
catalog_file,
411+
state_file,
411412
]:
412413
yield from self._execute(
413-
# TODO: Add support for incremental syncs by sending `--state <state_file>`
414-
["read", "--config", config_file, "--catalog", catalog_file],
414+
[
415+
"read",
416+
"--config",
417+
config_file,
418+
"--catalog",
419+
catalog_file,
420+
"--state",
421+
state_file,
422+
],
415423
)
416424
except Exception:
417425
send_telemetry(
@@ -520,11 +528,12 @@ def read(
520528
incoming_source_catalog=self.configured_catalog,
521529
stream_names=set(self.get_selected_streams()),
522530
)
531+
state = cache.get_state() if not force_full_refresh else None
523532
cache.process_airbyte_messages(
524533
self._tally_records(
525534
self._read(
526535
cache.get_telemetry_info(),
527-
force_full_refresh=force_full_refresh,
536+
state=state,
528537
),
529538
),
530539
write_strategy=write_strategy,

airbyte-lib/docs/generated/airbyte_lib.html

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)