|
| 1 | +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. |
| 2 | + |
| 3 | +import json |
| 4 | +from datetime import datetime, timezone |
| 5 | +from typing import Any, Dict, Optional |
| 6 | +from unittest import TestCase |
| 7 | + |
| 8 | +import freezegun |
| 9 | +from airbyte_cdk.sources.source import TState |
| 10 | +from airbyte_cdk.test.catalog_builder import CatalogBuilder |
| 11 | +from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, read |
| 12 | +from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse |
| 13 | +from airbyte_cdk.test.mock_http.request import ANY_QUERY_PARAMS |
| 14 | +from airbyte_cdk.test.state_builder import StateBuilder |
| 15 | +from airbyte_protocol.models import ConfiguredAirbyteCatalog, SyncMode |
| 16 | +from config_builder import ConfigBuilder |
| 17 | +from source_salesforce import SourceSalesforce |
| 18 | +from source_salesforce.api import UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS |
| 19 | + |
| 20 | +_A_FIELD_NAME = "a_field" |
| 21 | +_ACCESS_TOKEN = "an_access_token" |
| 22 | +_API_VERSION = "v57.0" |
| 23 | +_CLIENT_ID = "a_client_id" |
| 24 | +_CLIENT_SECRET = "a_client_secret" |
| 25 | +_INSTANCE_URL = "https://instance.salesforce.com" |
| 26 | +_NOW = datetime.now(timezone.utc) |
| 27 | +_REFRESH_TOKEN = "a_refresh_token" |
| 28 | +_STREAM_NAME = UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS[0] |
| 29 | + |
| 30 | + |
| 31 | +def _catalog(sync_mode: SyncMode) -> ConfiguredAirbyteCatalog: |
| 32 | + return CatalogBuilder().with_stream(_STREAM_NAME, sync_mode).build() |
| 33 | + |
| 34 | + |
| 35 | +def _source(catalog: ConfiguredAirbyteCatalog, config: Dict[str, Any], state: Optional[TState]) -> SourceSalesforce: |
| 36 | + return SourceSalesforce(catalog, config, state) |
| 37 | + |
| 38 | + |
| 39 | +def _read( |
| 40 | + sync_mode: SyncMode, |
| 41 | + config_builder: Optional[ConfigBuilder] = None, |
| 42 | + expecting_exception: bool = False |
| 43 | +) -> EntrypointOutput: |
| 44 | + catalog = _catalog(sync_mode) |
| 45 | + config = config_builder.build() if config_builder else ConfigBuilder().build() |
| 46 | + state = StateBuilder().build() |
| 47 | + return read(_source(catalog, config, state), config, catalog, state, expecting_exception) |
| 48 | + |
| 49 | + |
| 50 | +def _given_authentication(http_mocker: HttpMocker, client_id: str, client_secret: str, refresh_token: str) -> None: |
| 51 | + http_mocker.post( |
| 52 | + HttpRequest( |
| 53 | + "https://login.salesforce.com/services/oauth2/token", |
| 54 | + query_params=ANY_QUERY_PARAMS, |
| 55 | + body=f"grant_type=refresh_token&client_id={client_id}&client_secret={client_secret}&refresh_token={refresh_token}" |
| 56 | + ), |
| 57 | + HttpResponse(json.dumps({"access_token": _ACCESS_TOKEN, "instance_url": _INSTANCE_URL})), |
| 58 | + ) |
| 59 | + |
| 60 | + |
| 61 | +def _given_stream(http_mocker: HttpMocker, stream_name: str, field_name: str) -> None: |
| 62 | + http_mocker.get( |
| 63 | + HttpRequest(f"{_INSTANCE_URL}/services/data/{_API_VERSION}/sobjects"), |
| 64 | + HttpResponse(json.dumps({"sobjects": [{"name": stream_name, "queryable": True}]})), |
| 65 | + ) |
| 66 | + http_mocker.get( |
| 67 | + HttpRequest(f"{_INSTANCE_URL}/services/data/{_API_VERSION}/sobjects/AcceptedEventRelation/describe"), |
| 68 | + HttpResponse(json.dumps({"fields": [{"name": field_name, "type": "string"}]})), |
| 69 | + ) |
| 70 | + |
| 71 | + |
| 72 | +@freezegun.freeze_time(_NOW.isoformat()) |
| 73 | +class FullRefreshTest(TestCase): |
| 74 | + |
| 75 | + def setUp(self) -> None: |
| 76 | + self._config = ConfigBuilder().client_id(_CLIENT_ID).client_secret(_CLIENT_SECRET).refresh_token(_REFRESH_TOKEN) |
| 77 | + |
| 78 | + @HttpMocker() |
| 79 | + def test_given_error_on_fetch_chunk_when_read_then_retry(self, http_mocker: HttpMocker) -> None: |
| 80 | + _given_authentication(http_mocker, _CLIENT_ID, _CLIENT_SECRET, _REFRESH_TOKEN) |
| 81 | + _given_stream(http_mocker, _STREAM_NAME, _A_FIELD_NAME) |
| 82 | + http_mocker.get( |
| 83 | + HttpRequest(f"{_INSTANCE_URL}/services/data/{_API_VERSION}/queryAll?q=SELECT+{_A_FIELD_NAME}+FROM+{_STREAM_NAME}+"), |
| 84 | + [ |
| 85 | + HttpResponse("", status_code=406), |
| 86 | + HttpResponse(json.dumps({"records": [{"a_field": "a_value"}]})), |
| 87 | + ] |
| 88 | + ) |
| 89 | + |
| 90 | + output = _read(SyncMode.full_refresh, self._config) |
| 91 | + |
| 92 | + assert len(output.records) == 1 |
0 commit comments