|
28 | 28 | AirbyteLogMessage,
|
29 | 29 | AirbyteMessage,
|
30 | 30 | AirbyteRecordMessage,
|
| 31 | + AirbyteStateMessage, |
31 | 32 | AirbyteStream,
|
| 33 | + AirbyteStreamState, |
32 | 34 | ConfiguredAirbyteCatalog,
|
33 | 35 | ConfiguredAirbyteStream,
|
34 | 36 | ConnectorSpecification,
|
35 | 37 | DestinationSyncMode,
|
36 | 38 | Level,
|
| 39 | + StreamDescriptor, |
37 | 40 | SyncMode,
|
38 | 41 | )
|
39 | 42 | from airbyte_cdk.models import Type
|
|
50 | 53 | _stream_options = {"name": _stream_name, "primary_key": _stream_primary_key, "url_base": _stream_url_base}
|
51 | 54 | _page_size = 2
|
52 | 55 |
|
| 56 | +_A_STATE = [AirbyteStateMessage( |
| 57 | + type="STREAM", |
| 58 | + stream=AirbyteStreamState( |
| 59 | + stream_descriptor=StreamDescriptor( |
| 60 | + name=_stream_name |
| 61 | + ), |
| 62 | + stream_state={ |
| 63 | + "key": "value" |
| 64 | + } |
| 65 | + ) |
| 66 | +)] |
| 67 | + |
53 | 68 | MANIFEST = {
|
54 | 69 | "version": "0.30.3",
|
55 | 70 | "definitions": {
|
@@ -266,7 +281,7 @@ def test_resolve_manifest(valid_resolve_manifest_config_file):
|
266 | 281 | config["__command"] = command
|
267 | 282 | source = ManifestDeclarativeSource(MANIFEST)
|
268 | 283 | limits = TestReadLimits()
|
269 |
| - resolved_manifest = handle_connector_builder_request(source, command, config, create_configured_catalog("dummy_stream"), limits) |
| 284 | + resolved_manifest = handle_connector_builder_request(source, command, config, create_configured_catalog("dummy_stream"), _A_STATE, limits) |
270 | 285 |
|
271 | 286 | expected_resolved_manifest = {
|
272 | 287 | "type": "DeclarativeSource",
|
@@ -455,10 +470,11 @@ def test_read():
|
455 | 470 | ),
|
456 | 471 | )
|
457 | 472 | limits = TestReadLimits()
|
458 |
| - with patch("airbyte_cdk.connector_builder.message_grouper.MessageGrouper.get_message_groups", return_value=stream_read): |
| 473 | + with patch("airbyte_cdk.connector_builder.message_grouper.MessageGrouper.get_message_groups", return_value=stream_read) as mock: |
459 | 474 | output_record = handle_connector_builder_request(
|
460 |
| - source, "test_read", config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), limits |
| 475 | + source, "test_read", config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), _A_STATE, limits |
461 | 476 | )
|
| 477 | + mock.assert_called_with(source, config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), _A_STATE, limits.max_records) |
462 | 478 | output_record.record.emitted_at = 1
|
463 | 479 | assert output_record == expected_airbyte_message
|
464 | 480 |
|
@@ -492,7 +508,7 @@ def test_config_update():
|
492 | 508 | return_value=refresh_request_response,
|
493 | 509 | ):
|
494 | 510 | output = handle_connector_builder_request(
|
495 |
| - source, "test_read", config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), TestReadLimits() |
| 511 | + source, "test_read", config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), _A_STATE, TestReadLimits() |
496 | 512 | )
|
497 | 513 | assert output.record.data["latest_config_update"]
|
498 | 514 |
|
@@ -529,7 +545,7 @@ def check_config_against_spec(self):
|
529 | 545 |
|
530 | 546 | source = MockManifestDeclarativeSource()
|
531 | 547 | limits = TestReadLimits()
|
532 |
| - response = read_stream(source, TEST_READ_CONFIG, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), limits) |
| 548 | + response = read_stream(source, TEST_READ_CONFIG, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), _A_STATE, limits) |
533 | 549 |
|
534 | 550 | expected_stream_read = StreamRead(
|
535 | 551 | logs=[LogMessage("error_message - a stack trace", "ERROR")],
|
@@ -716,7 +732,7 @@ def test_read_source(mock_http_stream):
|
716 | 732 |
|
717 | 733 | source = create_source(config, limits)
|
718 | 734 |
|
719 |
| - output_data = read_stream(source, config, catalog, limits).record.data |
| 735 | + output_data = read_stream(source, config, catalog, _A_STATE, limits).record.data |
720 | 736 | slices = output_data["slices"]
|
721 | 737 |
|
722 | 738 | assert len(slices) == max_slices
|
@@ -761,7 +777,7 @@ def test_read_source_single_page_single_slice(mock_http_stream):
|
761 | 777 |
|
762 | 778 | source = create_source(config, limits)
|
763 | 779 |
|
764 |
| - output_data = read_stream(source, config, catalog, limits).record.data |
| 780 | + output_data = read_stream(source, config, catalog, _A_STATE, limits).record.data |
765 | 781 | slices = output_data["slices"]
|
766 | 782 |
|
767 | 783 | assert len(slices) == max_slices
|
@@ -817,7 +833,7 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error
|
817 | 833 | source = create_source(config, limits)
|
818 | 834 |
|
819 | 835 | with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False):
|
820 |
| - output_data = read_stream(source, config, catalog, limits).record.data |
| 836 | + output_data = read_stream(source, config, catalog, _A_STATE, limits).record.data |
821 | 837 | if expected_error:
|
822 | 838 | assert len(output_data["logs"]) > 0, "Expected at least one log message with the expected error"
|
823 | 839 | error_message = output_data["logs"][0]
|
@@ -875,7 +891,7 @@ def test_handle_read_external_oauth_request(deployment_mode, token_url, expected
|
875 | 891 | source = create_source(config, limits)
|
876 | 892 |
|
877 | 893 | with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False):
|
878 |
| - output_data = read_stream(source, config, catalog, limits).record.data |
| 894 | + output_data = read_stream(source, config, catalog, _A_STATE, limits).record.data |
879 | 895 | if expected_error:
|
880 | 896 | assert len(output_data["logs"]) > 0, "Expected at least one log message with the expected error"
|
881 | 897 | error_message = output_data["logs"][0]
|
|
0 commit comments