|
| 1 | +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. |
| 2 | +"""A BigQuery implementation of the cache.""" |
| 3 | + |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +from pathlib import Path |
| 7 | +from typing import TYPE_CHECKING, final |
| 8 | + |
| 9 | +import sqlalchemy |
| 10 | +from google.api_core.exceptions import NotFound |
| 11 | +from google.cloud import bigquery |
| 12 | +from google.oauth2 import service_account |
| 13 | +from overrides import overrides |
| 14 | + |
| 15 | +from airbyte import exceptions as exc |
| 16 | +from airbyte._processors.file.jsonl import JsonlWriter |
| 17 | +from airbyte._processors.sql.base import SqlProcessorBase |
| 18 | +from airbyte.telemetry import CacheTelemetryInfo |
| 19 | +from airbyte.types import SQLTypeConverter |
| 20 | + |
| 21 | + |
| 22 | +if TYPE_CHECKING: |
| 23 | + from sqlalchemy.engine.reflection import Inspector |
| 24 | + |
| 25 | + from airbyte._processors.file.base import FileWriterBase |
| 26 | + from airbyte.caches.base import CacheBase |
| 27 | + from airbyte.caches.bigquery import BigQueryCache |
| 28 | + |
| 29 | + |
| 30 | +class BigQueryTypeConverter(SQLTypeConverter): |
| 31 | + """A class to convert types for BigQuery.""" |
| 32 | + |
| 33 | + @overrides |
| 34 | + def to_sql_type( |
| 35 | + self, |
| 36 | + json_schema_property_def: dict[str, str | dict | list], |
| 37 | + ) -> sqlalchemy.types.TypeEngine: |
| 38 | + """Convert a value to a SQL type. |
| 39 | +
|
| 40 | + We first call the parent class method to get the type. Then if the type is VARCHAR or |
| 41 | + BIGINT, we replace it with respective BigQuery types. |
| 42 | + """ |
| 43 | + sql_type = super().to_sql_type(json_schema_property_def) |
| 44 | + # to-do: replace hardcoded return types with some sort of snowflake Variant equivalent |
| 45 | + if isinstance(sql_type, sqlalchemy.types.VARCHAR): |
| 46 | + return "String" |
| 47 | + if isinstance(sql_type, sqlalchemy.types.BIGINT): |
| 48 | + return "INT64" |
| 49 | + |
| 50 | + return sql_type.__class__.__name__ |
| 51 | + |
| 52 | + |
| 53 | +class BigQuerySqlProcessor(SqlProcessorBase): |
| 54 | + """A BigQuery implementation of the cache.""" |
| 55 | + |
| 56 | + file_writer_class = JsonlWriter |
| 57 | + type_converter_class = BigQueryTypeConverter |
| 58 | + supports_merge_insert = True |
| 59 | + |
| 60 | + cache: BigQueryCache |
| 61 | + |
| 62 | + def __init__(self, cache: CacheBase, file_writer: FileWriterBase | None = None) -> None: |
| 63 | + self._credentials: service_account.Credentials | None = None |
| 64 | + self._schema_exists: bool | None = None |
| 65 | + super().__init__(cache, file_writer) |
| 66 | + |
| 67 | + @final |
| 68 | + @overrides |
| 69 | + def _fully_qualified( |
| 70 | + self, |
| 71 | + table_name: str, |
| 72 | + ) -> str: |
| 73 | + """Return the fully qualified name of the given table.""" |
| 74 | + return f"`{self.cache.schema_name}`.`{table_name!s}`" |
| 75 | + |
| 76 | + @final |
| 77 | + @overrides |
| 78 | + def _quote_identifier(self, identifier: str) -> str: |
| 79 | + """Return the identifier name as is. BigQuery does not require quoting identifiers""" |
| 80 | + return f"{identifier}" |
| 81 | + |
| 82 | + @final |
| 83 | + @overrides |
| 84 | + def _get_telemetry_info(self) -> CacheTelemetryInfo: |
| 85 | + return CacheTelemetryInfo("bigquery") |
| 86 | + |
| 87 | + def _write_files_to_new_table( |
| 88 | + self, |
| 89 | + files: list[Path], |
| 90 | + stream_name: str, |
| 91 | + batch_id: str, |
| 92 | + ) -> str: |
| 93 | + """Write a file(s) to a new table. |
| 94 | +
|
| 95 | + This is a generic implementation, which can be overridden by subclasses |
| 96 | + to improve performance. |
| 97 | + """ |
| 98 | + temp_table_name = self._create_table_for_loading(stream_name, batch_id) |
| 99 | + |
| 100 | + # Specify the table ID (in the format `project_id.dataset_id.table_id`) |
| 101 | + table_id = f"{self.cache.project_name}.{self.cache.dataset_name}.{temp_table_name}" |
| 102 | + |
| 103 | + # Initialize a BigQuery client |
| 104 | + client = bigquery.Client(credentials=self._get_credentials()) |
| 105 | + |
| 106 | + for file_path in files: |
| 107 | + with Path.open(file_path, "rb") as source_file: |
| 108 | + load_job = client.load_table_from_file( # Make an API request |
| 109 | + file_obj=source_file, |
| 110 | + destination=table_id, |
| 111 | + job_config=bigquery.LoadJobConfig( |
| 112 | + source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON, |
| 113 | + schema=[ |
| 114 | + bigquery.SchemaField(name, field_type=str(type_)) |
| 115 | + for name, type_ in self._get_sql_column_definitions( |
| 116 | + stream_name=stream_name |
| 117 | + ).items() |
| 118 | + ], |
| 119 | + ), |
| 120 | + ) |
| 121 | + _ = load_job.result() # Wait for the job to complete |
| 122 | + |
| 123 | + return temp_table_name |
| 124 | + |
| 125 | + def _ensure_schema_exists( |
| 126 | + self, |
| 127 | + ) -> None: |
| 128 | + """Ensure the target schema exists. |
| 129 | +
|
| 130 | + We override the default implementation because BigQuery is very slow at scanning schemas. |
| 131 | +
|
| 132 | + This implementation simply calls "CREATE SCHEMA IF NOT EXISTS" and ignores any errors. |
| 133 | + """ |
| 134 | + if self._schema_exists: |
| 135 | + return |
| 136 | + |
| 137 | + sql = f"CREATE SCHEMA IF NOT EXISTS {self.cache.schema_name}" |
| 138 | + try: |
| 139 | + self._execute_sql(sql) |
| 140 | + except Exception as ex: |
| 141 | + # Ignore schema exists errors. |
| 142 | + if "already exists" not in str(ex): |
| 143 | + raise |
| 144 | + |
| 145 | + self._schema_exists = True |
| 146 | + |
| 147 | + def _get_credentials(self) -> service_account.Credentials: |
| 148 | + """Return the GCP credentials.""" |
| 149 | + if self._credentials is None: |
| 150 | + self._credentials = service_account.Credentials.from_service_account_file( |
| 151 | + self.cache.credentials_path |
| 152 | + ) |
| 153 | + |
| 154 | + return self._credentials |
| 155 | + |
| 156 | + def _table_exists( |
| 157 | + self, |
| 158 | + table_name: str, |
| 159 | + ) -> bool: |
| 160 | + """Return true if the given table exists. |
| 161 | +
|
| 162 | + We override the default implementation because BigQuery is very slow at scanning tables. |
| 163 | + """ |
| 164 | + client = bigquery.Client(credentials=self._get_credentials()) |
| 165 | + table_id = f"{self.cache.project_name}.{self.cache.dataset_name}.{table_name}" |
| 166 | + try: |
| 167 | + client.get_table(table_id) |
| 168 | + except NotFound: |
| 169 | + return False |
| 170 | + |
| 171 | + except ValueError as ex: |
| 172 | + raise exc.AirbyteLibInputError( |
| 173 | + message="Invalid project name or dataset name.", |
| 174 | + context={ |
| 175 | + "table_id": table_id, |
| 176 | + "table_name": table_name, |
| 177 | + "project_name": self.cache.project_name, |
| 178 | + "dataset_name": self.cache.dataset_name, |
| 179 | + }, |
| 180 | + ) from ex |
| 181 | + |
| 182 | + return True |
| 183 | + |
| 184 | + @final |
| 185 | + @overrides |
| 186 | + def _get_tables_list( |
| 187 | + self, |
| 188 | + ) -> list[str]: |
| 189 | + """Get the list of available tables in the schema. |
| 190 | +
|
| 191 | + For bigquery, {schema_name}.{table_name} is returned, so we need to |
| 192 | + strip the schema name in front of the table name, if it exists. |
| 193 | +
|
| 194 | + Warning: This method is slow for BigQuery, as it needs to scan all tables in the dataset. |
| 195 | + It has been observed to take 30+ seconds in some cases. |
| 196 | + """ |
| 197 | + with self.get_sql_connection() as conn: |
| 198 | + inspector: Inspector = sqlalchemy.inspect(conn) |
| 199 | + tables = inspector.get_table_names(schema=self.cache.schema_name) |
| 200 | + schema_prefix = f"{self.cache.schema_name}." |
| 201 | + return [ |
| 202 | + table.replace(schema_prefix, "", 1) if table.startswith(schema_prefix) else table |
| 203 | + for table in tables |
| 204 | + ] |
0 commit comments