15
15
AirbyteLogMessage ,
16
16
AirbyteMessage ,
17
17
AirbyteRecordMessage ,
18
+ AirbyteStateMessage ,
19
+ AirbyteStreamState ,
18
20
Level ,
19
21
OrchestratorType ,
22
+ StreamDescriptor ,
20
23
)
21
24
from airbyte_cdk .models import Type as MessageType
22
25
from unit_tests .connector_builder .utils import create_configured_catalog
@@ -470,6 +473,7 @@ def test_get_grouped_messages_with_many_slices(mock_entrypoint_read: Mock) -> No
470
473
request_response_log_message (request , response , url ),
471
474
record_message ("hashiras" , {"name" : "Obanai Iguro" }),
472
475
request_response_log_message (request , response , url ),
476
+ state_message ("hashiras" , {"a_timestamp" : 123 }),
473
477
]
474
478
),
475
479
)
@@ -486,13 +490,16 @@ def test_get_grouped_messages_with_many_slices(mock_entrypoint_read: Mock) -> No
486
490
assert stream_read .slices [0 ].slice_descriptor == {"descriptor" : "first_slice" }
487
491
assert len (stream_read .slices [0 ].pages ) == 1
488
492
assert len (stream_read .slices [0 ].pages [0 ].records ) == 1
493
+ assert stream_read .slices [0 ].state is None
489
494
490
495
assert stream_read .slices [1 ].slice_descriptor == {"descriptor" : "second_slice" }
491
496
assert len (stream_read .slices [1 ].pages ) == 3
492
497
assert len (stream_read .slices [1 ].pages [0 ].records ) == 2
493
498
assert len (stream_read .slices [1 ].pages [1 ].records ) == 1
494
499
assert len (stream_read .slices [1 ].pages [2 ].records ) == 0
495
500
501
+ assert stream_read .slices [1 ].state .stream .stream_state == {"a_timestamp" : 123 }
502
+
496
503
497
504
@patch ("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read" )
498
505
def test_get_grouped_messages_given_maximum_number_of_slices_then_test_read_limit_reached (mock_entrypoint_read : Mock ) -> None :
@@ -698,6 +705,13 @@ def record_message(stream: str, data: Mapping[str, Any]) -> AirbyteMessage:
698
705
return AirbyteMessage (type = MessageType .RECORD , record = AirbyteRecordMessage (stream = stream , data = data , emitted_at = 1234 ))
699
706
700
707
708
+ def state_message (stream : str , data : Mapping [str , Any ]) -> AirbyteMessage :
709
+ return AirbyteMessage (type = MessageType .STATE , state = AirbyteStateMessage (stream = AirbyteStreamState (
710
+ stream_descriptor = StreamDescriptor (name = stream ),
711
+ stream_state = data
712
+ )))
713
+
714
+
701
715
def slice_message (slice_descriptor : str = '{"key": "value"}' ) -> AirbyteMessage :
702
716
return AirbyteMessage (type = MessageType .LOG , log = AirbyteLogMessage (level = Level .INFO , message = "slice:" + slice_descriptor ))
703
717
0 commit comments