diff --git a/airbyte-lib/airbyte_lib/_processors.py b/airbyte-lib/airbyte_lib/_processors.py index d5eba9f2c00f1..4418463a8a1d9 100644 --- a/airbyte-lib/airbyte_lib/_processors.py +++ b/airbyte-lib/airbyte_lib/_processors.py @@ -32,6 +32,7 @@ from airbyte_lib import exceptions as exc from airbyte_lib._util import protocol_util # Internal utility functions +from airbyte_lib.progress import progress if TYPE_CHECKING: @@ -40,7 +41,7 @@ from airbyte_lib.config import CacheConfigBase -DEFAULT_BATCH_SIZE = 10000 +DEFAULT_BATCH_SIZE = 10_000 class BatchHandle: @@ -95,7 +96,7 @@ def register_source( For now, only one source at a time is supported. If this method is called multiple times, the last call will overwrite the previous one. - TODO: Expand this to handle mutliple sources. + TODO: Expand this to handle multiple sources. """ _ = source_name self.source_catalog = incoming_source_catalog @@ -157,6 +158,7 @@ def process_airbyte_messages( if len(stream_batch) >= max_batch_size: record_batch = pa.Table.from_pylist(stream_batch) self._process_batch(stream_name, record_batch) + progress.log_batch_written(stream_name, len(stream_batch)) stream_batch.clear() elif message.type is Type.STATE: @@ -180,14 +182,16 @@ def process_airbyte_messages( ) # We are at the end of the stream. Process whatever else is queued. - for stream_name, batch in stream_batches.items(): - if batch: - record_batch = pa.Table.from_pylist(batch) + for stream_name, stream_batch in stream_batches.items(): + if stream_batch: + record_batch = pa.Table.from_pylist(stream_batch) self._process_batch(stream_name, record_batch) + progress.log_batch_written(stream_name, len(stream_batch)) # Finalize any pending batches for stream_name in list(self._pending_batches.keys()): self._finalize_batches(stream_name) + progress.log_stream_finalized(stream_name) @final def _process_batch( @@ -287,7 +291,10 @@ def _finalizing_batches( state_messages_to_finalize = self._pending_state_messages[stream_name].copy() self._pending_batches[stream_name].clear() self._pending_state_messages[stream_name].clear() + + progress.log_batches_finalizing(stream_name, len(batches_to_finalize)) yield batches_to_finalize + progress.log_batches_finalized(stream_name, len(batches_to_finalize)) self._finalized_batches[stream_name].update(batches_to_finalize) self._finalized_state_messages[stream_name] += state_messages_to_finalize diff --git a/airbyte-lib/airbyte_lib/progress.py b/airbyte-lib/airbyte_lib/progress.py new file mode 100644 index 0000000000000..d1b7c5355fe1d --- /dev/null +++ b/airbyte-lib/airbyte_lib/progress.py @@ -0,0 +1,320 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. + +"""A simple progress bar for the command line and IPython notebooks.""" +from __future__ import annotations + +import datetime +import math +import time +from contextlib import suppress +from typing import cast + +from rich.errors import LiveError +from rich.live import Live as RichLive +from rich.markdown import Markdown as RichMarkdown + + +try: + IS_NOTEBOOK = True + from IPython import display as ipy_display + +except ImportError: + ipy_display = None + IS_NOTEBOOK = False + + +MAX_UPDATE_FREQUENCY = 1000 +"""The max number of records to read before updating the progress bar.""" + + +def _to_time_str(timestamp: float) -> str: + """Convert a timestamp float to a local time string. + + For now, we'll just use UTC to avoid breaking tests. In the future, we should + return a local time string. + """ + datetime_obj = datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) + # TODO: Uncomment this line when we can get tests to properly account for local timezones. + # For now, we'll just use UTC to avoid breaking tests. + # datetime_obj = datetime_obj.astimezone() + return datetime_obj.strftime("%H:%M:%S") + + +def _get_elapsed_time_str(seconds: int) -> str: + """Return duration as a string. + + Seconds are included until 10 minutes is exceeded. + Minutes are always included after 1 minute elapsed. + Hours are always included after 1 hour elapsed. + """ + if seconds <= 60: # noqa: PLR2004 # Magic numbers OK here. + return f"{seconds} seconds" + + if seconds < 60 * 10: + minutes = seconds // 60 + seconds = seconds % 60 + return f"{minutes}min {seconds}s" + + if seconds < 60 * 60: + minutes = seconds // 60 + seconds = seconds % 60 + return f"{minutes}min" + + hours = seconds // (60 * 60) + minutes = (seconds % (60 * 60)) // 60 + return f"{hours}hr {minutes}min" + + +class ReadProgress: + """A simple progress bar for the command line and IPython notebooks.""" + + def __init__(self) -> None: + """Initialize the progress tracker.""" + # Streams expected (for progress bar) + self.num_streams_expected = 0 + + # Reads + self.read_start_time = time.time() + self.read_end_time: float | None = None + self.total_records_read = 0 + + # Writes + self.total_records_written = 0 + self.total_batches_written = 0 + self.written_stream_names: set[str] = set() + + # Finalization + self.finalize_start_time: float | None = None + self.finalize_end_time: float | None = None + self.total_records_finalized = 0 + self.total_batches_finalized = 0 + self.finalized_stream_names: set[str] = set() + + self.last_update_time: float | None = None + + self.rich_view: RichLive | None = None + if not IS_NOTEBOOK: + # If we're in a terminal, use a Rich view to display the progress updates. + self.rich_view = RichLive() + try: + self.rich_view.start() + except LiveError: + self.rich_view = None + + def __del__(self) -> None: + """Close the Rich view.""" + if self.rich_view: + with suppress(Exception): + self.rich_view.stop() + + def log_success(self) -> None: + """Log success and stop tracking progress.""" + if self.finalize_end_time is None: + # If we haven't already finalized, do so now. + + self.finalize_end_time = time.time() + + self.update_display(force_refresh=True) + if self.rich_view: + with suppress(Exception): + self.rich_view.stop() + + def reset(self, num_streams_expected: int) -> None: + """Reset the progress tracker.""" + # Streams expected (for progress bar) + self.num_streams_expected = num_streams_expected + + # Reads + self.read_start_time = time.time() + self.read_end_time = None + self.total_records_read = 0 + + # Writes + self.total_records_written = 0 + self.total_batches_written = 0 + self.written_stream_names = set() + + # Finalization + self.finalize_start_time = None + self.finalize_end_time = None + self.total_records_finalized = 0 + self.total_batches_finalized = 0 + self.finalized_stream_names = set() + + @property + def elapsed_seconds(self) -> int: + """Return the number of seconds elapsed since the read operation started.""" + if self.finalize_end_time: + return int(self.finalize_end_time - self.read_start_time) + + return int(time.time() - self.read_start_time) + + @property + def elapsed_time_string(self) -> str: + """Return duration as a string.""" + return _get_elapsed_time_str(self.elapsed_seconds) + + @property + def elapsed_seconds_since_last_update(self) -> float | None: + """Return the number of seconds elapsed since the last update.""" + if self.last_update_time is None: + return None + + return time.time() - self.last_update_time + + @property + def elapsed_read_seconds(self) -> int: + """Return the number of seconds elapsed since the read operation started.""" + if self.read_end_time is None: + return int(time.time() - self.read_start_time) + + return int(self.read_end_time - self.read_start_time) + + @property + def elapsed_read_time_string(self) -> str: + """Return duration as a string.""" + return _get_elapsed_time_str(self.elapsed_read_seconds) + + @property + def elapsed_finalization_seconds(self) -> int: + """Return the number of seconds elapsed since the read operation started.""" + if self.finalize_start_time is None: + return 0 + if self.finalize_end_time is None: + return int(time.time() - self.finalize_start_time) + return int(self.finalize_end_time - self.finalize_start_time) + + @property + def elapsed_finalization_time_str(self) -> str: + """Return duration as a string.""" + return _get_elapsed_time_str(self.elapsed_finalization_seconds) + + def log_records_read(self, new_total_count: int) -> None: + """Load a number of records read.""" + self.total_records_read = new_total_count + + # This is some math to make updates adaptive to the scale of records read. + # We want to update the display more often when the count is low, and less + # often when the count is high. + updated_period = min( + MAX_UPDATE_FREQUENCY, 10 ** math.floor(math.log10(self.total_records_read) / 4) + ) + if self.total_records_read % updated_period != 0: + return + + self.update_display() + + def log_batch_written(self, stream_name: str, batch_size: int) -> None: + """Log that a batch has been written. + + Args: + stream_name: The name of the stream. + batch_size: The number of records in the batch. + """ + self.total_records_written += batch_size + self.total_batches_written += 1 + self.written_stream_names.add(stream_name) + self.update_display() + + def log_batches_finalizing(self, stream_name: str, num_batches: int) -> None: + """Log that batch are ready to be finalized. + + In our current implementation, we ignore the stream name and number of batches. + We just use this as a signal that we're finished reading and have begun to + finalize any accumulated batches. + """ + _ = stream_name, num_batches # unused for now + if self.finalize_start_time is None: + self.read_end_time = time.time() + self.finalize_start_time = self.read_end_time + + self.update_display(force_refresh=True) + + def log_batches_finalized(self, stream_name: str, num_batches: int) -> None: + """Log that a batch has been finalized.""" + _ = stream_name # unused for now + self.total_batches_finalized += num_batches + self.update_display(force_refresh=True) + + def log_stream_finalized(self, stream_name: str) -> None: + """Log that a stream has been finalized.""" + self.finalized_stream_names.add(stream_name) + if len(self.finalized_stream_names) == self.num_streams_expected: + self.log_success() + + self.update_display(force_refresh=True) + + def update_display(self, *, force_refresh: bool = False) -> None: + """Update the display.""" + # Don't update more than twice per second unless force_refresh is True. + if ( + not force_refresh + and self.last_update_time # if not set, then we definitely need to update + and cast(float, self.elapsed_seconds_since_last_update) < 0.5 # noqa: PLR2004 + ): + return + + status_message = self._get_status_message() + + if IS_NOTEBOOK: + # We're in a notebook so use the IPython display. + ipy_display.clear_output(wait=True) + ipy_display.display(ipy_display.Markdown(status_message)) + + elif self.rich_view is not None: + self.rich_view.update(RichMarkdown(status_message)) + + self.last_update_time = time.time() + + def _get_status_message(self) -> str: + """Compile and return a status message.""" + # Format start time as a friendly string in local timezone: + start_time_str = _to_time_str(self.read_start_time) + records_per_second: float = 0.0 + if self.elapsed_read_seconds > 0: + records_per_second = round(self.total_records_read / self.elapsed_read_seconds, 1) + status_message = ( + f"## Read Progress\n\n" + f"Started reading at {start_time_str}.\n\n" + f"Read **{self.total_records_read:,}** records " + f"over **{self.elapsed_read_time_string}** " + f"({records_per_second:,} records / second).\n\n" + ) + if self.total_records_written > 0: + status_message += ( + f"Wrote **{self.total_records_written:,}** records " + f"over {self.total_batches_written:,} batches.\n\n" + ) + if self.read_end_time is not None: + read_end_time_str = _to_time_str(self.read_end_time) + status_message += f"Finished reading at {read_end_time_str}.\n\n" + if self.finalize_start_time is not None: + finalize_start_time_str = _to_time_str(self.finalize_start_time) + status_message += f"Started finalizing streams at {finalize_start_time_str}.\n\n" + status_message += ( + f"Finalized **{self.total_batches_finalized}** batches " + f"over {self.elapsed_finalization_time_str}.\n\n" + ) + if self.finalized_stream_names: + status_message += ( + f"Completed {len(self.finalized_stream_names)} " + + (f"out of {self.num_streams_expected} " if self.num_streams_expected else "") + + "streams:\n\n" + ) + for stream_name in self.finalized_stream_names: + status_message += f" - {stream_name}\n" + + status_message += "\n\n" + + if self.finalize_end_time is not None: + completion_time_str = _to_time_str(self.finalize_end_time) + status_message += ( + f"Completed writing at {completion_time_str}. " + f"Total time elapsed: {self.elapsed_time_string}\n\n" + ) + status_message += "\n------------------------------------------------\n" + + return status_message + + +progress = ReadProgress() diff --git a/airbyte-lib/airbyte_lib/source.py b/airbyte-lib/airbyte_lib/source.py index 4db25b3afae09..120d25adba05e 100644 --- a/airbyte-lib/airbyte_lib/source.py +++ b/airbyte-lib/airbyte_lib/source.py @@ -26,6 +26,7 @@ from airbyte_lib._factories.cache_factories import get_default_cache from airbyte_lib._util import protocol_util # Internal utility functions from airbyte_lib.datasets._lazy import LazyDataset +from airbyte_lib.progress import progress from airbyte_lib.results import ReadResult from airbyte_lib.telemetry import ( CacheTelemetryInfo, @@ -76,7 +77,6 @@ def __init__( If config is provided, it will be validated against the spec if validate is True. """ - self._processed_records = 0 self.executor = executor self.name = name self._processed_records = 0 @@ -408,9 +408,12 @@ def _tally_records( ) -> Generator[AirbyteRecordMessage, Any, None]: """This method simply tallies the number of records processed and yields the messages.""" self._processed_records = 0 # Reset the counter before we start + progress.reset(len(self._selected_stream_names or [])) + for message in messages: self._processed_records += 1 yield message + progress.log_records_read(self._processed_records) def read(self, cache: SQLCacheBase | None = None) -> ReadResult: if cache is None: diff --git a/airbyte-lib/docs/generated/airbyte_lib/caches.html b/airbyte-lib/docs/generated/airbyte_lib/caches.html index bbebc9ac9f467..b39a5230e6582 100644 --- a/airbyte-lib/docs/generated/airbyte_lib/caches.html +++ b/airbyte-lib/docs/generated/airbyte_lib/caches.html @@ -679,7 +679,7 @@
Inherited Members

For now, only one source at a time is supported. If this method is called multiple times, the last call will overwrite the previous one.

-

TODO: Expand this to handle mutliple sources.

+

TODO: Expand this to handle multiple sources.

diff --git a/airbyte-lib/examples/run_faker.py b/airbyte-lib/examples/run_faker.py new file mode 100644 index 0000000000000..3583a370000d0 --- /dev/null +++ b/airbyte-lib/examples/run_faker.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""A simple test of AirbyteLib, using the Faker source connector. + +Usage (from airbyte-lib root directory): +> poetry run python ./examples/run_faker.py + +No setup is needed, but you may need to delete the .venv-source-faker folder +if your installation gets interrupted or corrupted. +""" +from __future__ import annotations + +import airbyte_lib as ab + + +SCALE = 1_000_000 # Number of records to generate between users and purchases. + + +source = ab.get_connector( + "source-faker", + pip_url="-e ../airbyte-integrations/connectors/source-faker", + config={"count": SCALE / 2}, + install_if_missing=True, +) +source.check() +source.set_streams(["products", "users", "purchases"]) + +cache = ab.new_local_cache() +result = source.read(cache) + +for name, records in result.cache.streams.items(): + print(f"Stream {name}: {len(list(records))} records") diff --git a/airbyte-lib/poetry.lock b/airbyte-lib/poetry.lock index 5dafb19e798a1..5ac62115d29d8 100644 --- a/airbyte-lib/poetry.lock +++ b/airbyte-lib/poetry.lock @@ -525,6 +525,20 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1 testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] +[[package]] +name = "freezegun" +version = "1.4.0" +description = "Let your Python tests travel through time" +optional = false +python-versions = ">=3.7" +files = [ + {file = "freezegun-1.4.0-py3-none-any.whl", hash = "sha256:55e0fc3c84ebf0a96a5aa23ff8b53d70246479e9a68863f1fcac5a3e52f19dd6"}, + {file = "freezegun-1.4.0.tar.gz", hash = "sha256:10939b0ba0ff5adaecf3b06a5c2f73071d9678e507c5eaedb23c761d56ac774b"}, +] + +[package.dependencies] +python-dateutil = ">=2.7" + [[package]] name = "genson" version = "1.2.2" @@ -877,6 +891,30 @@ six = ">=1.11.0" format = ["idna", "jsonpointer (>1.13)", "rfc3987", "strict-rfc3339", "webcolors"] format-nongpl = ["idna", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "webcolors"] +[[package]] +name = "markdown-it-py" +version = "3.0.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +optional = false +python-versions = ">=3.8" +files = [ + {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, + {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +code-style = ["pre-commit (>=3.0,<4.0)"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins"] +profiling = ["gprof2dot"] +rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + [[package]] name = "markupsafe" version = "2.1.4" @@ -946,6 +984,17 @@ files = [ {file = "MarkupSafe-2.1.4.tar.gz", hash = "sha256:3aae9af4cac263007fd6309c64c6ab4506dd2b79382d9d19a1994f9240b8db4f"}, ] +[[package]] +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + [[package]] name = "mypy" version = "1.8.0" @@ -1915,6 +1964,24 @@ redis = ["redis (>=3)"] security = ["itsdangerous (>=2.0)"] yaml = ["pyyaml (>=5.4)"] +[[package]] +name = "rich" +version = "13.7.0" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "rich-13.7.0-py3-none-any.whl", hash = "sha256:6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235"}, + {file = "rich-13.7.0.tar.gz", hash = "sha256:5cb5123b5cf9ee70584244246816e9114227e0b98ad9176eede6ad54bf5403fa"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0" +pygments = ">=2.13.0,<3.0.0" + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + [[package]] name = "rpds-py" version = "0.17.1" @@ -2494,4 +2561,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "5eba75179b62f56be141db82121ca9e1c623944306172d7bdacaf5388e6f3384" +content-hash = "7ed61ca7eaed73dbd7e0800aa87fcd5b2583739048d53e9e564fe2a6defa483f" diff --git a/airbyte-lib/pyproject.toml b/airbyte-lib/pyproject.toml index a4749bbd750bb..4625e688dc411 100644 --- a/airbyte-lib/pyproject.toml +++ b/airbyte-lib/pyproject.toml @@ -28,6 +28,7 @@ pyarrow = "^14.0.2" # Psycopg3 is not supported in SQLAlchemy 1.x: # psycopg = {extras = ["binary", "pool"], version = "^3.1.16"} +rich = "^13.7.0" [tool.poetry.group.dev.dependencies] @@ -44,6 +45,7 @@ ruff = "^0.1.11" types-jsonschema = "^4.20.0.0" google-cloud-secret-manager = "^2.17.0" types-requests = "2.31.0.4" +freezegun = "^1.4.0" [build-system] requires = ["poetry-core"] diff --git a/airbyte-lib/tests/conftest.py b/airbyte-lib/tests/conftest.py index caabe34cfed53..17009133a89d2 100644 --- a/airbyte-lib/tests/conftest.py +++ b/airbyte-lib/tests/conftest.py @@ -12,6 +12,7 @@ import docker import psycopg2 as psycopg import pytest +from _pytest.nodes import Item from google.cloud import secretmanager from pytest_docker.plugin import get_docker_ip @@ -25,6 +26,32 @@ PYTEST_POSTGRES_PORT = 5432 +def pytest_collection_modifyitems(items: list[Item]) -> None: + """Override default pytest behavior, sorting our tests in a sensible execution order. + + In general, we want faster tests to run first, so that we can get feedback faster. + + Running lint tests first is helpful because they are fast and can catch typos and other errors. + + Otherwise tests are run based on an alpha-based natural sort, where 'unit' tests run after + 'integration' tests because 'u' comes after 'i' alphabetically. + """ + def test_priority(item: Item) -> int: + if 'lint_tests' in str(item.fspath): + return 1 # lint tests have high priority + elif 'unit_tests' in str(item.fspath): + return 2 # unit tests have highest priority + elif 'docs_tests' in str(item.fspath): + return 3 # doc tests have medium priority + elif 'integration_tests' in str(item.fspath): + return 4 # integration tests have the lowest priority + else: + return 5 # all other tests have lower priority + + # Sort the items list in-place based on the test_priority function + items.sort(key=test_priority) + + def is_port_in_use(port): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(("localhost", port)) == 0 @@ -128,6 +155,7 @@ def new_pg_cache_config(pg_dsn): ) yield config + @pytest.fixture def snowflake_config(): if "GCP_GSM_CREDENTIALS" not in os.environ: diff --git a/airbyte-lib/tests/unit_tests/test_progress.py b/airbyte-lib/tests/unit_tests/test_progress.py new file mode 100644 index 0000000000000..377df860bb57a --- /dev/null +++ b/airbyte-lib/tests/unit_tests/test_progress.py @@ -0,0 +1,174 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. + +import datetime +from textwrap import dedent +import time +import pytest +from freezegun import freeze_time +from airbyte_lib.progress import ReadProgress, _get_elapsed_time_str, _to_time_str +from dateutil.tz import tzlocal + +# Calculate the offset from UTC in hours +tz_offset_hrs = int(datetime.datetime.now(tzlocal()).utcoffset().total_seconds() / 3600) + + +@freeze_time("2022-01-01") +def test_read_progress_initialization(): + progress = ReadProgress() + assert progress.num_streams_expected == 0 + assert progress.read_start_time == 1640995200.0 # Unix timestamp for 2022-01-01 + assert progress.total_records_read == 0 + assert progress.total_records_written == 0 + assert progress.total_batches_written == 0 + assert progress.written_stream_names == set() + assert progress.finalize_start_time is None + assert progress.finalize_end_time is None + assert progress.total_records_finalized == 0 + assert progress.total_batches_finalized == 0 + assert progress.finalized_stream_names == set() + assert progress.last_update_time is None + + +@freeze_time("2022-01-01") +def test_read_progress_reset(): + progress = ReadProgress() + progress.reset(5) + assert progress.num_streams_expected == 5 + assert progress.read_start_time == 1640995200.0 + assert progress.total_records_read == 0 + assert progress.total_records_written == 0 + assert progress.total_batches_written == 0 + assert progress.written_stream_names == set() + assert progress.finalize_start_time is None + assert progress.finalize_end_time is None + assert progress.total_records_finalized == 0 + assert progress.total_batches_finalized == 0 + assert progress.finalized_stream_names == set() + +@freeze_time("2022-01-01") +def test_read_progress_log_records_read(): + progress = ReadProgress() + progress.log_records_read(100) + assert progress.total_records_read == 100 + +@freeze_time("2022-01-01") +def test_read_progress_log_batch_written(): + progress = ReadProgress() + progress.log_batch_written("stream1", 50) + assert progress.total_records_written == 50 + assert progress.total_batches_written == 1 + assert progress.written_stream_names == {"stream1"} + +@freeze_time("2022-01-01") +def test_read_progress_log_batches_finalizing(): + progress = ReadProgress() + progress.log_batches_finalizing("stream1", 1) + assert progress.finalize_start_time == 1640995200.0 + +@freeze_time("2022-01-01") +def test_read_progress_log_batches_finalized(): + progress = ReadProgress() + progress.log_batches_finalized("stream1", 1) + assert progress.total_batches_finalized == 1 + +@freeze_time("2022-01-01") +def test_read_progress_log_stream_finalized(): + progress = ReadProgress() + progress.log_stream_finalized("stream1") + assert progress.finalized_stream_names == {"stream1"} + + +def test_get_elapsed_time_str(): + assert _get_elapsed_time_str(30) == "30 seconds" + assert _get_elapsed_time_str(90) == "1min 30s" + assert _get_elapsed_time_str(600) == "10min" + assert _get_elapsed_time_str(3600) == "1hr 0min" + + +@freeze_time("2022-01-01 0:00:00") +def test_get_time_str(): + assert _to_time_str(time.time()) == "00:00:00" + + +def _assert_lines(expected_lines, actual_lines: list[str] | str): + if isinstance(actual_lines, list): + actual_lines = "\n".join(actual_lines) + for line in expected_lines: + assert line in actual_lines, f"Missing line: {line}" + +def test_get_status_message_after_finalizing_records(): + + # Test that we can render the initial status message before starting to read + with freeze_time("2022-01-01 00:00:00"): + progress = ReadProgress() + expected_lines = [ + "Started reading at 00:00:00.", + "Read **0** records over **0 seconds** (0.0 records / second).", + ] + _assert_lines(expected_lines, progress._get_status_message()) + + # Test after reading some records + with freeze_time("2022-01-01 00:01:00"): + progress.log_records_read(100) + expected_lines = [ + "Started reading at 00:00:00.", + "Read **100** records over **60 seconds** (1.7 records / second).", + ] + _assert_lines(expected_lines, progress._get_status_message()) + + # Advance the day and reset the progress + with freeze_time("2022-01-02 00:00:00"): + progress = ReadProgress() + progress.reset(1) + expected_lines = [ + "Started reading at 00:00:00.", + "Read **0** records over **0 seconds** (0.0 records / second).", + ] + _assert_lines(expected_lines, progress._get_status_message()) + + # Test after writing some records and starting to finalize + with freeze_time("2022-01-02 00:01:00"): + progress.log_records_read(100) + progress.log_batch_written("stream1", 50) + progress.log_batches_finalizing("stream1", 1) + expected_lines = [ + "## Read Progress", + "Started reading at 00:00:00.", + "Read **100** records over **60 seconds** (1.7 records / second).", + "Wrote **50** records over 1 batches.", + "Finished reading at 00:01:00.", + "Started finalizing streams at 00:01:00.", + ] + _assert_lines(expected_lines, progress._get_status_message()) + + # Test after finalizing some records + with freeze_time("2022-01-02 00:02:00"): + progress.log_batches_finalized("stream1", 1) + expected_lines = [ + "## Read Progress", + "Started reading at 00:00:00.", + "Read **100** records over **60 seconds** (1.7 records / second).", + "Wrote **50** records over 1 batches.", + "Finished reading at 00:01:00.", + "Started finalizing streams at 00:01:00.", + "Finalized **1** batches over 60 seconds.", + ] + _assert_lines(expected_lines, progress._get_status_message()) + + # Test after finalizing all records + with freeze_time("2022-01-02 00:02:00"): + progress.log_stream_finalized("stream1") + message = progress._get_status_message() + expected_lines = [ + "## Read Progress", + "Started reading at 00:00:00.", + "Read **100** records over **60 seconds** (1.7 records / second).", + "Wrote **50** records over 1 batches.", + "Finished reading at 00:01:00.", + "Started finalizing streams at 00:01:00.", + "Finalized **1** batches over 60 seconds.", + "Completed 1 out of 1 streams:", + "- stream1", + "Total time elapsed: 2min 0s", + ] + _assert_lines(expected_lines, message)