Skip to content

Commit 9603c3c

Browse files
artem1205maxi297
andauthored
feat(Airbyte-ci): add command generate-erd-schema (#43310)
Signed-off-by: Artem Inzhyyants <[email protected]> Co-authored-by: maxi297 <[email protected]> Co-authored-by: Maxime Carbonneau-Leclerc <[email protected]>
1 parent 7ba3e2d commit 9603c3c

File tree

21 files changed

+2906
-0
lines changed

21 files changed

+2906
-0
lines changed

.github/workflows/airbyte-ci-tests.yml

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ jobs:
3939
- airbyte-ci/connectors/connector_ops/**
4040
- airbyte-ci/connectors/connectors_qa/**
4141
- airbyte-ci/connectors/ci_credentials/**
42+
- airbyte-ci/connectors/erd/**
4243
- airbyte-ci/connectors/metadata_service/lib/**
4344
- airbyte-ci/connectors/metadata_service/orchestrator/**
4445
- airbyte-cdk/python/**

airbyte-ci/connectors/erd/README.md

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# erd
2+
3+
A collection of utilities for generating ERDs.
4+
5+
# Setup
6+
7+
## Installation
8+
9+
`erd` tools use [Poetry](https://github.com/python-poetry/poetry) to manage dependencies,
10+
and targets Python 3.10 and higher.
11+
12+
Assuming you're in Airbyte repo root:
13+
14+
```bash
15+
cd airbyte-ci/connectors/erd
16+
poetry install
17+
```
18+
19+
## Usage
20+
21+
Pre-requisites:
22+
* Env variable `GENAI_API_KEY`. Can be found at URL https://aistudio.google.com/app/apikey
23+
24+
`poetry run erd --source-path <source path> --source-technical-name <for example, 'source-facebook-marketing'>`
25+
26+
The script supports the option to ignore the LLM generation by passing parameter `--skip-llm-relationships`
27+
28+
## Contributing to `erd`
29+
30+
### Running tests
31+
32+
To run tests locally:
33+
34+
```bash
35+
poetry run pytest
36+
```
37+
38+
## Changelog
39+
- 0.1.0: Initial commit

airbyte-ci/connectors/erd/poetry.lock

+2,001
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
[build-system]
2+
requires = ["poetry-core>=1.0.0"]
3+
build-backend = "poetry.core.masonry.api"
4+
5+
[tool.poetry]
6+
name = "erd"
7+
version = "0.1.0"
8+
description = "Contains utilities for generating ERDs."
9+
authors = ["Airbyte <[email protected]>"]
10+
license = "MIT"
11+
homepage = "https://github.com/airbytehq/airbyte"
12+
readme = "README.md"
13+
packages = [
14+
{ include = "erd", from = "src" },
15+
]
16+
17+
[tool.poetry.dependencies]
18+
python = "^3.10,<3.12"
19+
airbyte-cdk = "*"
20+
click = "^8.1.3"
21+
dpath = "^2.1.6"
22+
google-generativeai = "^0.7.2"
23+
markdown-it-py = ">=2.2.0"
24+
pydbml = "^1.1.0"
25+
pytest = "^8.1.1"
26+
pyyaml = "^6.0"
27+
28+
[tool.poetry.group.dev.dependencies]
29+
ruff = "^0.3.0"
30+
mypy = "^1.8.0"
31+
types-pyyaml = "^6.0.12.20240311"
32+
33+
[tool.ruff.lint]
34+
select = ["I", "F"]
35+
36+
[tool.ruff.lint.isort]
37+
known-first-party = ["connection-retriever"]
38+
39+
[tool.poe.tasks]
40+
test = "pytest tests"
41+
type_check = "mypy src --disallow-untyped-defs"
42+
pre-push = []
43+
44+
[tool.poetry.scripts]
45+
erd = "erd.cli:main"
46+
47+
[tool.airbyte_ci]
48+
python_versions = ["3.10"]
49+
poe_tasks = ["type_check", "test"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
3+
from pathlib import Path
4+
from typing import List, Set, Union
5+
6+
import yaml
7+
from airbyte_cdk.sources.declarative.parsers.manifest_reference_resolver import ManifestReferenceResolver
8+
from airbyte_protocol.models import AirbyteCatalog, AirbyteStream # type: ignore # missing library stubs or py.typed marker
9+
from erd.relationships import Relationships
10+
from pydbml import Database # type: ignore # missing library stubs or py.typed marker
11+
from pydbml.classes import Column, Index, Reference, Table # type: ignore # missing library stubs or py.typed marker
12+
13+
14+
class Source:
15+
def __init__(self, source_folder: Path, source_technical_name: str) -> None:
16+
self._source_folder = source_folder
17+
self._source_technical_name = source_technical_name
18+
19+
def is_dynamic(self, stream_name: str) -> bool:
20+
"""
21+
This method is a very flaky heuristic to know if a stream is dynamic or not. A stream will be considered dynamic if:
22+
* The stream name is in the schemas folder
23+
* The stream is within the manifest and the schema definition is `InlineSchemaLoader`
24+
"""
25+
manifest_static_streams = set()
26+
if self._has_manifest():
27+
with open(self._get_manifest_path()) as manifest_file:
28+
resolved_manifest = ManifestReferenceResolver().preprocess_manifest(yaml.safe_load(manifest_file))
29+
for stream in resolved_manifest["streams"]:
30+
if "schema_loader" not in stream:
31+
# stream is assumed to have `DefaultSchemaLoader` which will show in the schemas folder so we can skip
32+
continue
33+
if stream["schema_loader"]["type"] == "InlineSchemaLoader":
34+
name = stream["name"] if "name" in stream else stream.get("$parameters").get("name", None)
35+
if not name:
36+
print(f"Could not retrieve name for this stream: {stream}")
37+
continue
38+
manifest_static_streams.add(stream["name"] if "name" in stream else stream.get("$parameters").get("name", None))
39+
40+
return stream_name not in manifest_static_streams | self._get_streams_from_schemas_folder()
41+
42+
def _get_streams_from_schemas_folder(self) -> Set[str]:
43+
schemas_folder = self._source_folder / self._source_technical_name.replace("-", "_") / "schemas"
44+
return {p.name.replace(".json", "") for p in schemas_folder.iterdir() if p.is_file()} if schemas_folder.exists() else set()
45+
46+
def _get_manifest_path(self) -> Path:
47+
return self._source_folder / self._source_technical_name.replace("-", "_") / "manifest.yaml"
48+
49+
def _has_manifest(self) -> bool:
50+
return self._get_manifest_path().exists()
51+
52+
53+
class DbmlAssembler:
54+
def assemble(self, source: Source, discovered_catalog: AirbyteCatalog, relationships: Relationships) -> Database:
55+
database = Database()
56+
for stream in discovered_catalog.streams:
57+
if source.is_dynamic(stream.name):
58+
print(f"Skipping stream {stream.name} as it is dynamic")
59+
continue
60+
61+
database.add(self._create_table(stream))
62+
63+
self._add_references(source, database, relationships)
64+
65+
return database
66+
67+
def _create_table(self, stream: AirbyteStream) -> Table:
68+
dbml_table = Table(stream.name)
69+
for property_name, property_information in stream.json_schema.get("properties").items():
70+
try:
71+
dbml_table.add_column(
72+
Column(
73+
name=property_name,
74+
type=self._extract_type(property_information["type"]),
75+
pk=self._is_pk(stream, property_name),
76+
)
77+
)
78+
except (KeyError, ValueError) as exception:
79+
print(f"Ignoring field {property_name}: {exception}")
80+
continue
81+
82+
if stream.source_defined_primary_key and len(stream.source_defined_primary_key) > 1:
83+
if any(map(lambda key: len(key) != 1, stream.source_defined_primary_key)):
84+
raise ValueError(f"Does not support nested key as part of primary key `{stream.source_defined_primary_key}`")
85+
86+
composite_key_columns = [
87+
column for key in stream.source_defined_primary_key for column in dbml_table.columns if column.name in key
88+
]
89+
if len(composite_key_columns) < len(stream.source_defined_primary_key):
90+
raise ValueError("Unexpected error: missing PK column from dbml table")
91+
92+
dbml_table.add_index(
93+
Index(
94+
subjects=composite_key_columns,
95+
pk=True,
96+
)
97+
)
98+
return dbml_table
99+
100+
def _add_references(self, source: Source, database: Database, relationships: Relationships) -> None:
101+
for stream in relationships["streams"]:
102+
for column_name, relationship in stream["relations"].items():
103+
if source.is_dynamic(stream["name"]):
104+
print(f"Skipping relationship as stream {stream['name']} from relationship is dynamic")
105+
continue
106+
107+
try:
108+
target_table_name, target_column_name = relationship.split(
109+
".", 1
110+
) # we support the field names having dots but not stream name hence we split on the first dot only
111+
except ValueError as exception:
112+
raise ValueError(f"Could not handle relationship {relationship}") from exception
113+
114+
if source.is_dynamic(target_table_name):
115+
print(f"Skipping relationship as target stream {target_table_name} is dynamic")
116+
continue
117+
118+
try:
119+
database.add_reference(
120+
Reference(
121+
type="<>", # we don't have the information of which relationship type it is so we assume many-to-many for now
122+
col1=self._get_column(database, stream["name"], column_name),
123+
col2=self._get_column(database, target_table_name, target_column_name),
124+
)
125+
)
126+
except ValueError as exception:
127+
print(f"Skipping relationship: {exception}")
128+
129+
def _extract_type(self, property_type: Union[str, List[str]]) -> str:
130+
if isinstance(property_type, str):
131+
return property_type
132+
133+
types = list(property_type)
134+
if "null" in types:
135+
# As we flag everything as nullable (except PK and cursor field), there is little value in keeping the information in order to
136+
# show this in DBML
137+
types.remove("null")
138+
if len(types) != 1:
139+
raise ValueError(f"Expected only one type apart from `null` but got {len(types)}: {property_type}")
140+
return types[0]
141+
142+
def _is_pk(self, stream: AirbyteStream, property_name: str) -> bool:
143+
return stream.source_defined_primary_key == [[property_name]]
144+
145+
def _get_column(self, database: Database, table_name: str, column_name: str) -> Column:
146+
matching_tables = list(filter(lambda dbml_table: dbml_table.name == table_name, database.tables))
147+
if len(matching_tables) == 0:
148+
raise ValueError(f"Could not find table {table_name}")
149+
elif len(matching_tables) > 1:
150+
raise ValueError(f"Unexpected error: many tables found with name {table_name}")
151+
152+
table: Table = matching_tables[0]
153+
matching_columns = list(filter(lambda column: column.name == column_name, table.columns))
154+
if len(matching_columns) == 0:
155+
raise ValueError(f"Could not find column {column_name} in table {table_name}. Columns are: {table.columns}")
156+
elif len(matching_columns) > 1:
157+
raise ValueError(f"Unexpected error: many columns found with name {column_name} for table {table_name}")
158+
159+
return matching_columns[0]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
3+
import copy
4+
import json
5+
from pathlib import Path
6+
from typing import Any
7+
8+
import dpath
9+
import google.generativeai as genai # type: ignore # missing library stubs or py.typed marker
10+
from airbyte_protocol.models import AirbyteCatalog # type: ignore # missing library stubs or py.typed marker
11+
from erd.dbml_assembler import DbmlAssembler, Source
12+
from erd.relationships import Relationships, RelationshipsMerger
13+
from markdown_it import MarkdownIt
14+
from pydbml.renderer.dbml.default import DefaultDBMLRenderer # type: ignore # missing library stubs or py.typed marker
15+
16+
17+
class ErdService:
18+
def __init__(self, source_technical_name: str, source_path: Path) -> None:
19+
self._source_technical_name = source_technical_name
20+
self._source_path = source_path
21+
self._model = genai.GenerativeModel("gemini-1.5-flash")
22+
23+
if not self._discovered_catalog_path.exists():
24+
raise ValueError(f"Could not find discovered catalog at path {self._discovered_catalog_path}")
25+
26+
def generate_estimated_relationships(self) -> None:
27+
normalized_catalog = self._normalize_schema_catalog(self._get_catalog())
28+
estimated_relationships = self._get_relations_from_gemini(source_name=self._source_path.name, catalog=normalized_catalog)
29+
with open(self._estimated_relationships_file, "w") as estimated_relationship_file:
30+
json.dump(estimated_relationships, estimated_relationship_file, indent=4)
31+
32+
def write_dbml_file(self) -> None:
33+
database = DbmlAssembler().assemble(
34+
Source(self._source_path, self._source_technical_name),
35+
self._get_catalog(),
36+
RelationshipsMerger().merge(
37+
self._get_relationships(self._estimated_relationships_file), self._get_relationships(self._confirmed_relationships_file)
38+
),
39+
)
40+
41+
with open(self._erd_folder / "source.dbml", "w") as f:
42+
f.write(DefaultDBMLRenderer.render_db(database))
43+
44+
@staticmethod
45+
def _normalize_schema_catalog(catalog: AirbyteCatalog) -> dict[str, Any]:
46+
"""
47+
Foreign key cannot be of type object or array, therefore, we can remove these properties.
48+
:param schema: json_schema in draft7
49+
:return: json_schema in draft7 with TOP level properties only.
50+
"""
51+
streams = copy.deepcopy(catalog.model_dump())["streams"]
52+
for stream in streams:
53+
to_rem = dpath.search(
54+
stream["json_schema"]["properties"],
55+
["**"],
56+
afilter=lambda x: isinstance(x, dict) and ("array" in str(x.get("type", "")) or "object" in str(x.get("type", ""))),
57+
)
58+
for key in to_rem:
59+
stream["json_schema"]["properties"].pop(key)
60+
return streams # type: ignore # as this comes from an AirbyteCatalog dump, the format should be fine
61+
62+
def _get_relations_from_gemini(self, source_name: str, catalog: dict[str, Any]) -> Relationships:
63+
"""
64+
65+
:param source_name:
66+
:param catalog:
67+
:return: {"streams":[{'name': 'ads', 'relations': {'account_id': 'ad_account.id', 'campaign_id': 'campaigns.id', 'adset_id': 'ad_sets.id'}}, ...]}
68+
"""
69+
system = "You are an Database developer in charge of communicating well to your users."
70+
71+
source_desc = """
72+
You are working on the {source_name} API service.
73+
74+
The current JSON Schema format is as follows:
75+
{current_schema}, where "streams" has a list of streams, which represents database tables, and list of properties in each, which in turn, represent DB columns. Streams presented in list are the only available ones.
76+
Generate and add a `foreign_key` with reference for each field in top level of properties that is helpful in understanding what the data represents and how are streams related to each other. Pay attention to fields ends with '_id'.
77+
""".format(
78+
source_name=source_name, current_schema=catalog
79+
)
80+
task = """
81+
Please provide answer in the following format:
82+
{streams: [{"name": "<stream_name>", "relations": {"<foreign_key>": "<ref_table.column_name>"} }]}
83+
Pay extra attention that in <ref_table.column_name>" "ref_table" should be one of the list of streams, and "column_name" should be one of the property in respective reference stream.
84+
Limitations:
85+
- Not all tables should have relations
86+
- Reference should point to 1 table only.
87+
- table cannot reference on itself, on other words, e.g. `ad_account` cannot have relations with "ad_account" as a "ref_table"
88+
"""
89+
response = self._model.generate_content(f"{system} {source_desc} {task}")
90+
md = MarkdownIt("commonmark")
91+
tokens = md.parse(response.text)
92+
response_json = json.loads(tokens[0].content)
93+
return response_json # type: ignore # we blindly assume Gemini returns a response with the Relationships format as asked
94+
95+
@staticmethod
96+
def _get_relationships(path: Path) -> Relationships:
97+
if not path.exists():
98+
return {"streams": []}
99+
100+
with open(path, "r") as file:
101+
return json.load(file) # type: ignore # we assume the content of the file matches Relationships
102+
103+
def _get_catalog(self) -> AirbyteCatalog:
104+
with open(self._discovered_catalog_path, "r") as file:
105+
try:
106+
return AirbyteCatalog.model_validate(json.loads(file.read()))
107+
except json.JSONDecodeError as error:
108+
raise ValueError(
109+
f"Could not read json file {self._discovered_catalog_path}: {error}. Please ensure that it is a valid JSON."
110+
)
111+
112+
@property
113+
def _erd_folder(self) -> Path:
114+
"""
115+
Note: if this folder change, make sure to update the exported folder in the pipeline
116+
"""
117+
path = self._source_path / "erd"
118+
if not path.exists():
119+
path.mkdir()
120+
return path
121+
122+
@property
123+
def _estimated_relationships_file(self) -> Path:
124+
return self._erd_folder / "estimated_relationships.json"
125+
126+
@property
127+
def _confirmed_relationships_file(self) -> Path:
128+
return self._erd_folder / "confirmed_relationships.json"
129+
130+
@property
131+
def _discovered_catalog_path(self) -> Path:
132+
"""
133+
Note: if this folder change, make sure to update the exported folder in the pipeline
134+
"""
135+
return self._source_path / "erd" / "discovered_catalog.json"

0 commit comments

Comments
 (0)