Skip to content

Commit df34893

Browse files
authored
feat(airbyte-cdk): replace pydantic BaseModel with dataclasses + serpyco-rs in protocol (#44444)
Signed-off-by: Artem Inzhyyants <[email protected]>
1 parent 21fddbd commit df34893

File tree

125 files changed

+2730
-2270
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

125 files changed

+2730
-2270
lines changed

airbyte-cdk/python/airbyte_cdk/config_observation.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010
from copy import copy
1111
from typing import Any, List, MutableMapping
1212

13-
from airbyte_cdk.models import AirbyteControlConnectorConfigMessage, AirbyteControlMessage, AirbyteMessage, OrchestratorType, Type
13+
from airbyte_cdk.models import (
14+
AirbyteControlConnectorConfigMessage,
15+
AirbyteControlMessage,
16+
AirbyteMessage,
17+
AirbyteMessageSerializer,
18+
OrchestratorType,
19+
Type,
20+
)
21+
from orjson import orjson
1422

1523

1624
class ObservedDict(dict): # type: ignore # disallow_any_generics is set to True, and dict is equivalent to dict[Any]
@@ -76,7 +84,7 @@ def emit_configuration_as_airbyte_control_message(config: MutableMapping[str, An
7684
See the airbyte_cdk.sources.message package
7785
"""
7886
airbyte_message = create_connector_config_control_message(config)
79-
print(airbyte_message.model_dump_json(exclude_unset=True))
87+
print(orjson.dumps(AirbyteMessageSerializer.dump(airbyte_message)).decode())
8088

8189

8290
def create_connector_config_control_message(config: MutableMapping[str, Any]) -> AirbyteMessage:

airbyte-cdk/python/airbyte_cdk/connector.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Any, Generic, Mapping, Optional, Protocol, TypeVar
1212

1313
import yaml
14-
from airbyte_cdk.models import AirbyteConnectionStatus, ConnectorSpecification
14+
from airbyte_cdk.models import AirbyteConnectionStatus, ConnectorSpecification, ConnectorSpecificationSerializer
1515

1616

1717
def load_optional_package_file(package: str, filename: str) -> Optional[bytes]:
@@ -84,7 +84,7 @@ def spec(self, logger: logging.Logger) -> ConnectorSpecification:
8484
else:
8585
raise FileNotFoundError("Unable to find spec.yaml or spec.json in the package.")
8686

87-
return ConnectorSpecification.parse_obj(spec_obj)
87+
return ConnectorSpecificationSerializer.load(spec_obj)
8888

8989
@abstractmethod
9090
def check(self, logger: logging.Logger, config: TConfig) -> AirbyteConnectionStatus:

airbyte-cdk/python/airbyte_cdk/connector_builder/main.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,17 @@
99
from airbyte_cdk.connector import BaseConnector
1010
from airbyte_cdk.connector_builder.connector_builder_handler import TestReadLimits, create_source, get_limits, read_stream, resolve_manifest
1111
from airbyte_cdk.entrypoint import AirbyteEntrypoint
12-
from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog
12+
from airbyte_cdk.models import (
13+
AirbyteMessage,
14+
AirbyteMessageSerializer,
15+
AirbyteStateMessage,
16+
ConfiguredAirbyteCatalog,
17+
ConfiguredAirbyteCatalogSerializer,
18+
)
1319
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
1420
from airbyte_cdk.sources.source import Source
1521
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
22+
from orjson import orjson
1623

1724

1825
def get_config_and_catalog_from_args(args: List[str]) -> Tuple[str, Mapping[str, Any], Optional[ConfiguredAirbyteCatalog], Any]:
@@ -32,7 +39,7 @@ def get_config_and_catalog_from_args(args: List[str]) -> Tuple[str, Mapping[str,
3239

3340
command = config["__command"]
3441
if command == "test_read":
35-
catalog = ConfiguredAirbyteCatalog.parse_obj(BaseConnector.read_config(catalog_path))
42+
catalog = ConfiguredAirbyteCatalogSerializer.load(BaseConnector.read_config(catalog_path))
3643
state = Source.read_state(state_path)
3744
else:
3845
catalog = None
@@ -67,7 +74,7 @@ def handle_request(args: List[str]) -> AirbyteMessage:
6774
command, config, catalog, state = get_config_and_catalog_from_args(args)
6875
limits = get_limits(config)
6976
source = create_source(config, limits)
70-
return handle_connector_builder_request(source, command, config, catalog, state, limits).json(exclude_unset=True)
77+
return AirbyteMessageSerializer.dump(handle_connector_builder_request(source, command, config, catalog, state, limits)) # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage
7178

7279

7380
if __name__ == "__main__":
@@ -76,4 +83,4 @@ def handle_request(args: List[str]) -> AirbyteMessage:
7683
except Exception as exc:
7784
error = AirbyteTracedException.from_exception(exc, message=f"Error handling request: {str(exc)}")
7885
m = error.as_airbyte_message()
79-
print(error.as_airbyte_message().model_dump_json(exclude_unset=True))
86+
print(orjson.dumps(AirbyteMessageSerializer.dump(m)).decode())

airbyte-cdk/python/airbyte_cdk/connector_builder/message_grouper.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,7 @@
1818
StreamReadSlices,
1919
)
2020
from airbyte_cdk.entrypoint import AirbyteEntrypoint
21-
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
22-
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
23-
from airbyte_cdk.sources.utils.types import JsonType
24-
from airbyte_cdk.utils import AirbyteTracedException
25-
from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer
26-
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer, SchemaValidationException
27-
from airbyte_protocol.models.airbyte_protocol import (
21+
from airbyte_cdk.models import (
2822
AirbyteControlMessage,
2923
AirbyteLogMessage,
3024
AirbyteMessage,
@@ -34,7 +28,13 @@
3428
OrchestratorType,
3529
TraceType,
3630
)
37-
from airbyte_protocol.models.airbyte_protocol import Type as MessageType
31+
from airbyte_cdk.models import Type as MessageType
32+
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
33+
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
34+
from airbyte_cdk.sources.utils.types import JsonType
35+
from airbyte_cdk.utils import AirbyteTracedException
36+
from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer
37+
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer, SchemaValidationException
3838

3939

4040
class MessageGrouper:
@@ -182,19 +182,19 @@ def _get_message_groups(
182182
if (
183183
at_least_one_page_in_group
184184
and message.type == MessageType.LOG
185-
and message.log.message.startswith(SliceLogger.SLICE_LOG_PREFIX)
185+
and message.log.message.startswith(SliceLogger.SLICE_LOG_PREFIX) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
186186
):
187187
yield StreamReadSlices(
188188
pages=current_slice_pages,
189189
slice_descriptor=current_slice_descriptor,
190190
state=[latest_state_message] if latest_state_message else [],
191191
)
192-
current_slice_descriptor = self._parse_slice_description(message.log.message)
192+
current_slice_descriptor = self._parse_slice_description(message.log.message) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
193193
current_slice_pages = []
194194
at_least_one_page_in_group = False
195-
elif message.type == MessageType.LOG and message.log.message.startswith(SliceLogger.SLICE_LOG_PREFIX):
195+
elif message.type == MessageType.LOG and message.log.message.startswith(SliceLogger.SLICE_LOG_PREFIX): # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
196196
# parsing the first slice
197-
current_slice_descriptor = self._parse_slice_description(message.log.message)
197+
current_slice_descriptor = self._parse_slice_description(message.log.message) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
198198
elif message.type == MessageType.LOG:
199199
if json_message is not None and self._is_http_log(json_message):
200200
if self._is_auxiliary_http_request(json_message):
@@ -221,17 +221,17 @@ def _get_message_groups(
221221
else:
222222
yield message.log
223223
elif message.type == MessageType.TRACE:
224-
if message.trace.type == TraceType.ERROR:
224+
if message.trace.type == TraceType.ERROR: # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has trace.type
225225
yield message.trace
226226
elif message.type == MessageType.RECORD:
227-
current_page_records.append(message.record.data)
227+
current_page_records.append(message.record.data) # type: ignore[union-attr] # AirbyteMessage with MessageType.RECORD has record.data
228228
records_count += 1
229229
schema_inferrer.accumulate(message.record)
230230
datetime_format_inferrer.accumulate(message.record)
231-
elif message.type == MessageType.CONTROL and message.control.type == OrchestratorType.CONNECTOR_CONFIG:
231+
elif message.type == MessageType.CONTROL and message.control.type == OrchestratorType.CONNECTOR_CONFIG: # type: ignore[union-attr] # AirbyteMessage with MessageType.CONTROL has control.type
232232
yield message.control
233233
elif message.type == MessageType.STATE:
234-
latest_state_message = message.state
234+
latest_state_message = message.state # type: ignore[assignment]
235235
else:
236236
if current_page_request or current_page_response or current_page_records:
237237
self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records)
@@ -246,7 +246,7 @@ def _need_to_close_page(at_least_one_page_in_group: bool, message: AirbyteMessag
246246
return (
247247
at_least_one_page_in_group
248248
and message.type == MessageType.LOG
249-
and (MessageGrouper._is_page_http_request(json_message) or message.log.message.startswith("slice:"))
249+
and (MessageGrouper._is_page_http_request(json_message) or message.log.message.startswith("slice:")) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
250250
)
251251

252252
@staticmethod

airbyte-cdk/python/airbyte_cdk/destinations/destination.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111

1212
from airbyte_cdk.connector import Connector
1313
from airbyte_cdk.exception_handler import init_uncaught_exception_handler
14-
from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog, Type
14+
from airbyte_cdk.models import AirbyteMessage, AirbyteMessageSerializer, ConfiguredAirbyteCatalog, ConfiguredAirbyteCatalogSerializer, Type
1515
from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit
1616
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
17-
from pydantic import ValidationError as V2ValidationError
17+
from orjson import orjson
1818

1919
logger = logging.getLogger("airbyte")
2020

@@ -36,14 +36,14 @@ def _parse_input_stream(self, input_stream: io.TextIOWrapper) -> Iterable[Airbyt
3636
"""Reads from stdin, converting to Airbyte messages"""
3737
for line in input_stream:
3838
try:
39-
yield AirbyteMessage.parse_raw(line)
40-
except V2ValidationError:
39+
yield AirbyteMessageSerializer.load(orjson.loads(line))
40+
except orjson.JSONDecodeError:
4141
logger.info(f"ignoring input which can't be deserialized as Airbyte Message: {line}")
4242

4343
def _run_write(
4444
self, config: Mapping[str, Any], configured_catalog_path: str, input_stream: io.TextIOWrapper
4545
) -> Iterable[AirbyteMessage]:
46-
catalog = ConfiguredAirbyteCatalog.parse_file(configured_catalog_path)
46+
catalog = ConfiguredAirbyteCatalogSerializer.load(orjson.loads(open(configured_catalog_path).read()))
4747
input_messages = self._parse_input_stream(input_stream)
4848
logger.info("Begin writing to the destination...")
4949
yield from self.write(config=config, configured_catalog=catalog, input_messages=input_messages)
@@ -117,4 +117,4 @@ def run(self, args: List[str]) -> None:
117117
parsed_args = self.parse_args(args)
118118
output_messages = self.run_cmd(parsed_args)
119119
for message in output_messages:
120-
print(message.model_dump_json(exclude_unset=True))
120+
print(orjson.dumps(AirbyteMessageSerializer.dump(message)).decode())

airbyte-cdk/python/airbyte_cdk/entrypoint.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,23 @@
1919
from airbyte_cdk.connector import TConfig
2020
from airbyte_cdk.exception_handler import init_uncaught_exception_handler
2121
from airbyte_cdk.logger import init_logger
22-
from airbyte_cdk.models import AirbyteMessage, FailureType, Status, Type
23-
from airbyte_cdk.models.airbyte_protocol import AirbyteStateStats, ConnectorSpecification # type: ignore [attr-defined]
22+
from airbyte_cdk.models import ( # type: ignore [attr-defined]
23+
AirbyteMessage,
24+
AirbyteMessageSerializer,
25+
AirbyteStateStats,
26+
ConnectorSpecification,
27+
FailureType,
28+
Status,
29+
Type,
30+
)
2431
from airbyte_cdk.sources import Source
2532
from airbyte_cdk.sources.connector_state_manager import HashableStreamDescriptor
2633
from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit, split_config
2734
from airbyte_cdk.utils import PrintBuffer, is_cloud_environment, message_utils
2835
from airbyte_cdk.utils.airbyte_secrets_utils import get_secrets, update_secrets
2936
from airbyte_cdk.utils.constants import ENV_REQUEST_CACHE_PATH
3037
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
38+
from orjson import orjson
3139
from requests import PreparedRequest, Response, Session
3240

3341
logger = init_logger("airbyte")
@@ -170,13 +178,13 @@ def read(self, source_spec: ConnectorSpecification, config: TConfig, catalog: An
170178
def handle_record_counts(message: AirbyteMessage, stream_message_count: DefaultDict[HashableStreamDescriptor, float]) -> AirbyteMessage:
171179
match message.type:
172180
case Type.RECORD:
173-
stream_message_count[HashableStreamDescriptor(name=message.record.stream, namespace=message.record.namespace)] += 1.0
181+
stream_message_count[HashableStreamDescriptor(name=message.record.stream, namespace=message.record.namespace)] += 1.0 # type: ignore[union-attr] # record has `stream` and `namespace`
174182
case Type.STATE:
175183
stream_descriptor = message_utils.get_stream_descriptor(message)
176184

177185
# Set record count from the counter onto the state message
178-
message.state.sourceStats = message.state.sourceStats or AirbyteStateStats()
179-
message.state.sourceStats.recordCount = stream_message_count.get(stream_descriptor, 0.0)
186+
message.state.sourceStats = message.state.sourceStats or AirbyteStateStats() # type: ignore[union-attr] # state has `sourceStats`
187+
message.state.sourceStats.recordCount = stream_message_count.get(stream_descriptor, 0.0) # type: ignore[union-attr] # state has `sourceStats`
180188

181189
# Reset the counter
182190
stream_message_count[stream_descriptor] = 0.0
@@ -197,8 +205,8 @@ def set_up_secret_filter(config: TConfig, connection_specification: Mapping[str,
197205
update_secrets(config_secrets)
198206

199207
@staticmethod
200-
def airbyte_message_to_string(airbyte_message: AirbyteMessage) -> Any:
201-
return airbyte_message.model_dump_json(exclude_unset=True)
208+
def airbyte_message_to_string(airbyte_message: AirbyteMessage) -> str:
209+
return orjson.dumps(AirbyteMessageSerializer.dump(airbyte_message)).decode() # type: ignore[no-any-return] # orjson.dumps(message).decode() always returns string
202210

203211
@classmethod
204212
def extract_state(cls, args: List[str]) -> Optional[Any]:

airbyte-cdk/python/airbyte_cdk/logger.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import logging.config
88
from typing import Any, Mapping, Optional, Tuple
99

10-
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage
10+
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteMessageSerializer, Level, Type
1111
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
12+
from orjson import orjson
1213

1314
LOGGING_CONFIG = {
1415
"version": 1,
@@ -42,11 +43,11 @@ class AirbyteLogFormatter(logging.Formatter):
4243

4344
# Transforming Python log levels to Airbyte protocol log levels
4445
level_mapping = {
45-
logging.FATAL: "FATAL",
46-
logging.ERROR: "ERROR",
47-
logging.WARNING: "WARN",
48-
logging.INFO: "INFO",
49-
logging.DEBUG: "DEBUG",
46+
logging.FATAL: Level.FATAL,
47+
logging.ERROR: Level.ERROR,
48+
logging.WARNING: Level.WARN,
49+
logging.INFO: Level.INFO,
50+
logging.DEBUG: Level.DEBUG,
5051
}
5152

5253
def format(self, record: logging.LogRecord) -> str:
@@ -59,8 +60,8 @@ def format(self, record: logging.LogRecord) -> str:
5960
else:
6061
message = super().format(record)
6162
message = filter_secrets(message)
62-
log_message = AirbyteMessage(type="LOG", log=AirbyteLogMessage(level=airbyte_level, message=message))
63-
return log_message.model_dump_json(exclude_unset=True) # type: ignore
63+
log_message = AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=airbyte_level, message=message))
64+
return orjson.dumps(AirbyteMessageSerializer.dump(log_message)).decode() # type: ignore[no-any-return] # orjson.dumps(message).decode() always returns string
6465

6566
@staticmethod
6667
def extract_extra_args_from_record(record: logging.LogRecord) -> Mapping[str, Any]:

airbyte-cdk/python/airbyte_cdk/models/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# of airbyte-cdk rather than a standalone package.
88
from .airbyte_protocol import (
99
AdvancedAuth,
10+
AirbyteStateStats,
1011
AirbyteAnalyticsTraceMessage,
1112
AirbyteCatalog,
1213
AirbyteConnectionStatus,
@@ -58,3 +59,12 @@
5859
TimeWithoutTimezone,
5960
TimeWithTimezone,
6061
)
62+
63+
from .airbyte_protocol_serializers import (
64+
AirbyteStreamStateSerializer,
65+
AirbyteStateMessageSerializer,
66+
AirbyteMessageSerializer,
67+
ConfiguredAirbyteCatalogSerializer,
68+
ConfiguredAirbyteStreamSerializer,
69+
ConnectorSpecificationSerializer,
70+
)

0 commit comments

Comments
 (0)