13
13
from airbyte_cdk .sources import AbstractSource
14
14
from airbyte_cdk .sources .streams import Stream
15
15
from airbyte_cdk .sources .streams .http import HttpStream
16
- from airbyte_cdk .sources .streams .http .auth import Oauth2Authenticator , TokenAuthenticator
16
+
17
+ from .utils import datetime_to_string , delete_milliseconds , get_api_endpoint , get_start_date , initialize_authenticator
17
18
18
19
19
20
class OktaStream (HttpStream , ABC ):
20
21
page_size = 200
21
22
22
- def __init__ (self , url_base : str , * args , ** kwargs ):
23
+ def __init__ (self , url_base : str , start_date : pendulum . datetime , * args , ** kwargs ):
23
24
super ().__init__ (* args , ** kwargs )
24
25
# Inject custom url base to the stream
25
26
self ._url_base = url_base .rstrip ("/" ) + "/"
27
+ self .start_date = start_date
26
28
27
29
@property
28
30
def url_base (self ) -> str :
@@ -97,11 +99,10 @@ def request_params(
97
99
stream_slice : Mapping [str , any ] = None ,
98
100
next_page_token : Mapping [str , Any ] = None ,
99
101
) -> MutableMapping [str , Any ]:
100
- stream_state = stream_state or {}
101
102
params = super ().request_params (stream_state , stream_slice , next_page_token )
102
- latest_entry = stream_state .get (self .cursor_field )
103
- if latest_entry :
104
- params [ "filter" ] = f' { self . cursor_field } gt " { latest_entry } "'
103
+ latest_entry = stream_state .get (self .cursor_field ) if stream_state else datetime_to_string ( self . start_date )
104
+ filter_param = { "filter" : f' { self . cursor_field } gt " { latest_entry } "' }
105
+ params . update ( filter_param )
105
106
return params
106
107
107
108
@@ -120,7 +121,7 @@ class GroupMembers(IncrementalOktaStream):
120
121
use_cache = True
121
122
122
123
def stream_slices (self , ** kwargs ):
123
- group_stream = Groups (authenticator = self .authenticator , url_base = self .url_base )
124
+ group_stream = Groups (authenticator = self .authenticator , url_base = self .url_base , start_date = self . start_date )
124
125
for group in group_stream .read_records (sync_mode = SyncMode .full_refresh ):
125
126
yield {"group_id" : group ["id" ]}
126
127
@@ -134,10 +135,12 @@ def request_params(
134
135
stream_slice : Mapping [str , any ] = None ,
135
136
next_page_token : Mapping [str , Any ] = None ,
136
137
) -> MutableMapping [str , Any ]:
137
- params = OktaStream .request_params (self , stream_state , stream_slice , next_page_token )
138
- latest_entry = stream_state .get (self .cursor_field )
139
- if latest_entry :
140
- params ["after" ] = latest_entry
138
+ # Filter param should be ignored SCIM filter expressions can't use the published
139
+ # attribute since it may conflict with the logic of the since, after, and until query params.
140
+ # Docs: https://developer.okta.com/docs/reference/api/system-log/#expression-filter
141
+ params = super (IncrementalOktaStream , self ).request_params (stream_state , stream_slice , next_page_token )
142
+ latest_entry = stream_state .get (self .cursor_field ) if stream_state else self .min_user_id
143
+ params ["after" ] = latest_entry
141
144
return params
142
145
143
146
def get_updated_state (self , current_stream_state : MutableMapping [str , Any ], latest_record : Mapping [str , Any ]) -> Mapping [str , Any ]:
@@ -154,7 +157,7 @@ class GroupRoleAssignments(OktaStream):
154
157
use_cache = True
155
158
156
159
def stream_slices (self , ** kwargs ):
157
- group_stream = Groups (authenticator = self .authenticator , url_base = self .url_base )
160
+ group_stream = Groups (authenticator = self .authenticator , url_base = self .url_base , start_date = self . start_date )
158
161
for group in group_stream .read_records (sync_mode = SyncMode .full_refresh ):
159
162
yield {"group_id" : group ["id" ]}
160
163
@@ -168,6 +171,28 @@ class Logs(IncrementalOktaStream):
168
171
cursor_field = "published"
169
172
primary_key = "uuid"
170
173
174
+ def __init__ (self , url_base , ** kwargs ):
175
+ super ().__init__ (url_base = url_base , ** kwargs )
176
+ self ._raise_on_http_errors : bool = True
177
+
178
+ @property
179
+ def raise_on_http_errors (self ) -> bool :
180
+ return self ._raise_on_http_errors
181
+
182
+ def should_retry (self , response : requests .Response ) -> bool :
183
+ """
184
+ When the connector gets abnormal state API retrun errror with 400 status code
185
+ and internal error code E0000001. The connector ignores an error with 400 code
186
+ to finish successfully sync and inform the user about an error in logs with an
187
+ error message.
188
+ """
189
+
190
+ if response .status_code == 400 and response .json ().get ("errorCode" ) == "E0000001" :
191
+ self .logger .info (f"{ response .json ()['errorSummary' ]} " )
192
+ self ._raise_on_http_errors = False
193
+ return False
194
+ return HttpStream .should_retry (self , response )
195
+
171
196
def path (self , ** kwargs ) -> str :
172
197
return "logs"
173
198
@@ -177,24 +202,27 @@ def request_params(
177
202
stream_slice : Mapping [str , any ] = None ,
178
203
next_page_token : Mapping [str , Any ] = None ,
179
204
) -> MutableMapping [str , Any ]:
180
- # The log stream use a different params to get data
181
- # https://developer.okta.com/docs/reference/api/system-log/#datetime-filter
182
- stream_state = stream_state or {}
183
- params = OktaStream .request_params (self , stream_state , stream_slice , next_page_token )
184
- latest_entry = stream_state .get (self .cursor_field )
185
- if latest_entry :
186
- params ["since" ] = latest_entry
187
- # [Test-driven Development] Set until When the cursor value from the stream state
188
- # is abnormally large, otherwise the server side that sets now to until
189
- # will throw an error: The "until" date must be later than the "since" date
190
- # https://developer.okta.com/docs/reference/api/system-log/#request-parameters
191
- parsed = pendulum .parse (latest_entry )
192
- utc_now = pendulum .utcnow ()
193
- if parsed > utc_now :
194
- params ["until" ] = latest_entry
195
-
205
+ # The log stream use a different params to get data.
206
+ # Docs: https://developer.okta.com/docs/reference/api/system-log/#datetime-filter
207
+ # Filter param should be ignored SCIM filter expressions can't use the published
208
+ # attribute since it may conflict with the logic of the since, after, and until query params.
209
+ # Docs: https://developer.okta.com/docs/reference/api/system-log/#expression-filter
210
+ params = super (IncrementalOktaStream , self ).request_params (stream_state , stream_slice , next_page_token )
211
+ latest_entry = stream_state .get (self .cursor_field ) if stream_state else self .start_date
212
+ params ["since" ] = latest_entry
196
213
return params
197
214
215
+ def parse_response (
216
+ self ,
217
+ response : requests .Response ,
218
+ ** kwargs ,
219
+ ) -> Iterable [Mapping ]:
220
+ data = response .json () if isinstance (response .json (), list ) else []
221
+
222
+ for record in data :
223
+ record [self .cursor_field ] = delete_milliseconds (record [self .cursor_field ])
224
+ yield record
225
+
198
226
199
227
class Users (IncrementalOktaStream ):
200
228
cursor_field = "lastUpdated"
@@ -242,7 +270,7 @@ class UserRoleAssignments(OktaStream):
242
270
use_cache = True
243
271
244
272
def stream_slices (self , ** kwargs ):
245
- user_stream = Users (authenticator = self .authenticator , url_base = self .url_base )
273
+ user_stream = Users (authenticator = self .authenticator , url_base = self .url_base , start_date = self . start_date )
246
274
for user in user_stream .read_records (sync_mode = SyncMode .full_refresh ):
247
275
yield {"user_id" : user ["id" ]}
248
276
@@ -264,7 +292,7 @@ def parse_response(
264
292
yield from response .json ()["permissions" ]
265
293
266
294
def stream_slices (self , ** kwargs ):
267
- custom_roles = CustomRoles (authenticator = self .authenticator , url_base = self .url_base )
295
+ custom_roles = CustomRoles (authenticator = self .authenticator , url_base = self .url_base , start_date = self . start_date )
268
296
for role in custom_roles .read_records (sync_mode = SyncMode .full_refresh ):
269
297
yield {"role_id" : role ["id" ]}
270
298
@@ -273,66 +301,11 @@ def path(self, stream_slice: Mapping[str, Any] = None, **kwargs) -> str:
273
301
return f"iam/roles/{ role_id } /permissions"
274
302
275
303
276
- class OktaOauth2Authenticator (Oauth2Authenticator ):
277
- def get_refresh_request_body (self ) -> Mapping [str , Any ]:
278
- return {
279
- "grant_type" : "refresh_token" ,
280
- "refresh_token" : self .refresh_token ,
281
- }
282
-
283
- def refresh_access_token (self ) -> Tuple [str , int ]:
284
- try :
285
- response = requests .request (
286
- method = "POST" ,
287
- url = self .token_refresh_endpoint ,
288
- data = self .get_refresh_request_body (),
289
- auth = (self .client_id , self .client_secret ),
290
- )
291
- response .raise_for_status ()
292
- response_json = response .json ()
293
- return response_json ["access_token" ], response_json ["expires_in" ]
294
- except Exception as e :
295
- raise Exception (f"Error while refreshing access token: { e } " ) from e
296
-
297
-
298
304
class SourceOkta (AbstractSource ):
299
- def initialize_authenticator (self , config : Mapping [str , Any ]):
300
- if "token" in config :
301
- return TokenAuthenticator (config ["token" ], auth_method = "SSWS" )
302
-
303
- creds = config .get ("credentials" )
304
- if not creds :
305
- raise Exception ("Config validation error. `credentials` not specified." )
306
-
307
- auth_type = creds .get ("auth_type" )
308
- if not auth_type :
309
- raise Exception ("Config validation error. `auth_type` not specified." )
310
-
311
- if auth_type == "api_token" :
312
- return TokenAuthenticator (creds ["api_token" ], auth_method = "SSWS" )
313
-
314
- if auth_type == "oauth2.0" :
315
- return OktaOauth2Authenticator (
316
- token_refresh_endpoint = self .get_token_refresh_endpoint (config ),
317
- client_secret = creds ["client_secret" ],
318
- client_id = creds ["client_id" ],
319
- refresh_token = creds ["refresh_token" ],
320
- )
321
-
322
- @staticmethod
323
- def get_url_base (config : Mapping [str , Any ]) -> str :
324
- return config .get ("base_url" ) or f"https://{ config ['domain' ]} .okta.com"
325
-
326
- def get_api_endpoint (self , config : Mapping [str , Any ]) -> str :
327
- return parse .urljoin (self .get_url_base (config ), "/api/v1/" )
328
-
329
- def get_token_refresh_endpoint (self , config : Mapping [str , Any ]) -> str :
330
- return parse .urljoin (self .get_url_base (config ), "/oauth2/v1/token" )
331
-
332
305
def check_connection (self , logger , config ) -> Tuple [bool , any ]:
333
306
try :
334
- auth = self . initialize_authenticator (config )
335
- api_endpoint = self . get_api_endpoint (config )
307
+ auth = initialize_authenticator (config )
308
+ api_endpoint = get_api_endpoint (config )
336
309
url = parse .urljoin (api_endpoint , "users" )
337
310
338
311
response = requests .get (
@@ -349,13 +322,11 @@ def check_connection(self, logger, config) -> Tuple[bool, any]:
349
322
return False , "Failed to authenticate with the provided credentials"
350
323
351
324
def streams (self , config : Mapping [str , Any ]) -> List [Stream ]:
352
- auth = self .initialize_authenticator (config )
353
- api_endpoint = self .get_api_endpoint (config )
325
+ auth = initialize_authenticator (config )
326
+ api_endpoint = get_api_endpoint (config )
327
+ start_date = get_start_date (config )
354
328
355
- initialization_params = {
356
- "authenticator" : auth ,
357
- "url_base" : api_endpoint ,
358
- }
329
+ initialization_params = {"authenticator" : auth , "url_base" : api_endpoint , "start_date" : start_date }
359
330
360
331
return [
361
332
Groups (** initialization_params ),
0 commit comments