|
8 | 8 |
|
9 | 9 | from airbyte_cdk.models import AirbyteMessage, AirbyteStateBlob, AirbyteStateMessage, AirbyteStateType, AirbyteStreamState, StreamDescriptor
|
10 | 10 | from airbyte_cdk.models import Type as MessageType
|
| 11 | +from airbyte_protocol_dataclasses.models import * |
11 | 12 |
|
12 | 13 |
|
13 | 14 | @dataclass(frozen=True)
|
@@ -95,35 +96,34 @@ def _extract_from_state_message(
|
95 | 96 | :param state: The incoming state input
|
96 | 97 | :return: A tuple of shared state and per stream state assembled from the incoming state list
|
97 | 98 | """
|
98 |
| - if state is None: |
| 99 | + if not state: |
99 | 100 | return None, {}
|
100 | 101 |
|
101 |
| - is_global = cls._is_global_state(state) |
102 |
| - |
103 |
| - if is_global: |
104 |
| - global_state = state[0].global_ # type: ignore # We verified state is a list in _is_global_state |
105 |
| - shared_state = copy.deepcopy(global_state.shared_state, {}) # type: ignore[union-attr] # global_state has shared_state |
| 102 | + if cls._is_global_state(state): |
| 103 | + global_state = state[0].global_ |
| 104 | + shared_state = global_state.shared_state # type: ignore[union-attr] |
106 | 105 | streams = {
|
107 | 106 | HashableStreamDescriptor(
|
108 | 107 | name=per_stream_state.stream_descriptor.name, namespace=per_stream_state.stream_descriptor.namespace
|
109 | 108 | ): per_stream_state.stream_state
|
110 |
| - for per_stream_state in global_state.stream_states # type: ignore[union-attr] # global_state has shared_state |
| 109 | + for per_stream_state in global_state.stream_states # type: ignore[union-attr] |
111 | 110 | }
|
112 | 111 | return shared_state, streams
|
113 |
| - else: |
114 |
| - streams = { |
115 |
| - HashableStreamDescriptor( |
116 |
| - name=per_stream_state.stream.stream_descriptor.name, namespace=per_stream_state.stream.stream_descriptor.namespace # type: ignore[union-attr] # stream has stream_descriptor |
117 |
| - ): per_stream_state.stream.stream_state # type: ignore[union-attr] # stream has stream_state |
118 |
| - for per_stream_state in state |
119 |
| - if per_stream_state.type == AirbyteStateType.STREAM and hasattr(per_stream_state, "stream") # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True |
120 |
| - } |
121 |
| - return None, streams |
| 112 | + |
| 113 | + streams = { |
| 114 | + HashableStreamDescriptor( |
| 115 | + name=per_stream_state.stream.stream_descriptor.name, |
| 116 | + namespace=per_stream_state.stream.stream_descriptor.namespace, # type: ignore[union-attr] |
| 117 | + ): per_stream_state.stream.stream_state # type: ignore[union-attr] |
| 118 | + for per_stream_state in state |
| 119 | + if per_stream_state.type == AirbyteStateType.STREAM |
| 120 | + } |
| 121 | + return None, streams |
122 | 122 |
|
123 | 123 | @staticmethod
|
124 | 124 | def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool:
|
125 | 125 | return (
|
126 |
| - isinstance(state, List) |
| 126 | + isinstance(state, list) |
127 | 127 | and len(state) == 1
|
128 | 128 | and isinstance(state[0], AirbyteStateMessage)
|
129 | 129 | and state[0].type == AirbyteStateType.GLOBAL
|
|
0 commit comments