9
9
import time
10
10
from abc import ABC
11
11
from contextlib import closing
12
- from typing import Any , Iterable , List , Mapping , MutableMapping , Optional , Tuple , Type , Union
12
+ from typing import Any , Callable , Iterable , List , Mapping , MutableMapping , Optional , Tuple , Type , Union
13
13
14
14
import pandas as pd
15
15
import pendulum
16
16
import requests # type: ignore[import]
17
17
from airbyte_cdk .models import ConfiguredAirbyteCatalog , SyncMode
18
- from airbyte_cdk .sources .streams import Stream
19
18
from airbyte_cdk .sources .streams .availability_strategy import AvailabilityStrategy
19
+ from airbyte_cdk .sources .streams .core import Stream , StreamData
20
20
from airbyte_cdk .sources .streams .http import HttpStream
21
21
from airbyte_cdk .sources .utils .transform import TransformConfig , TypeTransformer
22
22
from numpy import nan
@@ -38,6 +38,7 @@ class SalesforceStream(HttpStream, ABC):
38
38
page_size = 2000
39
39
transformer = TypeTransformer (TransformConfig .DefaultSchemaNormalization )
40
40
encoding = DEFAULT_ENCODING
41
+ MAX_PROPERTIES_LENGTH = Salesforce .REQUEST_SIZE_LIMITS - 2000
41
42
42
43
def __init__ (
43
44
self , sf_api : Salesforce , pk : str , stream_name : str , sobject_options : Mapping [str , Any ] = None , schema : dict = None , ** kwargs
@@ -65,6 +66,31 @@ def url_base(self) -> str:
65
66
def availability_strategy (self ) -> Optional ["AvailabilityStrategy" ]:
66
67
return None
67
68
69
+ @property
70
+ def too_many_properties (self ):
71
+ selected_properties = self .get_json_schema ().get ("properties" , {})
72
+ properties_length = len ("," .join (p for p in selected_properties ))
73
+ return properties_length > self .MAX_PROPERTIES_LENGTH
74
+
75
+ def parse_response (self , response : requests .Response , ** kwargs ) -> Iterable [Mapping ]:
76
+ yield from response .json ()["records" ]
77
+
78
+ def get_json_schema (self ) -> Mapping [str , Any ]:
79
+ if not self .schema :
80
+ self .schema = self .sf_api .generate_schema (self .name )
81
+ return self .schema
82
+
83
+ def get_error_display_message (self , exception : BaseException ) -> Optional [str ]:
84
+ if isinstance (exception , exceptions .ConnectionError ):
85
+ return f"After { self .max_retries } retries the connector has failed with a network error. It looks like Salesforce API experienced temporary instability, please try again later."
86
+ return super ().get_error_display_message (exception )
87
+
88
+
89
+ class RestSalesforceStream (SalesforceStream ):
90
+ def __init__ (self , * args , ** kwargs ):
91
+ super ().__init__ (* args , ** kwargs )
92
+ assert self .primary_key or not self .too_many_properties
93
+
68
94
def path (self , next_page_token : Mapping [str , Any ] = None , ** kwargs : Any ) -> str :
69
95
if next_page_token :
70
96
"""
@@ -80,7 +106,11 @@ def next_page_token(self, response: requests.Response) -> Optional[Mapping[str,
80
106
return {"next_token" : next_token } if next_token else None
81
107
82
108
def request_params (
83
- self , stream_state : Mapping [str , Any ], stream_slice : Mapping [str , Any ] = None , next_page_token : Mapping [str , Any ] = None
109
+ self ,
110
+ stream_state : Mapping [str , Any ],
111
+ stream_slice : Mapping [str , Any ] = None ,
112
+ next_page_token : Mapping [str , Any ] = None ,
113
+ property_chunk : Mapping [str , Any ] = None ,
84
114
) -> MutableMapping [str , Any ]:
85
115
"""
86
116
Salesforce SOQL Query: https://developer.salesforce.com/docs/atlas.en-us.232.0.api_rest.meta/api_rest/dome_queryall.htm
@@ -91,32 +121,44 @@ def request_params(
91
121
"""
92
122
return {}
93
123
94
- selected_properties = self . get_json_schema (). get ( "properties" , {})
95
- query = f"SELECT { ',' .join (selected_properties .keys ())} FROM { self .name } "
124
+ property_chunk = property_chunk or {}
125
+ query = f"SELECT { ',' .join (property_chunk .keys ())} FROM { self .name } "
96
126
97
127
if self .primary_key and self .name not in UNSUPPORTED_FILTERING_STREAMS :
98
128
query += f"ORDER BY { self .primary_key } ASC"
99
129
100
130
return {"q" : query }
101
131
102
- def parse_response (self , response : requests . Response , ** kwargs ) -> Iterable [Mapping ]:
103
- yield from response . json ()[ "records" ]
132
+ def chunk_properties (self ) -> Iterable [Mapping [ str , Any ] ]:
133
+ selected_properties = self . get_json_schema (). get ( "properties" , {})
104
134
105
- def get_json_schema (self ) -> Mapping [str , Any ]:
106
- if not self .schema :
107
- self .schema = self .sf_api .generate_schema (self .name )
108
- return self .schema
135
+ summary_length = 0
136
+ local_properties = {}
137
+ for property_name , value in selected_properties .items ():
138
+ current_property_length = len (property_name ) + 1 # properties are split with commas
139
+ if current_property_length + summary_length >= self .MAX_PROPERTIES_LENGTH :
140
+ yield local_properties
141
+ local_properties = {}
142
+ summary_length = 0
143
+
144
+ local_properties [property_name ] = value
145
+ summary_length += current_property_length
146
+
147
+ if local_properties :
148
+ yield local_properties
109
149
110
150
def read_records (
111
151
self ,
112
152
sync_mode : SyncMode ,
113
153
cursor_field : List [str ] = None ,
114
154
stream_slice : Mapping [str , Any ] = None ,
115
155
stream_state : Mapping [str , Any ] = None ,
116
- ) -> Iterable [Mapping [ str , Any ] ]:
156
+ ) -> Iterable [StreamData ]:
117
157
try :
118
- yield from super ().read_records (
119
- sync_mode = sync_mode , cursor_field = cursor_field , stream_slice = stream_slice , stream_state = stream_state
158
+ yield from self ._read_pages (
159
+ lambda req , res , state , _slice : self .parse_response (res , stream_slice = _slice , stream_state = state ),
160
+ stream_slice ,
161
+ stream_state ,
120
162
)
121
163
except exceptions .HTTPError as error :
122
164
"""
@@ -135,10 +177,83 @@ def read_records(
135
177
return
136
178
raise error
137
179
138
- def get_error_display_message (self , exception : BaseException ) -> Optional [str ]:
139
- if isinstance (exception , exceptions .ConnectionError ):
140
- return f"After { self .max_retries } retries the connector has failed with a network error. It looks like Salesforce API experienced temporary instability, please try again later."
141
- return super ().get_error_display_message (exception )
180
+ def _read_pages (
181
+ self ,
182
+ records_generator_fn : Callable [
183
+ [requests .PreparedRequest , requests .Response , Mapping [str , Any ], Mapping [str , Any ]], Iterable [StreamData ]
184
+ ],
185
+ stream_slice : Mapping [str , Any ] = None ,
186
+ stream_state : Mapping [str , Any ] = None ,
187
+ ) -> Iterable [StreamData ]:
188
+ stream_state = stream_state or {}
189
+ pagination_complete = False
190
+ records = {}
191
+ next_pages = {}
192
+
193
+ while not pagination_complete :
194
+ index = 0
195
+ for index , property_chunk in enumerate (self .chunk_properties ()):
196
+ request , response = self ._fetch_next_page (stream_slice , stream_state , next_pages .get (index ), property_chunk )
197
+ next_pages [index ] = self .next_page_token (response )
198
+ chunk_page_records = records_generator_fn (request , response , stream_state , stream_slice )
199
+ if not self .too_many_properties :
200
+ # this is the case when a stream has no primary key
201
+ # (is allowed when properties length does not exceed the maximum value)
202
+ # so there would be a single iteration, therefore we may and should yield records immediately
203
+ yield from chunk_page_records
204
+ break
205
+ chunk_page_records = {record [self .primary_key ]: record for record in chunk_page_records }
206
+
207
+ for record_id , record in chunk_page_records .items ():
208
+ if record_id not in records :
209
+ records [record_id ] = (record , 1 )
210
+ continue
211
+ incomplete_record , counter = records [record_id ]
212
+ incomplete_record .update (record )
213
+ counter += 1
214
+ records [record_id ] = (incomplete_record , counter )
215
+
216
+ for record_id , (record , counter ) in records .items ():
217
+ if counter != index + 1 :
218
+ # Because we make multiple calls to query N records (each call to fetch X properties of all the N records),
219
+ # there's a chance that the number of records corresponding to the query may change between the calls. This
220
+ # may result in data inconsistency. We skip such records for now and log a warning message.
221
+ self .logger .warning (
222
+ f"Inconsistent record with primary key { record_id } found. It consists of { counter } chunks instead of { index + 1 } . "
223
+ f"Skipping it."
224
+ )
225
+ continue
226
+ yield record
227
+
228
+ records = {}
229
+
230
+ if not any (next_pages .values ()):
231
+ pagination_complete = True
232
+
233
+ # Always return an empty generator just in case no records were ever yielded
234
+ yield from []
235
+
236
+ def _fetch_next_page (
237
+ self ,
238
+ stream_slice : Mapping [str , Any ] = None ,
239
+ stream_state : Mapping [str , Any ] = None ,
240
+ next_page_token : Mapping [str , Any ] = None ,
241
+ property_chunk : Mapping [str , Any ] = None ,
242
+ ) -> Tuple [requests .PreparedRequest , requests .Response ]:
243
+ request_headers = self .request_headers (stream_state = stream_state , stream_slice = stream_slice , next_page_token = next_page_token )
244
+ request = self ._create_prepared_request (
245
+ path = self .path (stream_state = stream_state , stream_slice = stream_slice , next_page_token = next_page_token ),
246
+ headers = dict (request_headers , ** self .authenticator .get_auth_header ()),
247
+ params = self .request_params (
248
+ stream_state = stream_state , stream_slice = stream_slice , next_page_token = next_page_token , property_chunk = property_chunk
249
+ ),
250
+ json = self .request_body_json (stream_state = stream_state , stream_slice = stream_slice , next_page_token = next_page_token ),
251
+ data = self .request_body_data (stream_state = stream_state , stream_slice = stream_slice , next_page_token = next_page_token ),
252
+ )
253
+ request_kwargs = self .request_kwargs (stream_state = stream_state , stream_slice = stream_slice , next_page_token = next_page_token )
254
+
255
+ response = self ._send_request (request , request_kwargs )
256
+ return request , response
142
257
143
258
144
259
class BulkSalesforceStream (SalesforceStream ):
@@ -406,10 +521,10 @@ def get_standard_instance(self) -> SalesforceStream:
406
521
sobject_options = self .sobject_options ,
407
522
authenticator = self .authenticator ,
408
523
)
409
- new_cls : Type [SalesforceStream ] = SalesforceStream
524
+ new_cls : Type [SalesforceStream ] = RestSalesforceStream
410
525
if isinstance (self , BulkIncrementalSalesforceStream ):
411
526
stream_kwargs .update ({"replication_key" : self .replication_key , "start_date" : self .start_date })
412
- new_cls = IncrementalSalesforceStream
527
+ new_cls = IncrementalRestSalesforceStream
413
528
414
529
return new_cls (** stream_kwargs )
415
530
@@ -426,7 +541,7 @@ def transform_empty_string_to_none(instance: Any, schema: Any):
426
541
return instance
427
542
428
543
429
- class IncrementalSalesforceStream ( SalesforceStream , ABC ):
544
+ class IncrementalRestSalesforceStream ( RestSalesforceStream , ABC ):
430
545
state_checkpoint_interval = 500
431
546
432
547
def __init__ (self , replication_key : str , start_date : Optional [str ], ** kwargs ):
@@ -442,20 +557,24 @@ def format_start_date(start_date: Optional[str]) -> Optional[str]:
442
557
return None
443
558
444
559
def request_params (
445
- self , stream_state : Mapping [str , Any ], stream_slice : Mapping [str , Any ] = None , next_page_token : Mapping [str , Any ] = None
560
+ self ,
561
+ stream_state : Mapping [str , Any ],
562
+ stream_slice : Mapping [str , Any ] = None ,
563
+ next_page_token : Mapping [str , Any ] = None ,
564
+ property_chunk : Mapping [str , Any ] = None ,
446
565
) -> MutableMapping [str , Any ]:
447
566
if next_page_token :
448
567
"""
449
568
If `next_page_token` is set, subsequent requests use `nextRecordsUrl`, and do not include any parameters.
450
569
"""
451
570
return {}
452
571
453
- selected_properties = self . get_json_schema (). get ( "properties" , {})
572
+ property_chunk = property_chunk or {}
454
573
455
574
stream_date = stream_state .get (self .cursor_field )
456
575
start_date = stream_date or self .start_date
457
576
458
- query = f"SELECT { ',' .join (selected_properties .keys ())} FROM { self .name } "
577
+ query = f"SELECT { ',' .join (property_chunk .keys ())} FROM { self .name } "
459
578
if start_date :
460
579
query += f"WHERE { self .cursor_field } >= { start_date } "
461
580
if self .name not in UNSUPPORTED_FILTERING_STREAMS :
@@ -477,7 +596,7 @@ def get_updated_state(self, current_stream_state: MutableMapping[str, Any], late
477
596
return {self .cursor_field : latest_benchmark }
478
597
479
598
480
- class BulkIncrementalSalesforceStream (BulkSalesforceStream , IncrementalSalesforceStream ):
599
+ class BulkIncrementalSalesforceStream (BulkSalesforceStream , IncrementalRestSalesforceStream ):
481
600
def next_page_token (self , last_record : Mapping [str , Any ]) -> Optional [Mapping [str , Any ]]:
482
601
if self .name not in UNSUPPORTED_FILTERING_STREAMS :
483
602
page_token : str = last_record [self .cursor_field ]
0 commit comments