Skip to content

Commit 1c5ac5b

Browse files
authored
🏗️ Python CDK: add schema transformer class (#6139)
* Python CDK: add schema transformer class
1 parent d386ed7 commit 1c5ac5b

File tree

8 files changed

+597
-4
lines changed

8 files changed

+597
-4
lines changed

airbyte-cdk/python/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Changelog
22

3+
## 0.1.24
4+
Added Transform class to use for mutating record value types so they adhere to jsonschema definition.
5+
36
## 0.1.23
47
Added the ability to use caching for efficient synchronization of nested streams.
58

airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
import copy
2727
from abc import ABC, abstractmethod
2828
from datetime import datetime
29-
from typing import Any, Iterator, List, Mapping, MutableMapping, Optional, Tuple
29+
from functools import lru_cache
30+
from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional, Tuple
3031

3132
from airbyte_cdk.logger import AirbyteLogger
3233
from airbyte_cdk.models import (
@@ -35,6 +36,7 @@
3536
AirbyteMessage,
3637
AirbyteRecordMessage,
3738
AirbyteStateMessage,
39+
AirbyteStream,
3840
ConfiguredAirbyteCatalog,
3941
ConfiguredAirbyteStream,
4042
Status,
@@ -45,6 +47,7 @@
4547
from airbyte_cdk.sources.streams import Stream
4648
from airbyte_cdk.sources.streams.http.http import HttpStream
4749
from airbyte_cdk.sources.utils.schema_helpers import InternalConfig, split_config
50+
from airbyte_cdk.sources.utils.transform import TypeTransformer
4851

4952

5053
class AbstractSource(Source, ABC):
@@ -70,6 +73,9 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
7073
:return: A list of the streams in this source connector.
7174
"""
7275

76+
# Stream name to instance map for applying output object transformation
77+
_stream_to_instance_map: Dict[str, AirbyteStream] = {}
78+
7379
@property
7480
def name(self) -> str:
7581
"""Source name"""
@@ -101,6 +107,7 @@ def read(
101107
# TODO assert all streams exist in the connector
102108
# get the streams once in case the connector needs to make any queries to generate them
103109
stream_instances = {s.name: s for s in self.streams(config)}
110+
self._stream_to_instance_map = stream_instances
104111
for configured_stream in catalog.streams:
105112
stream_instance = stream_instances.get(configured_stream.stream.name)
106113
if not stream_instance:
@@ -227,7 +234,25 @@ def _checkpoint_state(self, stream_name, stream_state, connector_state, logger):
227234
connector_state[stream_name] = stream_state
228235
return AirbyteMessage(type=MessageType.STATE, state=AirbyteStateMessage(data=connector_state))
229236

237+
@lru_cache(maxsize=None)
238+
def _get_stream_transformer_and_schema(self, stream_name: str) -> Tuple[TypeTransformer, dict]:
239+
"""
240+
Lookup stream's transform object and jsonschema based on stream name.
241+
This function would be called a lot so using caching to save on costly
242+
get_json_schema operation.
243+
:param stream_name name of stream from catalog.
244+
:return tuple with stream transformer object and discover json schema.
245+
"""
246+
stream_instance = self._stream_to_instance_map.get(stream_name)
247+
return stream_instance.transformer, stream_instance.get_json_schema()
248+
230249
def _as_airbyte_record(self, stream_name: str, data: Mapping[str, Any]):
231250
now_millis = int(datetime.now().timestamp()) * 1000
251+
transformer, schema = self._get_stream_transformer_and_schema(stream_name)
252+
# Transform object fields according to config. Most likely you will
253+
# need it to normalize values against json schema. By default no action
254+
# taken unless configured. See
255+
# docs/connector-development/cdk-python/schemas.md for details.
256+
transformer.transform(data, schema)
232257
message = AirbyteRecordMessage(stream=stream_name, data=data, emitted_at=now_millis)
233258
return AirbyteMessage(type=MessageType.RECORD, record=message)

airbyte-cdk/python/airbyte_cdk/sources/streams/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from airbyte_cdk.logger import AirbyteLogger
3232
from airbyte_cdk.models import AirbyteStream, SyncMode
3333
from airbyte_cdk.sources.utils.schema_helpers import ResourceSchemaLoader
34+
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
3435

3536

3637
def package_name_from_class(cls: object) -> str:
@@ -47,6 +48,9 @@ class Stream(ABC):
4748
# Use self.logger in subclasses to log any messages
4849
logger = AirbyteLogger() # TODO use native "logging" loggers with custom handlers
4950

51+
# TypeTransformer object to perform output data transformation
52+
transformer: TypeTransformer = TypeTransformer(TransformConfig.NoTransform)
53+
5054
@property
5155
def name(self) -> str:
5256
"""
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
#
2+
# MIT License
3+
#
4+
# Copyright (c) 2020 Airbyte
5+
#
6+
# Permission is hereby granted, free of charge, to any person obtaining a copy
7+
# of this software and associated documentation files (the "Software"), to deal
8+
# in the Software without restriction, including without limitation the rights
9+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
# copies of the Software, and to permit persons to whom the Software is
11+
# furnished to do so, subject to the following conditions:
12+
#
13+
# The above copyright notice and this permission notice shall be included in all
14+
# copies or substantial portions of the Software.
15+
#
16+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
# SOFTWARE.
23+
#
24+
from distutils.util import strtobool
25+
from enum import Flag, auto
26+
from typing import Any, Callable, Dict
27+
28+
from airbyte_cdk.logger import AirbyteLogger
29+
from jsonschema import Draft7Validator, validators
30+
31+
logger = AirbyteLogger()
32+
33+
34+
class TransformConfig(Flag):
35+
"""
36+
TypeTransformer class config. Configs can be combined using bitwise or operator e.g.
37+
```
38+
TransformConfig.DefaultSchemaNormalization | TransformConfig.CustomSchemaNormalization
39+
```
40+
"""
41+
42+
# No action taken, default behaviour. Cannot be combined with any other options.
43+
NoTransform = auto()
44+
# Applies default type casting with default_convert method which converts
45+
# values by applying simple type casting to specified jsonschema type.
46+
DefaultSchemaNormalization = auto()
47+
# Allow registering custom type transformation callback. Can be combined
48+
# with DefaultSchemaNormalization. In this case default type casting would
49+
# be applied before custom one.
50+
CustomSchemaNormalization = auto()
51+
52+
53+
class TypeTransformer:
54+
"""
55+
Class for transforming object before output.
56+
"""
57+
58+
_custom_normalizer: Callable[[Any, Dict[str, Any]], Any] = None
59+
60+
def __init__(self, config: TransformConfig):
61+
"""
62+
Initialize TypeTransformer instance.
63+
:param config Transform config that would be applied to object
64+
"""
65+
if TransformConfig.NoTransform in config and config != TransformConfig.NoTransform:
66+
raise Exception("NoTransform option cannot be combined with other flags.")
67+
self._config = config
68+
all_validators = {
69+
key: self.__get_normalizer(key, orig_validator)
70+
for key, orig_validator in Draft7Validator.VALIDATORS.items()
71+
# Do not validate field we do not transform for maximum performance.
72+
if key in ["type", "array", "$ref", "properties", "items"]
73+
}
74+
self._normalizer = validators.create(meta_schema=Draft7Validator.META_SCHEMA, validators=all_validators)
75+
76+
def registerCustomTransform(self, normalization_callback: Callable[[Any, Dict[str, Any]], Any]) -> Callable:
77+
"""
78+
Register custom normalization callback.
79+
:param normalization_callback function to be used for value
80+
normalization. Takes original value and part type schema. Should return
81+
normalized value. See docs/connector-development/cdk-python/schemas.md
82+
for details.
83+
:return Same callbeck, this is usefull for using registerCustomTransform function as decorator.
84+
"""
85+
if TransformConfig.CustomSchemaNormalization not in self._config:
86+
raise Exception("Please set TransformConfig.CustomSchemaNormalization config before registering custom normalizer")
87+
self._custom_normalizer = normalization_callback
88+
return normalization_callback
89+
90+
def __normalize(self, original_item: Any, subschema: Dict[str, Any]) -> Any:
91+
"""
92+
Applies different transform function to object's field according to config.
93+
:param original_item original value of field.
94+
:param subschema part of the jsonschema containing field type/format data.
95+
:return Final field value.
96+
"""
97+
if TransformConfig.DefaultSchemaNormalization in self._config:
98+
original_item = self.default_convert(original_item, subschema)
99+
100+
if self._custom_normalizer:
101+
original_item = self._custom_normalizer(original_item, subschema)
102+
return original_item
103+
104+
@staticmethod
105+
def default_convert(original_item: Any, subschema: Dict[str, Any]) -> Any:
106+
"""
107+
Default transform function that is used when TransformConfig.DefaultSchemaNormalization flag set.
108+
:param original_item original value of field.
109+
:param subschema part of the jsonschema containing field type/format data.
110+
:return transformed field value.
111+
"""
112+
target_type = subschema.get("type")
113+
if original_item is None and "null" in target_type:
114+
return None
115+
if isinstance(target_type, list):
116+
# jsonschema type could either be a single string or array of type
117+
# strings. In case if there is some disambigous and more than one
118+
# type (except null) do not do any conversion and return original
119+
# value. If type array has one type and null i.e. {"type":
120+
# ["integer", "null"]}, convert value to specified type.
121+
target_type = [t for t in target_type if t != "null"]
122+
if len(target_type) != 1:
123+
return original_item
124+
target_type = target_type[0]
125+
try:
126+
if target_type == "string":
127+
return str(original_item)
128+
elif target_type == "number":
129+
return float(original_item)
130+
elif target_type == "integer":
131+
return int(original_item)
132+
elif target_type == "boolean":
133+
if isinstance(original_item, str):
134+
return strtobool(original_item) == 1
135+
return bool(original_item)
136+
except ValueError:
137+
return original_item
138+
return original_item
139+
140+
def __get_normalizer(self, schema_key: str, original_validator: Callable):
141+
"""
142+
Traverse through object fields using native jsonschema validator and apply normalization function.
143+
:param schema_key related json schema key that currently being validated/normalized.
144+
:original_validator: native jsonschema validator callback.
145+
"""
146+
147+
def normalizator(validator_instance: Callable, val: Any, instance: Any, schema: Dict[str, Any]):
148+
"""
149+
Jsonschema validator callable it uses for validating instance. We
150+
override default Draft7Validator to perform value transformation
151+
before validation take place. We do not take any action except
152+
logging warn if object does not conform to json schema, just using
153+
jsonschema algorithm to traverse through object fields.
154+
Look
155+
https://python-jsonschema.readthedocs.io/en/stable/creating/?highlight=validators.create#jsonschema.validators.create
156+
validators parameter for detailed description.
157+
:
158+
"""
159+
160+
def resolve(subschema):
161+
if "$ref" in subschema:
162+
_, resolved = validator_instance.resolver.resolve(subschema["$ref"])
163+
return resolved
164+
return subschema
165+
166+
if schema_key == "type" and instance is not None:
167+
if "object" in val and isinstance(instance, dict):
168+
for k, subschema in schema.get("properties", {}).items():
169+
if k in instance:
170+
subschema = resolve(subschema)
171+
instance[k] = self.__normalize(instance[k], subschema)
172+
elif "array" in val and isinstance(instance, list):
173+
subschema = schema.get("items", {})
174+
subschema = resolve(subschema)
175+
for index, item in enumerate(instance):
176+
instance[index] = self.__normalize(item, subschema)
177+
# Running native jsonschema traverse algorithm after field normalization is done.
178+
yield from original_validator(validator_instance, val, instance, schema)
179+
180+
return normalizator
181+
182+
def transform(self, record: Dict[str, Any], schema: Dict[str, Any]):
183+
"""
184+
Normalize and validate according to config.
185+
:param record record instance for normalization/transformation. All modification are done by modifing existent object.
186+
:schema object's jsonschema for normalization.
187+
"""
188+
if TransformConfig.NoTransform in self._config:
189+
return
190+
normalizer = self._normalizer(schema)
191+
for e in normalizer.iter_errors(record):
192+
"""
193+
just calling normalizer.validate() would throw an exception on
194+
first validation occurences and stop processing rest of schema.
195+
"""
196+
logger.warn(e.message)

airbyte-cdk/python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
setup(
3737
name="airbyte-cdk",
38-
version="0.1.23",
38+
version="0.1.24",
3939
description="A framework for writing Airbyte Connectors.",
4040
long_description=README,
4141
long_description_content_type="text/markdown",

airbyte-cdk/python/unit_tests/sources/test_source.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from airbyte_cdk.sources import AbstractSource, Source
3535
from airbyte_cdk.sources.streams.core import Stream
3636
from airbyte_cdk.sources.streams.http.http import HttpStream
37+
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
3738

3839

3940
class MockSource(Source):
@@ -81,6 +82,7 @@ def abstract_source(mocker):
8182
class MockHttpStream(MagicMock, HttpStream):
8283
url_base = "http://example.com"
8384
path = "/dummy/path"
85+
get_json_schema = MagicMock()
8486

8587
def supports_incremental(self):
8688
return True
@@ -92,6 +94,7 @@ def __init__(self, *args, **kvargs):
9294

9395
class MockStream(MagicMock, Stream):
9496
page_size = None
97+
get_json_schema = MagicMock()
9598

9699
def __init__(self, *args, **kvargs):
97100
MagicMock.__init__(self)
@@ -145,8 +148,7 @@ def test_read_catalog(source):
145148
def test_internal_config(abstract_source, catalog):
146149
streams = abstract_source.streams(None)
147150
assert len(streams) == 2
148-
http_stream = streams[0]
149-
non_http_stream = streams[1]
151+
http_stream, non_http_stream = streams
150152
assert isinstance(http_stream, HttpStream)
151153
assert not isinstance(non_http_stream, HttpStream)
152154
http_stream.read_records.return_value = [{}] * 3
@@ -216,3 +218,44 @@ def test_internal_config_limit(abstract_source, catalog):
216218
logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list]
217219
read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")]
218220
assert read_log_record[0].startswith(f"Read {STREAM_LIMIT} ")
221+
222+
223+
SCHEMA = {"type": "object", "properties": {"value": {"type": "string"}}}
224+
225+
226+
def test_source_config_no_transform(abstract_source, catalog):
227+
logger_mock = MagicMock()
228+
streams = abstract_source.streams(None)
229+
http_stream, non_http_stream = streams
230+
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
231+
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [[{"value": 23}] * 5] * 2
232+
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
233+
assert len(records) == 2 * 5
234+
assert [r.record.data for r in records] == [{"value": 23}] * 2 * 5
235+
assert http_stream.get_json_schema.call_count == 1
236+
assert non_http_stream.get_json_schema.call_count == 1
237+
238+
239+
def test_source_config_transform(abstract_source, catalog):
240+
logger_mock = MagicMock()
241+
streams = abstract_source.streams(None)
242+
http_stream, non_http_stream = streams
243+
http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
244+
non_http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
245+
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
246+
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [{"value": 23}], [{"value": 23}]
247+
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
248+
assert len(records) == 2
249+
assert [r.record.data for r in records] == [{"value": "23"}] * 2
250+
251+
252+
def test_source_config_transform_and_no_transform(abstract_source, catalog):
253+
logger_mock = MagicMock()
254+
streams = abstract_source.streams(None)
255+
http_stream, non_http_stream = streams
256+
http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
257+
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
258+
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [{"value": 23}], [{"value": 23}]
259+
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
260+
assert len(records) == 2
261+
assert [r.record.data for r in records] == [{"value": "23"}, {"value": 23}]

0 commit comments

Comments
 (0)