3
3
#
4
4
import functools
5
5
from abc import ABC , abstractmethod
6
- from datetime import datetime
7
- from typing import Any , List , Mapping , MutableMapping , Optional , Protocol , Tuple
6
+ from typing import Any , Callable , Iterable , List , Mapping , MutableMapping , Optional , Protocol , Tuple
8
7
9
8
from airbyte_cdk .sources .connector_state_manager import ConnectorStateManager
10
9
from airbyte_cdk .sources .message import MessageRepository
@@ -18,19 +17,41 @@ def _extract_value(mapping: Mapping[str, Any], path: List[str]) -> Any:
18
17
return functools .reduce (lambda a , b : a [b ], path , mapping )
19
18
20
19
21
- class Comparable (Protocol ):
20
+ class GapType (Protocol ):
21
+ """
22
+ This is the representation of gaps between two cursor values. Examples:
23
+ * if cursor values are datetimes, GapType is timedelta
24
+ * if cursor values are integer, GapType will also be integer
25
+ """
26
+
27
+ pass
28
+
29
+
30
+ class CursorValueType (Protocol ):
22
31
"""Protocol for annotating comparable types."""
23
32
24
33
@abstractmethod
25
- def __lt__ (self : "Comparable" , other : "Comparable" ) -> bool :
34
+ def __lt__ (self : "CursorValueType" , other : "CursorValueType" ) -> bool :
35
+ pass
36
+
37
+ @abstractmethod
38
+ def __ge__ (self : "CursorValueType" , other : "CursorValueType" ) -> bool :
39
+ pass
40
+
41
+ @abstractmethod
42
+ def __add__ (self : "CursorValueType" , other : GapType ) -> "CursorValueType" :
43
+ pass
44
+
45
+ @abstractmethod
46
+ def __sub__ (self : "CursorValueType" , other : GapType ) -> "CursorValueType" :
26
47
pass
27
48
28
49
29
50
class CursorField :
30
51
def __init__ (self , cursor_field_key : str ) -> None :
31
52
self .cursor_field_key = cursor_field_key
32
53
33
- def extract_value (self , record : Record ) -> Comparable :
54
+ def extract_value (self , record : Record ) -> CursorValueType :
34
55
cursor_value = record .data .get (self .cursor_field_key )
35
56
if cursor_value is None :
36
57
raise ValueError (f"Could not find cursor field { self .cursor_field_key } in record" )
@@ -118,7 +139,10 @@ def __init__(
118
139
connector_state_converter : AbstractStreamStateConverter ,
119
140
cursor_field : CursorField ,
120
141
slice_boundary_fields : Optional [Tuple [str , str ]],
121
- start : Optional [Any ],
142
+ start : Optional [CursorValueType ],
143
+ end_provider : Callable [[], CursorValueType ],
144
+ lookback_window : Optional [GapType ] = None ,
145
+ slice_range : Optional [GapType ] = None ,
122
146
) -> None :
123
147
self ._stream_name = stream_name
124
148
self ._stream_namespace = stream_namespace
@@ -129,15 +153,18 @@ def __init__(
129
153
# To see some example where the slice boundaries might not be defined, check https://github.com/airbytehq/airbyte/blob/1ce84d6396e446e1ac2377362446e3fb94509461/airbyte-integrations/connectors/source-stripe/source_stripe/streams.py#L363-L379
130
154
self ._slice_boundary_fields = slice_boundary_fields if slice_boundary_fields else tuple ()
131
155
self ._start = start
156
+ self ._end_provider = end_provider
132
157
self ._most_recent_record : Optional [Record ] = None
133
158
self ._has_closed_at_least_one_slice = False
134
159
self .start , self ._concurrent_state = self ._get_concurrent_state (stream_state )
160
+ self ._lookback_window = lookback_window
161
+ self ._slice_range = slice_range
135
162
136
163
@property
137
164
def state (self ) -> MutableMapping [str , Any ]:
138
165
return self ._concurrent_state
139
166
140
- def _get_concurrent_state (self , state : MutableMapping [str , Any ]) -> Tuple [datetime , MutableMapping [str , Any ]]:
167
+ def _get_concurrent_state (self , state : MutableMapping [str , Any ]) -> Tuple [CursorValueType , MutableMapping [str , Any ]]:
141
168
if self ._connector_state_converter .is_state_message_compatible (state ):
142
169
return self ._start or self ._connector_state_converter .zero_value , self ._connector_state_converter .deserialize (state )
143
170
return self ._connector_state_converter .convert_from_sequential_state (self ._cursor_field , state , self ._start )
@@ -203,23 +230,20 @@ def _emit_state_message(self) -> None:
203
230
self ._connector_state_manager .update_state_for_stream (
204
231
self ._stream_name ,
205
232
self ._stream_namespace ,
206
- self ._connector_state_converter .convert_to_sequential_state (self ._cursor_field , self .state ),
233
+ self ._connector_state_converter .convert_to_state_message (self ._cursor_field , self .state ),
207
234
)
208
- # TODO: if we migrate stored state to the concurrent state format
209
- # (aka stop calling self._connector_state_converter.convert_to_sequential_state`), we'll need to cast datetimes to string or
210
- # int before emitting state
211
235
state_message = self ._connector_state_manager .create_state_message (self ._stream_name , self ._stream_namespace )
212
236
self ._message_repository .emit_message (state_message )
213
237
214
238
def _merge_partitions (self ) -> None :
215
239
self .state ["slices" ] = self ._connector_state_converter .merge_intervals (self .state ["slices" ])
216
240
217
- def _extract_from_slice (self , partition : Partition , key : str ) -> Comparable :
241
+ def _extract_from_slice (self , partition : Partition , key : str ) -> CursorValueType :
218
242
try :
219
243
_slice = partition .to_slice ()
220
244
if not _slice :
221
245
raise KeyError (f"Could not find key `{ key } ` in empty slice" )
222
- return self ._connector_state_converter .parse_value (_slice [key ]) # type: ignore # we expect the devs to specify a key that would return a Comparable
246
+ return self ._connector_state_converter .parse_value (_slice [key ]) # type: ignore # we expect the devs to specify a key that would return a CursorValueType
223
247
except KeyError as exception :
224
248
raise KeyError (f"Partition is expected to have key `{ key } ` but could not be found" ) from exception
225
249
@@ -229,3 +253,66 @@ def ensure_at_least_one_state_emitted(self) -> None:
229
253
called.
230
254
"""
231
255
self ._emit_state_message ()
256
+
257
+ def generate_slices (self ) -> Iterable [Tuple [CursorValueType , CursorValueType ]]:
258
+ """
259
+ Generating slices based on a few parameters:
260
+ * lookback_window: Buffer to remove from END_KEY of the highest slice
261
+ * slice_range: Max difference between two slices. If the difference between two slices is greater, multiple slices will be created
262
+ * start: `_split_per_slice_range` will clip any value to `self._start which means that:
263
+ * if upper is less than self._start, no slices will be generated
264
+ * if lower is less than self._start, self._start will be used as the lower boundary (lookback_window will not be considered in that case)
265
+
266
+ Note that the slices will overlap at their boundaries. We therefore expect to have at least the lower or the upper boundary to be
267
+ inclusive in the API that is queried.
268
+ """
269
+ self ._merge_partitions ()
270
+
271
+ if self ._start is not None and self ._is_start_before_first_slice ():
272
+ yield from self ._split_per_slice_range (self ._start , self .state ["slices" ][0 ][self ._connector_state_converter .START_KEY ])
273
+
274
+ if len (self .state ["slices" ]) == 1 :
275
+ yield from self ._split_per_slice_range (
276
+ self ._calculate_lower_boundary_of_last_slice (self .state ["slices" ][0 ][self ._connector_state_converter .END_KEY ]),
277
+ self ._end_provider (),
278
+ )
279
+ elif len (self .state ["slices" ]) > 1 :
280
+ for i in range (len (self .state ["slices" ]) - 1 ):
281
+ yield from self ._split_per_slice_range (
282
+ self .state ["slices" ][i ][self ._connector_state_converter .END_KEY ],
283
+ self .state ["slices" ][i + 1 ][self ._connector_state_converter .START_KEY ],
284
+ )
285
+ yield from self ._split_per_slice_range (
286
+ self ._calculate_lower_boundary_of_last_slice (self .state ["slices" ][- 1 ][self ._connector_state_converter .END_KEY ]),
287
+ self ._end_provider (),
288
+ )
289
+ else :
290
+ raise ValueError ("Expected at least one slice" )
291
+
292
+ def _is_start_before_first_slice (self ) -> bool :
293
+ return self ._start is not None and self ._start < self .state ["slices" ][0 ][self ._connector_state_converter .START_KEY ]
294
+
295
+ def _calculate_lower_boundary_of_last_slice (self , lower_boundary : CursorValueType ) -> CursorValueType :
296
+ if self ._lookback_window :
297
+ return lower_boundary - self ._lookback_window
298
+ return lower_boundary
299
+
300
+ def _split_per_slice_range (self , lower : CursorValueType , upper : CursorValueType ) -> Iterable [Tuple [CursorValueType , CursorValueType ]]:
301
+ if lower >= upper :
302
+ return
303
+
304
+ if self ._start and upper < self ._start :
305
+ return
306
+
307
+ lower = max (lower , self ._start ) if self ._start else lower
308
+ if not self ._slice_range or lower + self ._slice_range >= upper :
309
+ yield lower , upper
310
+ else :
311
+ stop_processing = False
312
+ current_lower_boundary = lower
313
+ while not stop_processing :
314
+ current_upper_boundary = min (current_lower_boundary + self ._slice_range , upper )
315
+ yield current_lower_boundary , current_upper_boundary
316
+ current_lower_boundary = current_upper_boundary
317
+ if current_upper_boundary >= upper :
318
+ stop_processing = True
0 commit comments