Skip to content

feat: add support for updating records for stripe #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 136 additions & 17 deletions airbyte-integrations/connectors/source-stripe/source_stripe/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import requests
from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.streams.http import HttpStream
from airbyte_cdk.sources.streams import IncrementalMixin
from datetime import datetime

STRIPE_ERROR_CODES: List = [
# stream requires additional permissions
Expand Down Expand Up @@ -135,6 +137,7 @@ def get_updated_state(self, current_stream_state: MutableMapping[str, Any], late
Return the latest state by comparing the cursor value in the latest record with the stream's most recent state object
and returning an updated state object.
"""

return {self.cursor_field: max(latest_record.get(self.cursor_field), current_stream_state.get(self.cursor_field, 0))}

def stream_slices(
Expand All @@ -159,12 +162,100 @@ def get_start_timestamp(self, stream_state) -> int:

return start_point

class IncrementalStripeStreamWithUpdates(IncrementalStripeStream):
"""
This is a base class for incremental streams that support updates using the events API
In first sync it will not get any updates as the records are already updated
After first sync it will get all updates starting from date of the last sync
"""
event_types = None
update_field = "event_created"
state_lastSync_key = "lastSync"
state_completed_key = "completed"
completed = False
def lookahead(self, iterable):
"""Pass through all values from the given iterable, to check
if there are more values to come after the current one,
or if it is the last value. Helps to define in the last state the completed flag.
"""
# Get an iterator and pull the first value.
it = iter(iterable)
last = next(it)
# Run the iterator to exhaustion (starting from the second value).
for val in it:
yield last
last = val
self.completed = True
yield last
def shouldFetchFromOriginalResource (self,stream_state):
durationInDaysFromLastSync = 0
hasState = bool(stream_state)
self.completed = stream_state.get(self.state_completed_key) or False
if hasState and self.completed:
then = datetime.fromtimestamp(stream_state.get(self.state_lastSync_key))
now = datetime.utcnow()
duration = now - then
durationInDaysFromLastSync = duration.days
# If last state is not present or the main sync didn't complete or the last sync was more than 30 days ago
# Fetch data from original stream else fetch data from events
shouldResetState = durationInDaysFromLastSync >= 30
return hasState == False or self.completed is False or shouldResetState
def read_records(self, stream_slice, stream_state, **kwargs) -> Iterable[Mapping[str, Any]]:
shouldFetchFromOriginalResource = self.shouldFetchFromOriginalResource(stream_state)
if shouldFetchFromOriginalResource:
# Set completed
self.completed = False
yield from self.lookahead(super().read_records(stream_slice=stream_slice, stream_state={}, **kwargs))
else:
yield from self.get_updates(stream_state, **kwargs)

def get_updates(self, stream_state, **kwargs)-> Iterable[Mapping[str, Any]]:
update_stream = Updates(event_types=self.event_types, authenticator=self.authenticator, account_id=self.account_id, start_date=self.start_date)
slice = update_stream.stream_slices(sync_mode="incremental", stream_state=stream_state)
for _slice in slice:
for event in update_stream.read_records(stream_slice=_slice,stream_state=stream_state, **kwargs):
self.set_record_id(event)
yield event

def set_record_id(self, record):
"""
Sets the primary_key of the record to a new field (i.e. {subscription_id: "..."}) and replace
the actual id with a unique value
"""
record["record_id"] = record[self.primary_key]
# Delete the original id because is reserved by warehouse and if it is present it gets used for deduplication
del record[self.primary_key]

def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapping]:
for item in super().parse_response(response, **kwargs):
if item.get(self.update_field) is None:
# set event_created as the cursor field value in case it is not set
item[self.update_field] = item[self.cursor_field]
self.set_record_id(item)
yield item

class Customers(IncrementalStripeStream):
def get_updated_state(self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]) -> Mapping[str, Any]:
"""
Return the latest state by comparing the cursor value in the latest record with the stream's most recent state object
and returning an updated state object.
"""
streamState = max(latest_record.get(self.cursor_field), current_stream_state.get(self.cursor_field, 0))
# We set the state for updates to current time
lastSyncAt = pendulum.now().int_timestamp
if self.completed:
# If the main sync is completed we use the event created to store events state
updateState = max(latest_record.get(self.update_field), current_stream_state.get(self.update_field, 0))
else:
# If the main sync is not completed we get events from the beginning to ensure no data loss
updateState = 0
latestState = { self.update_field:updateState, self.cursor_field: streamState, self.state_completed_key: self.completed, self.state_lastSync_key:lastSyncAt }
return latestState

class Customers(IncrementalStripeStreamWithUpdates):
"""
API docs: https://stripe.com/docs/api/customers/list
"""

event_types = ["customer.created", "customer.updated"]
cursor_field = "created"

def path(self, **kwargs) -> str:
Expand All @@ -183,11 +274,11 @@ def path(self, **kwargs) -> str:
return "balance_transactions"


class Charges(IncrementalStripeStream):
class Charges(IncrementalStripeStreamWithUpdates):
"""
API docs: https://stripe.com/docs/api/charges/list
"""

event_types = ["charge.created","charge.updated"]
cursor_field = "created"

def path(self, **kwargs) -> str:
Expand Down Expand Up @@ -224,11 +315,11 @@ def path(self, **kwargs):
return "coupons"


class Disputes(IncrementalStripeStream):
class Disputes(IncrementalStripeStreamWithUpdates):
"""
API docs: https://stripe.com/docs/api/disputes/list
"""

event_types = ["charge.dispute.created", "charge.dispute.updated"]
cursor_field = "created"

def path(self, **kwargs):
Expand All @@ -245,6 +336,33 @@ class Events(IncrementalStripeStream):
def path(self, **kwargs):
return "events"

class Updates(Events):
"""
A class for getting updated records using the event stream
"""
cursor_field = "event_created"
def __init__(self, event_types=None, **kwargs):
super().__init__(**kwargs)
# event_types defines the types of the events that will be used to fetch the specified updated
# example: event_types = "subscription.updated" will be used for the Charges stream
if not event_types:
raise Exception("event_types is required for the Updates stream")
self.event_types = event_types

def request_params(self, stream_slice: Mapping[str, Any] = None, **kwargs):
params = super().request_params(stream_slice=stream_slice, **kwargs)
if self.event_types:
params["types[]"] = self.event_types
return params

def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapping]:
for event in super().parse_response(response, **kwargs):
# The actual record exist in the data object
# example {"object": "event","data": {"object": {...} } }
eventData = event.get("data",{}).get("object",{})
# Add event_created field to the record using the created date of the event
eventData[self.cursor_field] = event.get("created")
yield eventData

class StripeSubStream(SingleEmptySliceMixin, StripeStream, ABC):
"""
Expand Down Expand Up @@ -351,11 +469,11 @@ def read_records(self, sync_mode: SyncMode, stream_slice: Optional[Mapping[str,
yield item


class Invoices(IncrementalStripeStream):
class Invoices(IncrementalStripeStreamWithUpdates):
"""
API docs: https://stripe.com/docs/api/invoices/list
"""

event_types = ["invoice.created", "invoice.updated"]
cursor_field = "created"

def path(self, **kwargs):
Expand All @@ -378,11 +496,11 @@ def path(self, stream_slice: Mapping[str, Any] = None, **kwargs):
return f"invoices/{stream_slice[self.parent_id]}/lines"


class InvoiceItems(IncrementalStripeStream):
class InvoiceItems(IncrementalStripeStreamWithUpdates):
"""
API docs: https://stripe.com/docs/api/invoiceitems/list
"""

event_types = ["invoiceitem.created", "invoiceitem.updated"]
cursor_field = "date"
name = "invoice_items"

Expand All @@ -401,33 +519,34 @@ def path(self, **kwargs):
return "payouts"


class Plans(IncrementalStripeStream):
class Plans(IncrementalStripeStreamWithUpdates):
"""
API docs: https://stripe.com/docs/api/plans/list
"""

event_types = ["plan.created", "plan.updated"]
cursor_field = "created"

def path(self, **kwargs):
return "plans"


class Products(IncrementalStripeStream):
class Products(IncrementalStripeStreamWithUpdates):
"""
API docs: https://stripe.com/docs/api/products/list
"""

event_types = ["product.created", "product.updated"]
cursor_field = "created"

def path(self, **kwargs):
return "products"


class Subscriptions(IncrementalStripeStream):
class Subscriptions(IncrementalStripeStreamWithUpdates):
"""
API docs: https://stripe.com/docs/api/subscriptions/list
"""

event_types = ["customer.subscription.created", "customer.subscription.updated"]
cursor_field = "created"
status = "all"

Expand Down Expand Up @@ -461,11 +580,11 @@ def request_params(self, stream_slice: Mapping[str, Any] = None, **kwargs):
return params


class Transfers(IncrementalStripeStream):
class Transfers(IncrementalStripeStreamWithUpdates):
"""
API docs: https://stripe.com/docs/api/transfers/list
"""

event_types = ["transfer.created", "transfer.updated"]
cursor_field = "created"

def path(self, **kwargs):
Expand Down