Skip to content

Commit 6cfc3b5

Browse files
bindipankhudibindipankhudiaaronsteers
authored
Feat: Add BigQuery cache support; Fix: IPython rendering bug when using a TTY (#65)
Co-authored-by: bindipankhudi <[email protected]> Co-authored-by: Aaron Steers <[email protected]>
1 parent 797e657 commit 6cfc3b5

File tree

9 files changed

+688
-72
lines changed

9 files changed

+688
-72
lines changed

airbyte/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from airbyte import caches, datasets, registry, secrets
99
from airbyte._factories.connector_factories import get_source
10+
from airbyte.caches.bigquery import BigQueryCache
1011
from airbyte.caches.duckdb import DuckDBCache
1112
from airbyte.caches.factories import get_default_cache, new_local_cache
1213
from airbyte.datasets import CachedDataset
@@ -29,6 +30,7 @@
2930
"get_source",
3031
"new_local_cache",
3132
# Classes
33+
"BigQueryCache",
3234
"CachedDataset",
3335
"DuckDBCache",
3436
"ReadResult",

airbyte/_processors/sql/base.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ class SqlProcessorBase(RecordProcessor):
7979

8080
# Constructor:
8181

82-
@final # We don't want subclasses to have to override the constructor.
8382
def __init__(
8483
self,
8584
cache: CacheBase,
@@ -435,9 +434,6 @@ def _create_table(
435434
column_definition_str: str,
436435
primary_keys: list[str] | None = None,
437436
) -> None:
438-
if DEBUG_MODE:
439-
assert table_name not in self._get_tables_list(), f"Table {table_name} already exists."
440-
441437
if primary_keys:
442438
pk_str = ", ".join(primary_keys)
443439
column_definition_str += f",\n PRIMARY KEY ({pk_str})"
@@ -448,11 +444,6 @@ def _create_table(
448444
)
449445
"""
450446
_ = self._execute_sql(cmd)
451-
if DEBUG_MODE:
452-
tables_list = self._get_tables_list()
453-
assert (
454-
table_name in tables_list
455-
), f"Table {table_name} was not created. Found: {tables_list}"
456447

457448
def _normalize_column_name(
458449
self,
@@ -804,8 +795,8 @@ def _merge_temp_table_to_final_table(
804795
columns = {self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)}
805796
pk_columns = {self._quote_identifier(c) for c in self._get_primary_keys(stream_name)}
806797
non_pk_columns = columns - pk_columns
807-
join_clause = "{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns)
808-
set_clause = "{nl} ".join(f"{col} = tmp.{col}" for col in non_pk_columns)
798+
join_clause = f"{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns)
799+
set_clause = f"{nl} , ".join(f"{col} = tmp.{col}" for col in non_pk_columns)
809800
self._execute_sql(
810801
f"""
811802
MERGE INTO {self._fully_qualified(final_table_name)} final
@@ -908,12 +899,14 @@ def _emulated_merge_temp_table_to_final_table(
908899
conn.execute(update_stmt)
909900
conn.execute(insert_new_records_stmt)
910901

911-
@final
912902
def _table_exists(
913903
self,
914904
table_name: str,
915905
) -> bool:
916-
"""Return true if the given table exists."""
906+
"""Return true if the given table exists.
907+
908+
Subclasses may override this method to provide a more efficient implementation.
909+
"""
917910
return table_name in self._get_tables_list()
918911

919912
@abc.abstractmethod

airbyte/_processors/sql/bigquery.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
"""A BigQuery implementation of the cache."""
3+
4+
from __future__ import annotations
5+
6+
from pathlib import Path
7+
from typing import TYPE_CHECKING, final
8+
9+
import sqlalchemy
10+
from google.api_core.exceptions import NotFound
11+
from google.cloud import bigquery
12+
from google.oauth2 import service_account
13+
from overrides import overrides
14+
15+
from airbyte import exceptions as exc
16+
from airbyte._processors.file.jsonl import JsonlWriter
17+
from airbyte._processors.sql.base import SqlProcessorBase
18+
from airbyte.telemetry import CacheTelemetryInfo
19+
from airbyte.types import SQLTypeConverter
20+
21+
22+
if TYPE_CHECKING:
23+
from sqlalchemy.engine.reflection import Inspector
24+
25+
from airbyte._processors.file.base import FileWriterBase
26+
from airbyte.caches.base import CacheBase
27+
from airbyte.caches.bigquery import BigQueryCache
28+
29+
30+
class BigQueryTypeConverter(SQLTypeConverter):
31+
"""A class to convert types for BigQuery."""
32+
33+
@overrides
34+
def to_sql_type(
35+
self,
36+
json_schema_property_def: dict[str, str | dict | list],
37+
) -> sqlalchemy.types.TypeEngine:
38+
"""Convert a value to a SQL type.
39+
40+
We first call the parent class method to get the type. Then if the type is VARCHAR or
41+
BIGINT, we replace it with respective BigQuery types.
42+
"""
43+
sql_type = super().to_sql_type(json_schema_property_def)
44+
# to-do: replace hardcoded return types with some sort of snowflake Variant equivalent
45+
if isinstance(sql_type, sqlalchemy.types.VARCHAR):
46+
return "String"
47+
if isinstance(sql_type, sqlalchemy.types.BIGINT):
48+
return "INT64"
49+
50+
return sql_type.__class__.__name__
51+
52+
53+
class BigQuerySqlProcessor(SqlProcessorBase):
54+
"""A BigQuery implementation of the cache."""
55+
56+
file_writer_class = JsonlWriter
57+
type_converter_class = BigQueryTypeConverter
58+
supports_merge_insert = True
59+
60+
cache: BigQueryCache
61+
62+
def __init__(self, cache: CacheBase, file_writer: FileWriterBase | None = None) -> None:
63+
self._credentials: service_account.Credentials | None = None
64+
self._schema_exists: bool | None = None
65+
super().__init__(cache, file_writer)
66+
67+
@final
68+
@overrides
69+
def _fully_qualified(
70+
self,
71+
table_name: str,
72+
) -> str:
73+
"""Return the fully qualified name of the given table."""
74+
return f"`{self.cache.schema_name}`.`{table_name!s}`"
75+
76+
@final
77+
@overrides
78+
def _quote_identifier(self, identifier: str) -> str:
79+
"""Return the identifier name as is. BigQuery does not require quoting identifiers"""
80+
return f"{identifier}"
81+
82+
@final
83+
@overrides
84+
def _get_telemetry_info(self) -> CacheTelemetryInfo:
85+
return CacheTelemetryInfo("bigquery")
86+
87+
def _write_files_to_new_table(
88+
self,
89+
files: list[Path],
90+
stream_name: str,
91+
batch_id: str,
92+
) -> str:
93+
"""Write a file(s) to a new table.
94+
95+
This is a generic implementation, which can be overridden by subclasses
96+
to improve performance.
97+
"""
98+
temp_table_name = self._create_table_for_loading(stream_name, batch_id)
99+
100+
# Specify the table ID (in the format `project_id.dataset_id.table_id`)
101+
table_id = f"{self.cache.project_name}.{self.cache.dataset_name}.{temp_table_name}"
102+
103+
# Initialize a BigQuery client
104+
client = bigquery.Client(credentials=self._get_credentials())
105+
106+
for file_path in files:
107+
with Path.open(file_path, "rb") as source_file:
108+
load_job = client.load_table_from_file( # Make an API request
109+
file_obj=source_file,
110+
destination=table_id,
111+
job_config=bigquery.LoadJobConfig(
112+
source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON,
113+
schema=[
114+
bigquery.SchemaField(name, field_type=str(type_))
115+
for name, type_ in self._get_sql_column_definitions(
116+
stream_name=stream_name
117+
).items()
118+
],
119+
),
120+
)
121+
_ = load_job.result() # Wait for the job to complete
122+
123+
return temp_table_name
124+
125+
def _ensure_schema_exists(
126+
self,
127+
) -> None:
128+
"""Ensure the target schema exists.
129+
130+
We override the default implementation because BigQuery is very slow at scanning schemas.
131+
132+
This implementation simply calls "CREATE SCHEMA IF NOT EXISTS" and ignores any errors.
133+
"""
134+
if self._schema_exists:
135+
return
136+
137+
sql = f"CREATE SCHEMA IF NOT EXISTS {self.cache.schema_name}"
138+
try:
139+
self._execute_sql(sql)
140+
except Exception as ex:
141+
# Ignore schema exists errors.
142+
if "already exists" not in str(ex):
143+
raise
144+
145+
self._schema_exists = True
146+
147+
def _get_credentials(self) -> service_account.Credentials:
148+
"""Return the GCP credentials."""
149+
if self._credentials is None:
150+
self._credentials = service_account.Credentials.from_service_account_file(
151+
self.cache.credentials_path
152+
)
153+
154+
return self._credentials
155+
156+
def _table_exists(
157+
self,
158+
table_name: str,
159+
) -> bool:
160+
"""Return true if the given table exists.
161+
162+
We override the default implementation because BigQuery is very slow at scanning tables.
163+
"""
164+
client = bigquery.Client(credentials=self._get_credentials())
165+
table_id = f"{self.cache.project_name}.{self.cache.dataset_name}.{table_name}"
166+
try:
167+
client.get_table(table_id)
168+
except NotFound:
169+
return False
170+
171+
except ValueError as ex:
172+
raise exc.AirbyteLibInputError(
173+
message="Invalid project name or dataset name.",
174+
context={
175+
"table_id": table_id,
176+
"table_name": table_name,
177+
"project_name": self.cache.project_name,
178+
"dataset_name": self.cache.dataset_name,
179+
},
180+
) from ex
181+
182+
return True
183+
184+
@final
185+
@overrides
186+
def _get_tables_list(
187+
self,
188+
) -> list[str]:
189+
"""Get the list of available tables in the schema.
190+
191+
For bigquery, {schema_name}.{table_name} is returned, so we need to
192+
strip the schema name in front of the table name, if it exists.
193+
194+
Warning: This method is slow for BigQuery, as it needs to scan all tables in the dataset.
195+
It has been observed to take 30+ seconds in some cases.
196+
"""
197+
with self.get_sql_connection() as conn:
198+
inspector: Inspector = sqlalchemy.inspect(conn)
199+
tables = inspector.get_table_names(schema=self.cache.schema_name)
200+
schema_prefix = f"{self.cache.schema_name}."
201+
return [
202+
table.replace(schema_prefix, "", 1) if table.startswith(schema_prefix) else table
203+
for table in tables
204+
]

airbyte/caches/_catalog_manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
from __future__ import annotations
55

66
import json
7+
from datetime import datetime
78
from typing import TYPE_CHECKING, Callable
89

10+
from pytz import utc
911
from sqlalchemy import Column, DateTime, String
1012
from sqlalchemy.ext.declarative import declarative_base
1113
from sqlalchemy.orm import Session
12-
from sqlalchemy.sql import func
1314

1415
from airbyte_protocol.models import (
1516
AirbyteStateMessage,
@@ -50,7 +51,9 @@ class StreamState(Base): # type: ignore[valid-type,misc]
5051
stream_name = Column(String)
5152
table_name = Column(String, primary_key=True)
5253
state_json = Column(String)
53-
last_updated = Column(DateTime(timezone=True), onupdate=func.now(), default=func.now())
54+
last_updated = Column(
55+
DateTime(timezone=True), onupdate=datetime.now(utc), default=datetime.now(utc)
56+
)
5457

5558

5659
class CatalogManager:

airbyte/caches/bigquery.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
"""A BigQuery implementation of the cache."""
3+
4+
from __future__ import annotations
5+
6+
import urllib
7+
8+
from overrides import overrides
9+
10+
from airbyte._processors.sql.bigquery import BigQuerySqlProcessor
11+
from airbyte.caches.base import (
12+
CacheBase,
13+
)
14+
15+
16+
class BigQueryCache(CacheBase):
17+
"""The BigQuery cache implementation."""
18+
19+
project_name: str
20+
dataset_name: str = "airbyte_raw"
21+
credentials_path: str
22+
23+
_sql_processor_class: type[BigQuerySqlProcessor] = BigQuerySqlProcessor
24+
25+
def __post_init__(self) -> None:
26+
"""Initialize the BigQuery cache."""
27+
self.schema_name = self.dataset_name
28+
29+
@overrides
30+
def get_database_name(self) -> str:
31+
"""Return the name of the database. For BigQuery, this is the project name."""
32+
return self.project_name
33+
34+
@overrides
35+
def get_sql_alchemy_url(self) -> str:
36+
"""Return the SQLAlchemy URL to use."""
37+
credentials_path_encoded = urllib.parse.quote(self.credentials_path)
38+
return f"bigquery://{self.project_name!s}?credentials_path={credentials_path_encoded}"

airbyte/progress.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
import datetime
7+
import importlib
78
import math
89
import sys
910
import time
@@ -25,10 +26,12 @@
2526

2627
ipy_display: ModuleType | None
2728
try:
28-
IS_NOTEBOOK = True
29-
from IPython import display as ipy_display # type: ignore # noqa: PGH003
29+
# Default to IS_NOTEBOOK=False if a TTY is detected.
30+
IS_NOTEBOOK = not sys.stdout.isatty()
31+
ipy_display = importlib.import_module("IPython.display")
3032

3133
except ImportError:
34+
# If IPython is not installed, then we're definitely not in a notebook.
3235
ipy_display = None
3336
IS_NOTEBOOK = False
3437

0 commit comments

Comments
 (0)