Skip to content

🚨🚨 Low code CDK: Decouple SimpleRetriever and HttpStream #28657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ModelToComponentFactory
from airbyte_cdk.sources.streams.http import HttpStream
from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever
from airbyte_cdk.utils.traced_exception import AirbyteTracedException

DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5
Expand Down Expand Up @@ -51,12 +51,14 @@ def create_source(config: Mapping[str, Any], limits: TestReadLimits) -> Manifest
emit_connector_builder_messages=True,
limit_pages_fetched_per_slice=limits.max_pages_per_slice,
limit_slices_fetched=limits.max_slices,
disable_retries=True
)
disable_retries=True,
),
)


def read_stream(source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, limits: TestReadLimits) -> AirbyteMessage:
def read_stream(
source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, limits: TestReadLimits
) -> AirbyteMessage:
try:
handler = MessageGrouper(limits.max_pages_per_slice, limits.max_slices)
stream_name = configured_catalog.streams[0].stream.name # The connector builder only supports a single stream
Expand Down Expand Up @@ -90,7 +92,13 @@ def resolve_manifest(source: ManifestDeclarativeSource) -> AirbyteMessage:
def list_streams(source: ManifestDeclarativeSource, config: Dict[str, Any]) -> AirbyteMessage:
try:
streams = [
{"name": http_stream.name, "url": urljoin(http_stream.url_base, http_stream.path())}
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long are we to remove the call to list streams? It breaks my heart every time I see us maintaining something that we want to remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @lmossman is about to get started on the frontend work that will allows us to not rely on the list api anymore. So a few more weeks at max

"name": http_stream.name,
"url": urljoin(
http_stream.requester.get_url_base(),
http_stream.requester.get_path(stream_state=None, stream_slice=None, next_page_token=None),
),
}
for http_stream in _get_http_streams(source, config)
]
return AirbyteMessage(
Expand All @@ -105,20 +113,20 @@ def list_streams(source: ManifestDeclarativeSource, config: Dict[str, Any]) -> A
return AirbyteTracedException.from_exception(exc, message=f"Error listing streams: {str(exc)}").as_airbyte_message()


def _get_http_streams(source: ManifestDeclarativeSource, config: Dict[str, Any]) -> List[HttpStream]:
def _get_http_streams(source: ManifestDeclarativeSource, config: Dict[str, Any]) -> List[SimpleRetriever]:
http_streams = []
for stream in source.streams(config=config):
if isinstance(stream, DeclarativeStream):
if isinstance(stream.retriever, HttpStream):
if isinstance(stream.retriever, SimpleRetriever):
http_streams.append(stream.retriever)
else:
raise TypeError(
f"A declarative stream should only have a retriever of type HttpStream, but received: {stream.retriever.__class__}"
f"A declarative stream should only have a retriever of type SimpleRetriever, but received: {stream.retriever.__class__}"
)
else:
raise TypeError(f"A declarative source should only contain streams of type DeclarativeStream, but received: {stream.__class__}")
return http_streams


def _emitted_at():
def _emitted_at() -> int:
return int(datetime.now().timestamp()) * 1000
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@ class DeclarativeStream(Stream):
schema_loader: Optional[SchemaLoader] = None
_name: str = field(init=False, repr=False, default="")
_primary_key: str = field(init=False, repr=False, default="")
_schema_loader: SchemaLoader = field(init=False, repr=False, default=None)
stream_cursor_field: Optional[Union[InterpolatedString, str]] = None

def __post_init__(self, parameters: Mapping[str, Any]):
self.stream_cursor_field = InterpolatedString.create(self.stream_cursor_field, parameters=parameters)
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._stream_cursor_field = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change as self.stream_cursor_field was always an InterpolatedString before and now it can stay as a string and self._stream_cursor_field is the field that is always an InterpolatedString. Is this an issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's used anywhere directly for access.

InterpolatedString.create(self.stream_cursor_field, parameters=parameters)
if isinstance(self.stream_cursor_field, str)
else self.stream_cursor_field
)
self._schema_loader = self.schema_loader if self.schema_loader else DefaultSchemaLoader(config=self.config, parameters=parameters)

@property
@property # type: ignore
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
return self._primary_key

Expand All @@ -53,7 +56,7 @@ def primary_key(self, value: str) -> None:
if not isinstance(value, property):
self._primary_key = value

@property
@property # type: ignore
def name(self) -> str:
"""
:return: Stream name. By default this is the implementing class name, but it can be overridden as needed.
Expand All @@ -67,14 +70,16 @@ def name(self, value: str) -> None:

@property
def state(self) -> MutableMapping[str, Any]:
return self.retriever.state
return self.retriever.state # type: ignore

@state.setter
def state(self, value: MutableMapping[str, Any]):
def state(self, value: MutableMapping[str, Any]) -> None:
"""State setter, accept state serialized by state getter."""
self.retriever.state = value

def get_updated_state(self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]):
def get_updated_state(
self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]
) -> MutableMapping[str, Any]:
return self.state

@property
Expand All @@ -83,22 +88,22 @@ def cursor_field(self) -> Union[str, List[str]]:
Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field.
:return: The name of the field used as a cursor. If the cursor is nested, return an array consisting of the path to the cursor.
"""
cursor = self.stream_cursor_field.eval(self.config)
cursor = self._stream_cursor_field.eval(self.config)
return cursor if cursor else []

def read_records(
self,
sync_mode: SyncMode,
cursor_field: List[str] = None,
stream_slice: Mapping[str, Any] = None,
stream_state: Mapping[str, Any] = None,
cursor_field: Optional[List[str]] = None,
stream_slice: Optional[Mapping[str, Any]] = None,
stream_state: Optional[Mapping[str, Any]] = None,
) -> Iterable[Mapping[str, Any]]:
"""
:param: stream_state We knowingly avoid using stream_state as we want cursors to manage their own state.
"""
yield from self.retriever.read_records(sync_mode, cursor_field, stream_slice)
yield from self.retriever.read_records(stream_slice)

def get_json_schema(self) -> Mapping[str, Any]:
def get_json_schema(self) -> Mapping[str, Any]: # type: ignore
"""
:return: A dict of the JSON schema representing this stream.

Expand All @@ -108,7 +113,7 @@ def get_json_schema(self) -> Mapping[str, Any]:
return self._schema_loader.get_json_schema()

def stream_slices(
self, *, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Mapping[str, Any] = None
self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None
) -> Iterable[Optional[Mapping[str, Any]]]:
"""
Override to define the slices for this stream. See the stream slicing section of the docs for more information.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ def create_http_requester(self, model: HttpRequesterModel, config: Config, *, na
http_method=model_http_method,
request_options_provider=request_options_provider,
config=config,
disable_retries=self._disable_retries,
parameters=model.parameters or {},
)

Expand Down Expand Up @@ -912,7 +913,6 @@ def create_simple_retriever(
config=config,
maximum_number_of_slices=self._limit_slices_fetched or 5,
parameters=model.parameters or {},
disable_retries=self._disable_retries,
message_repository=self._message_repository,
)
return SimpleRetriever(
Expand All @@ -925,7 +925,6 @@ def create_simple_retriever(
cursor=cursor,
config=config,
parameters=model.parameters or {},
disable_retries=self._disable_retries,
message_repository=self._message_repository,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class HttpRequester(Requester):
request_options_provider: Optional[InterpolatedRequestOptionsProvider] = None
error_handler: Optional[ErrorHandler] = None

disable_retries: bool = False
_DEFAULT_MAX_RETRY = 5
_DEFAULT_RETRY_FACTOR = 5

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._url_base = InterpolatedString.create(self.url_base, parameters=parameters)
self._path = InterpolatedString.create(self.path, parameters=parameters)
Expand Down Expand Up @@ -154,21 +158,6 @@ def get_request_body_json( # type: ignore
stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
)

def request_kwargs(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
# todo: there are a few integrations that override the request_kwargs() method, but the use case for why kwargs over existing
# constructs is a little unclear. We may revisit this, but for now lets leave it out of the DSL
return {}

disable_retries: bool = False
_DEFAULT_MAX_RETRY = 5
_DEFAULT_RETRY_FACTOR = 5

@property
def max_retries(self) -> Union[int, None]:
if self.disable_retries:
Expand Down Expand Up @@ -236,6 +225,7 @@ def _get_mapping(

def _get_request_options(
self,
stream_state: Optional[StreamState],
stream_slice: Optional[StreamSlice],
next_page_token: Optional[Mapping[str, Any]],
requester_method: Callable[..., Optional[Union[Mapping[str, Any], str]]],
Expand All @@ -247,16 +237,22 @@ def _get_request_options(
Raise a ValueError if there's a key collision
Returned merged mapping otherwise
"""
requester_mapping, requester_keys = self._get_mapping(requester_method, stream_slice=stream_slice, next_page_token=next_page_token)
requester_mapping, requester_keys = self._get_mapping(
requester_method, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
)
auth_options_mapping, auth_options_keys = self._get_mapping(auth_options_method)
extra_options = extra_options or {}
extra_mapping, extra_keys = self._get_mapping(lambda: extra_options)

all_mappings = [requester_mapping, auth_options_mapping, extra_mapping]
all_keys = [requester_keys, auth_options_keys, extra_keys]

string_options = sum(isinstance(mapping, str) for mapping in all_mappings)
# If more than one mapping is a string, raise a ValueError
if sum(isinstance(mapping, str) for mapping in all_mappings) > 1:
if string_options > 1:
raise ValueError("Cannot combine multiple options if one is a string")

if string_options == 1 and sum(len(keys) for keys in all_keys) > 0:
raise ValueError("Cannot combine multiple options if one is a string")

# If any mapping is a string, return it
Expand All @@ -275,6 +271,7 @@ def _get_request_options(

def _request_headers(
self,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
extra_headers: Optional[Mapping[str, Any]] = None,
Expand All @@ -284,6 +281,7 @@ def _request_headers(
Authentication headers will overwrite any overlapping headers returned from this method.
"""
headers = self._get_request_options(
stream_state,
stream_slice,
next_page_token,
self.get_request_headers,
Expand All @@ -296,6 +294,7 @@ def _request_headers(

def _request_params(
self,
stream_state: Optional[StreamState],
stream_slice: Optional[StreamSlice],
next_page_token: Optional[Mapping[str, Any]],
extra_params: Optional[Mapping[str, Any]] = None,
Expand All @@ -306,14 +305,15 @@ def _request_params(
E.g: you might want to define query parameters for paging if next_page_token is not None.
"""
options = self._get_request_options(
stream_slice, next_page_token, self.get_request_params, self.get_authenticator().get_request_params, extra_params
stream_state, stream_slice, next_page_token, self.get_request_params, self.get_authenticator().get_request_params, extra_params
)
if isinstance(options, str):
raise ValueError("Request params cannot be a string")
return options

def _request_body_data(
self,
stream_state: Optional[StreamState],
stream_slice: Optional[StreamSlice],
next_page_token: Optional[Mapping[str, Any]],
extra_body_data: Optional[Union[Mapping[str, Any], str]] = None,
Expand All @@ -329,11 +329,17 @@ def _request_body_data(
"""
# Warning: use self.state instead of the stream_state passed as argument!
return self._get_request_options(
stream_slice, next_page_token, self.get_request_body_data, self.get_authenticator().get_request_body_data, extra_body_data
stream_state,
stream_slice,
next_page_token,
self.get_request_body_data,
self.get_authenticator().get_request_body_data,
extra_body_data,
)

def _request_body_json(
self,
stream_state: Optional[StreamState],
stream_slice: Optional[StreamSlice],
next_page_token: Optional[Mapping[str, Any]],
extra_body_json: Optional[Mapping[str, Any]] = None,
Expand All @@ -345,7 +351,12 @@ def _request_body_json(
"""
# Warning: use self.state instead of the stream_state passed as argument!
options = self._get_request_options(
stream_slice, next_page_token, self.get_request_body_json, self.get_authenticator().get_request_body_json, extra_body_json
stream_state,
stream_slice,
next_page_token,
self.get_request_body_json,
self.get_authenticator().get_request_body_json,
extra_body_json,
)
if isinstance(options, str):
raise ValueError("Request body json cannot be a string")
Expand Down Expand Up @@ -396,6 +407,7 @@ def _create_prepared_request(

def send_request(
self,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
path: Optional[str] = None,
Expand All @@ -405,11 +417,13 @@ def send_request(
request_body_json: Optional[Mapping[str, Any]] = None,
) -> Optional[requests.Response]:
request = self._create_prepared_request(
path=path if path is not None else self.get_path(stream_state=None, stream_slice=stream_slice, next_page_token=next_page_token),
headers=self._request_headers(stream_slice, next_page_token, request_headers),
params=self._request_params(stream_slice, next_page_token, request_params),
json=self._request_body_json(stream_slice, next_page_token, request_body_json),
data=self._request_body_data(stream_slice, next_page_token, request_body_data),
path=path
if path is not None
else self.get_path(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token),
headers=self._request_headers(stream_state, stream_slice, next_page_token, request_headers),
params=self._request_params(stream_state, stream_slice, next_page_token, request_params),
json=self._request_body_json(stream_state, stream_slice, next_page_token, request_body_json),
data=self._request_body_data(stream_state, stream_slice, next_page_token, request_body_data),
)

response = self._send_with_retry(request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

from dataclasses import InitVar, dataclass
from typing import Any, List, Mapping, Optional, Union
from typing import Any, List, Mapping, MutableMapping, Optional, Union

import requests
from airbyte_cdk.sources.declarative.requesters.paginators.paginator import Paginator
Expand All @@ -27,7 +27,7 @@ def get_request_params(
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
) -> MutableMapping[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we returning this a something Mutable? Does anything mutates the request params? Mutable things are often preferred as it ensures that we don't update something in memory that someone else rely on

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just fixing the types (it's coming from here:

)

I agree that we should make this non-mutable but I would like to split it out of this PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be safe to change the RequestOptionsProvider's return type to Mapping, but fine to do this separately

return {}

def get_request_headers(
Expand Down Expand Up @@ -60,6 +60,6 @@ def get_request_body_json(
def next_page_token(self, response: requests.Response, last_records: List[Record]) -> Mapping[str, Any]:
return {}

def reset(self):
def reset(self) -> None:
# No state to reset
pass
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Paginator(ABC, RequestOptionsProvider):
"""

@abstractmethod
def reset(self):
def reset(self) -> None:
"""
Reset the pagination's inner state
"""
Expand Down
Loading