|
6 | 6 | from unittest.mock import Mock
|
7 | 7 |
|
8 | 8 | import pytest
|
9 |
| -from airbyte_cdk.models import FailureType |
10 | 9 | from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
|
11 | 10 | from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import PerPartitionCursor, PerPartitionKeySerializer, StreamSlice
|
12 | 11 | from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
|
13 | 12 | from airbyte_cdk.sources.types import Record
|
14 |
| -from airbyte_cdk.utils import AirbyteTracedException |
15 | 13 |
|
16 | 14 | PARTITION = {
|
17 | 15 | "partition_key string": "partition value",
|
@@ -519,10 +517,37 @@ def test_get_stream_state_includes_parent_state(mocked_cursor_factory, mocked_pa
|
519 | 517 | assert stream_state == expected_state
|
520 | 518 |
|
521 | 519 |
|
522 |
| -def test_given_invalid_state_when_set_initial_state_then_raise_config_error(mocked_cursor_factory, mocked_partition_router) -> None: |
523 |
| - cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) |
524 |
| - |
525 |
| - with pytest.raises(AirbyteTracedException) as exception: |
526 |
| - cursor.set_initial_state({"invalid_state": 1}) |
| 520 | +def test_per_partition_state_when_set_initial_global_state(mocked_cursor_factory, mocked_partition_router) -> None: |
| 521 | + first_partition = {"first_partition_key": "first_partition_value"} |
| 522 | + second_partition = {"second_partition_key": "second_partition_value"} |
| 523 | + global_state = {"global_state_format_key": "global_state_format_value"} |
527 | 524 |
|
528 |
| - assert exception.value.failure_type == FailureType.config_error |
| 525 | + mocked_partition_router.stream_slices.return_value = [ |
| 526 | + StreamSlice(partition=first_partition, cursor_slice={}), |
| 527 | + StreamSlice(partition=second_partition, cursor_slice={}), |
| 528 | + ] |
| 529 | + mocked_cursor_factory.create.side_effect = [ |
| 530 | + MockedCursorBuilder().with_stream_state(global_state).build(), |
| 531 | + MockedCursorBuilder().with_stream_state(global_state).build(), |
| 532 | + ] |
| 533 | + cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) |
| 534 | + global_state = {"global_state_format_key": "global_state_format_value"} |
| 535 | + cursor.set_initial_state(global_state) |
| 536 | + assert cursor._state_to_migrate_from == global_state |
| 537 | + list(cursor.stream_slices()) |
| 538 | + assert cursor._cursor_per_partition['{"first_partition_key":"first_partition_value"}'].set_initial_state.call_count == 1 |
| 539 | + assert cursor._cursor_per_partition['{"first_partition_key":"first_partition_value"}'].set_initial_state.call_args[0] == ( |
| 540 | + {"global_state_format_key": "global_state_format_value"}, |
| 541 | + ) |
| 542 | + assert cursor._cursor_per_partition['{"second_partition_key":"second_partition_value"}'].set_initial_state.call_count == 1 |
| 543 | + assert cursor._cursor_per_partition['{"second_partition_key":"second_partition_value"}'].set_initial_state.call_args[0] == ( |
| 544 | + {"global_state_format_key": "global_state_format_value"}, |
| 545 | + ) |
| 546 | + expected_state = [ |
| 547 | + {"cursor": {"global_state_format_key": "global_state_format_value"}, "partition": {"first_partition_key": "first_partition_value"}}, |
| 548 | + { |
| 549 | + "cursor": {"global_state_format_key": "global_state_format_value"}, |
| 550 | + "partition": {"second_partition_key": "second_partition_value"}, |
| 551 | + }, |
| 552 | + ] |
| 553 | + assert cursor.get_stream_state()["states"] == expected_state |
0 commit comments