Skip to content

Salesforce: retry on download_data and create_stream_job #36385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 3, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ data:
connectorSubtype: api
connectorType: source
definitionId: b117307c-14b6-41aa-9422-947e34922962
dockerImageTag: 2.4.0
dockerImageTag: 2.4.1
dockerRepository: airbyte/source-salesforce
documentationUrl: https://docs.airbyte.com/integrations/sources/salesforce
githubIssueLabel: source-salesforce
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = [ "poetry-core>=1.0.0",]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
version = "2.4.0"
version = "2.4.1"
name = "source-salesforce"
description = "Source implementation for Salesforce."
authors = [ "Airbyte <[email protected]>",]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
from typing import Any, List, Mapping, Optional, Tuple

import backoff
import requests # type: ignore[import]
from airbyte_cdk.models import ConfiguredAirbyteCatalog
from airbyte_cdk.utils import AirbyteTracedException
Expand Down Expand Up @@ -300,7 +301,7 @@ def get_validated_streams(self, config: Mapping[str, Any], catalog: ConfiguredAi
validated_streams = [stream_name for stream_name in stream_names if self.filter_streams(stream_name)]
return {stream_name: sobject_options for stream_name, sobject_options in stream_objects.items() if stream_name in validated_streams}

@default_backoff_handler(max_tries=5, factor=5)
@default_backoff_handler(max_tries=5, backoff_method=backoff.expo, backoff_params={"factor": 5})
def _make_request(
self, http_method: str, url: str, headers: dict = None, body: dict = None, stream: bool = False, params: dict = None
) -> requests.models.Response:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,50 @@
exceptions.ReadTimeout,
exceptions.ConnectionError,
exceptions.HTTPError,
# We've had a couple of customers with ProtocolErrors, namely:
# * A self-managed instance during `BulkSalesforceStream.download_data`. This customer had an abnormally high number of ConnectionError
# which seems to indicate problems with his network infrastructure in general. The exact error was: `urllib3.exceptions.ProtocolError: ('Connection broken: IncompleteRead(905 bytes read, 119 more expected)', IncompleteRead(905 bytes read, 119 more expected))`
# * A cloud customer with very long syncs. All those syncs would end up with the following error: `urllib3.exceptions.ProtocolError: ("Connection broken: InvalidChunkLength(got length b'', 0 bytes read)", InvalidChunkLength(got length b'', 0 bytes read))`
# Without much more information, we will make it retryable hoping that performing the same request will work.
exceptions.ChunkedEncodingError,
# We've had examples where the response from Salesforce was not a JSON response. Those cases where error cases though. For example:
# https://github.com/airbytehq/airbyte-internal-issues/issues/6855. We will assume that this is an edge issue and that retry should help
exceptions.JSONDecodeError,
)

_RETRYABLE_400_STATUS_CODES = {
# Using debug mode and breakpointing on the issue, we were able to validate that there issues are retryable. We've also opened a case
# with Salesforce to try to understand what is causing that as the response does not have a body.
406,
# Most of the time, they don't have a body but there was one from the Salesforce Edge mentioning "We are setting things up. This process
# can take a few minutes. This page will auto-refresh when ready. If it takes too long, please contact support or visit our <a>status
# page</a> for more information." We therefore assume this is a transient error and will retry on it.
420,
codes.too_many_requests,
}


logger = logging.getLogger("airbyte")


def default_backoff_handler(max_tries: int, factor: int, **kwargs):
def default_backoff_handler(max_tries: int, backoff_method=None, backoff_params=None):
if backoff_method is None or backoff_params is None:
if not (backoff_method is None and backoff_params is None):
raise ValueError("Both `backoff_method` and `backoff_params` need to be provided if one is provided")
backoff_method = backoff.expo
backoff_params = {"factor": 15}

def log_retry_attempt(details):
_, exc, _ = sys.exc_info()
logger.info(str(exc))
logger.info(f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying...")

def should_give_up(exc):
give_up = exc.response is not None and exc.response.status_code != codes.too_many_requests and 400 <= exc.response.status_code < 500
give_up = (
exc.response is not None
and exc.response.status_code not in _RETRYABLE_400_STATUS_CODES
and 400 <= exc.response.status_code < 500
)

# Salesforce can return an error with a limit using a 403 code error.
if exc.response is not None and exc.response.status_code == codes.forbidden:
Expand All @@ -40,12 +71,11 @@ def should_give_up(exc):
return give_up

return backoff.on_exception(
backoff.expo,
backoff_method,
TRANSIENT_EXCEPTIONS,
jitter=None,
on_backoff=log_retry_attempt,
giveup=should_give_up,
max_tries=max_tries,
factor=factor,
**kwargs,
**backoff_params,
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import uuid
from abc import ABC
from contextlib import closing
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Type, Union

import backoff
import pandas as pd
import pendulum
import requests # type: ignore[import]
Expand All @@ -31,14 +32,15 @@
from .api import PARENT_SALESFORCE_OBJECTS, UNSUPPORTED_FILTERING_STREAMS, Salesforce
from .availability_strategy import SalesforceAvailabilityStrategy
from .exceptions import SalesforceException, TmpFileIOError
from .rate_limiting import default_backoff_handler
from .rate_limiting import TRANSIENT_EXCEPTIONS, default_backoff_handler

# https://stackoverflow.com/a/54517228
CSV_FIELD_SIZE_LIMIT = int(ctypes.c_ulong(-1).value // 2)
csv.field_size_limit(CSV_FIELD_SIZE_LIMIT)

DEFAULT_ENCODING = "utf-8"
LOOKBACK_SECONDS = 600 # based on https://trailhead.salesforce.com/trailblazer-community/feed/0D54V00007T48TASAZ
_JOB_TRANSIENT_ERRORS_MAX_RETRY = 1


class SalesforceStream(HttpStream, ABC):
Expand Down Expand Up @@ -351,24 +353,38 @@ def path(self, next_page_token: Mapping[str, Any] = None, **kwargs: Any) -> str:

transformer = TypeTransformer(TransformConfig.CustomSchemaNormalization | TransformConfig.DefaultSchemaNormalization)

@default_backoff_handler(max_tries=5, factor=15)
@default_backoff_handler(max_tries=5, backoff_method=backoff.expo, backoff_params={"factor": 15})
def _send_http_request(self, method: str, url: str, json: dict = None, headers: dict = None, stream: bool = False):
"""
This method should be used when you don't have to read data from the HTTP body. Else, you will have to retry when you actually read
the response buffer (which is either by calling `json` or `iter_content`)
"""
return self._non_retryable_send_http_request(method, url, json, headers, stream)

def _non_retryable_send_http_request(self, method: str, url: str, json: dict = None, headers: dict = None, stream: bool = False):
headers = self.authenticator.get_auth_header() if not headers else headers | self.authenticator.get_auth_header()
response = self._session.request(method, url=url, headers=headers, json=json, stream=stream)
if response.status_code not in [200, 204]:
self.logger.error(f"error body: {response.text}, sobject options: {self.sobject_options}")
response.raise_for_status()
return response

@default_backoff_handler(max_tries=5, backoff_method=backoff.expo, backoff_params={"factor": 15})
def _create_stream_job(self, query: str, url: str) -> Optional[str]:
json = {"operation": "queryAll", "query": query, "contentType": "CSV", "columnDelimiter": "COMMA", "lineEnding": "LF"}
response = self._non_retryable_send_http_request("POST", url, json=json)
job_id: str = response.json()["id"]
return job_id

def create_stream_job(self, query: str, url: str) -> Optional[str]:
"""
docs: https://developer.salesforce.com/docs/atlas.en-us.api_asynch.meta/api_asynch/create_job.html

Note that we want to retry during connection issues as well. Those can occur when calling `.json()`. Even in the case of a
connection error during a HTTPError, we will retry as else, we won't be able to take the right action.
"""
json = {"operation": "queryAll", "query": query, "contentType": "CSV", "columnDelimiter": "COMMA", "lineEnding": "LF"}
try:
response = self._send_http_request("POST", url, json=json)
job_id: str = response.json()["id"]
return job_id
return self._create_stream_job(query, url)
except exceptions.HTTPError as error:
if error.response.status_code in [codes.FORBIDDEN, codes.BAD_REQUEST]:
# A part of streams can't be used by BULK API. Every API version can have a custom list of
Expand All @@ -383,9 +399,7 @@ def create_stream_job(self, query: str, url: str) -> Optional[str]:
# updated query: "Select Name, (Select Subject,ActivityType from ActivityHistories) from Contact"
# The second variant forces customisation for every case (ActivityHistory, ActivityHistories etc).
# And the main problem is these subqueries doesn't support CSV response format.
error_data = error.response.json()[0]
error_code = error_data.get("errorCode")
error_message = error_data.get("message", "")
error_code, error_message = self._extract_error_code_and_message(error.response)
if error_message == "Selecting compound data not supported in Bulk Query" or (
error_code == "INVALIDENTITY" and "is not supported by the Bulk API" in error_message
):
Expand All @@ -401,7 +415,7 @@ def create_stream_job(self, query: str, url: str) -> Optional[str]:
elif error.response.status_code == codes.FORBIDDEN and error_code == "REQUEST_LIMIT_EXCEEDED":
self.logger.error(
f"Cannot receive data for stream '{self.name}' ,"
f"sobject options: {self.sobject_options}, Error message: '{error_data.get('message')}'"
f"sobject options: {self.sobject_options}, Error message: '{error_message}'"
)
elif error.response.status_code == codes.BAD_REQUEST and error_message.endswith("does not support query"):
self.logger.error(
Expand Down Expand Up @@ -437,9 +451,7 @@ def wait_for_job(self, url: str) -> str:
try:
job_info = self._send_http_request("GET", url=url).json()
except exceptions.HTTPError as error:
error_data = error.response.json()[0]
error_code = error_data.get("errorCode")
error_message = error_data.get("message", "")
error_code, error_message = self._extract_error_code_and_message(error.response)
if (
"We can't complete the action because enabled transaction security policies took too long to complete." in error_message
and error_code == "TXN_SECURITY_METERING_ERROR"
Expand Down Expand Up @@ -473,6 +485,19 @@ def wait_for_job(self, url: str) -> str:
self.logger.warning(f"Not wait the {self.name} data for {self.DEFAULT_WAIT_TIMEOUT_SECONDS} seconds, data: {job_info}!!")
return job_status

def _extract_error_code_and_message(self, response: requests.Response) -> tuple[Optional[str], str]:
try:
error_data = response.json()[0]
return error_data.get("errorCode"), error_data.get("message", "")
except exceptions.JSONDecodeError:
self.logger.warning(f"The response for `{response.request.url}` is not a JSON but was `{response.content}`")
except IndexError:
self.logger.warning(
f"The response for `{response.request.url}` was expected to be a list with at least one element but was `{response.content}`"
)

return None, ""

def execute_job(self, query: str, url: str) -> Tuple[Optional[str], Optional[str]]:
job_status = "Failed"
for i in range(0, self.MAX_RETRY_NUMBER):
Expand Down Expand Up @@ -520,6 +545,7 @@ def get_response_encoding(self, headers) -> str:

return self.encoding

@default_backoff_handler(max_tries=5, backoff_method=backoff.constant, backoff_params={"interval": 5})
def download_data(self, url: str, chunk_size: int = 1024) -> tuple[str, str, dict]:
"""
Retrieves binary data result from successfully `executed_job`, using chunks, to avoid local memory limitations.
Expand All @@ -529,7 +555,7 @@ def download_data(self, url: str, chunk_size: int = 1024) -> tuple[str, str, dic
"""
# set filepath for binary data from response
tmp_file = str(uuid.uuid4())
with closing(self._send_http_request("GET", url, headers={"Accept-Encoding": "gzip"}, stream=True)) as response, open(
with closing(self._non_retryable_send_http_request("GET", url, headers={"Accept-Encoding": "gzip"}, stream=True)) as response, open(
tmp_file, "wb"
) as data_file:
response_headers = response.headers
Expand Down Expand Up @@ -615,6 +641,7 @@ def read_records(
cursor_field: List[str] = None,
stream_slice: Mapping[str, Any] = None,
stream_state: Mapping[str, Any] = None,
call_count: int = 0,
) -> Iterable[Mapping[str, Any]]:
stream_state = stream_state or {}
next_page_token = None
Expand Down Expand Up @@ -643,7 +670,17 @@ def read_records(
while True:
req = PreparedRequest()
req.prepare_url(f"{job_full_url}/results", {"locator": salesforce_bulk_api_locator})
tmp_file, response_encoding, response_headers = self.download_data(url=req.url)
try:
tmp_file, response_encoding, response_headers = self.download_data(url=req.url)
except TRANSIENT_EXCEPTIONS as exception:
if call_count >= _JOB_TRANSIENT_ERRORS_MAX_RETRY:
self.logger.error(f"Downloading data failed even after {call_count} retries. Stopping retry and raising exception")
raise exception
self.logger.warning(f"Downloading data failed after {call_count} retries. Retrying the whole job...")
call_count += 1
yield from self.read_records(sync_mode, cursor_field, stream_slice, stream_state, call_count=call_count)
return

for record in self.read_with_chunks(tmp_file, response_encoding):
yield record

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from airbyte_cdk.test.entrypoint_wrapper import read
from airbyte_cdk.utils import AirbyteTracedException
from conftest import encoding_symbols_parameters, generate_stream
from requests.exceptions import HTTPError
from requests.exceptions import ChunkedEncodingError, HTTPError
from source_salesforce.api import Salesforce
from source_salesforce.exceptions import AUTHENTICATION_ERROR_MESSAGE_MAPPING
from source_salesforce.source import SourceSalesforce
Expand All @@ -38,12 +38,19 @@
BulkIncrementalSalesforceStream,
BulkSalesforceStream,
BulkSalesforceSubStream,
Describe,
IncrementalRestSalesforceStream,
RestSalesforceStream,
SalesforceStream,
)

_A_CHUNKED_RESPONSE = [b"first chunk", b"second chunk"]
_A_JSON_RESPONSE = {"id": "any id"}
_A_SUCCESSFUL_JOB_CREATION_RESPONSE = {"state": "JobComplete"}
_A_PK = "a_pk"
_A_STREAM_NAME = "a_stream_name"

_NUMBER_OF_DOWNLOAD_TRIES = 5
_FIRST_CALL_FROM_JOB_CREATION = 1

_ANY_CATALOG = ConfiguredAirbyteCatalog.parse_obj({"streams": []})
_ANY_CONFIG = {}
_ANY_STATE = None
Expand Down Expand Up @@ -589,6 +596,52 @@ def test_csv_reader_dialect_unix():
assert result == data


@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
def test_given_retryable_error_when_download_data_then_retry(send_http_request_patch):
send_http_request_patch.return_value.iter_content.side_effect = [HTTPError(), _A_CHUNKED_RESPONSE]
BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=Mock(), pk=_A_PK).download_data(url="any url")
assert send_http_request_patch.call_count == 2


@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
def test_given_first_download_fail_when_download_data_then_retry_job_only_once(send_http_request_patch):
sf_api = Mock()
sf_api.generate_schema.return_value = {}
sf_api.instance_url = "http://test_given_first_download_fail_when_download_data_then_retry_job.com"
job_creation_return_values = [_A_JSON_RESPONSE, _A_SUCCESSFUL_JOB_CREATION_RESPONSE]
send_http_request_patch.return_value.json.side_effect = job_creation_return_values * 2
send_http_request_patch.return_value.iter_content.side_effect = HTTPError()

with pytest.raises(Exception):
list(BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=sf_api, pk=_A_PK).read_records(SyncMode.full_refresh))

assert send_http_request_patch.call_count == (len(job_creation_return_values) + _NUMBER_OF_DOWNLOAD_TRIES) * 2


@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
def test_given_http_errors_when_create_stream_job_then_retry(send_http_request_patch):
send_http_request_patch.return_value.json.side_effect = [HTTPError(), _A_JSON_RESPONSE]
BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=Mock(), pk=_A_PK).create_stream_job(query="any query", url="any url")
assert send_http_request_patch.call_count == 2


@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
def test_given_fail_with_http_errors_when_create_stream_job_then_handle_error(send_http_request_patch):
mocked_response = Mock()
mocked_response.status_code = 666
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👿

send_http_request_patch.return_value.json.side_effect = HTTPError(response=mocked_response)

with pytest.raises(HTTPError):
BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=Mock(), pk=_A_PK).create_stream_job(query="any query", url="any url")


@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
def test_given_retryable_error_that_are_not_http_errors_when_create_stream_job_then_retry(send_http_request_patch):
send_http_request_patch.return_value.json.side_effect = [ChunkedEncodingError(), _A_JSON_RESPONSE]
BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=Mock(), pk=_A_PK).create_stream_job(query="any query", url="any url")
assert send_http_request_patch.call_count == 2


@pytest.mark.parametrize(
"stream_names,catalog_stream_names,",
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#

import json
import pathlib
from typing import List
from unittest.mock import Mock

Expand All @@ -22,14 +23,14 @@ def time_sleep_mock(mocker):

@pytest.fixture(scope="module")
def bulk_catalog():
with open("unit_tests/bulk_catalog.json") as f:
with (pathlib.Path(__file__).parent / "bulk_catalog.json").open() as f:
data = json.loads(f.read())
return ConfiguredAirbyteCatalog.parse_obj(data)


@pytest.fixture(scope="module")
def rest_catalog():
with open("unit_tests/rest_catalog.json") as f:
with (pathlib.Path(__file__).parent / "rest_catalog.json").open() as f:
data = json.loads(f.read())
return ConfiguredAirbyteCatalog.parse_obj(data)

Expand Down
Loading
Loading