|
| 1 | +from json import loads |
| 2 | +from urllib.parse import urlparse |
| 3 | + |
| 4 | +import pandas as pd |
| 5 | + |
| 6 | +from app.model import ConnectionUrl, TrinoConnectionInfo |
| 7 | +from app.model.data_source import DataSource |
| 8 | +from app.model.metadata.dto import ( |
| 9 | + Column, |
| 10 | + Constraint, |
| 11 | + Table, |
| 12 | + TableProperties, |
| 13 | + WrenEngineColumnType, |
| 14 | +) |
| 15 | +from app.model.metadata.metadata import Metadata |
| 16 | + |
| 17 | + |
| 18 | +class TrinoMetadata(Metadata): |
| 19 | + def __init__(self, connection_info: TrinoConnectionInfo | ConnectionUrl): |
| 20 | + super().__init__(connection_info) |
| 21 | + |
| 22 | + def get_table_list(self) -> list[Table]: |
| 23 | + sql = """SELECT |
| 24 | + t.table_catalog, |
| 25 | + t.table_schema, |
| 26 | + t.table_name, |
| 27 | + c.column_name, |
| 28 | + c.data_type, |
| 29 | + c.is_nullable |
| 30 | + FROM |
| 31 | + information_schema.tables t |
| 32 | + JOIN |
| 33 | + information_schema.columns c |
| 34 | + ON t.table_schema = c.table_schema |
| 35 | + AND t.table_name = c.table_name |
| 36 | + WHERE |
| 37 | + t.table_type IN ('BASE TABLE', 'VIEW') |
| 38 | + AND t.table_schema NOT IN ('information_schema', 'pg_catalog')""" |
| 39 | + |
| 40 | + sql_cursor = DataSource.trino.get_connection(self.connection_info).raw_sql(sql) |
| 41 | + column_names = [col[0] for col in sql_cursor.description] |
| 42 | + response = loads( |
| 43 | + pd.DataFrame(sql_cursor.fetchall(), columns=column_names).to_json( |
| 44 | + orient="records" |
| 45 | + ) |
| 46 | + ) |
| 47 | + unique_tables = {} |
| 48 | + for row in response: |
| 49 | + # generate unique table name |
| 50 | + schema_table = self._format_trino_compact_table_name( |
| 51 | + row["table_catalog"], row["table_schema"], row["table_name"] |
| 52 | + ) |
| 53 | + # init table if not exists |
| 54 | + if schema_table not in unique_tables: |
| 55 | + unique_tables[schema_table] = Table( |
| 56 | + name=schema_table, |
| 57 | + description="", |
| 58 | + columns=[], |
| 59 | + properties=TableProperties( |
| 60 | + schema=row["table_schema"], |
| 61 | + catalog=row["table_catalog"], |
| 62 | + table=row["table_name"], |
| 63 | + ), |
| 64 | + primaryKey="", |
| 65 | + ) |
| 66 | + |
| 67 | + # table exists, and add column to the table |
| 68 | + unique_tables[schema_table].columns.append( |
| 69 | + Column( |
| 70 | + name=row["column_name"], |
| 71 | + type=self._transform_column_type(row["data_type"]), |
| 72 | + notNull=row["is_nullable"].lower() == "no", |
| 73 | + description="", |
| 74 | + properties=None, |
| 75 | + ) |
| 76 | + ) |
| 77 | + return list(unique_tables.values()) |
| 78 | + |
| 79 | + def get_constraints(self) -> list[Constraint]: |
| 80 | + return [] |
| 81 | + |
| 82 | + def _format_trino_compact_table_name( |
| 83 | + self, catalog: str, schema: str, table: str |
| 84 | + ) -> str: |
| 85 | + return f"{catalog}.{schema}.{table}" |
| 86 | + |
| 87 | + def _get_schema_name(self): |
| 88 | + if hasattr(self.connection_info, "connection_url"): |
| 89 | + return urlparse( |
| 90 | + self.connection_info.connection_url.get_secret_value() |
| 91 | + ).path.split("/")[-1] |
| 92 | + else: |
| 93 | + return self.connection_info.trino_schema.get_secret_value() |
| 94 | + |
| 95 | + def _transform_column_type(self, data_type): |
| 96 | + # all possible types listed here: https://trino.io/docs/current/language/types.html |
| 97 | + switcher = { |
| 98 | + # String Types (ignore Binary and Spatial Types for now) |
| 99 | + "char": WrenEngineColumnType.CHAR, |
| 100 | + "varchar": WrenEngineColumnType.VARCHAR, |
| 101 | + "tinytext": WrenEngineColumnType.TEXT, |
| 102 | + "text": WrenEngineColumnType.TEXT, |
| 103 | + "mediumtext": WrenEngineColumnType.TEXT, |
| 104 | + "longtext": WrenEngineColumnType.TEXT, |
| 105 | + "enum": WrenEngineColumnType.VARCHAR, |
| 106 | + "set": WrenEngineColumnType.VARCHAR, |
| 107 | + # Numeric Types(https://dev.mysql.com/doc/refman/8.4/en/numeric-types.html) |
| 108 | + "bit": WrenEngineColumnType.TINYINT, |
| 109 | + "tinyint": WrenEngineColumnType.TINYINT, |
| 110 | + "smallint": WrenEngineColumnType.SMALLINT, |
| 111 | + "mediumint": WrenEngineColumnType.INTEGER, |
| 112 | + "int": WrenEngineColumnType.INTEGER, |
| 113 | + "integer": WrenEngineColumnType.INTEGER, |
| 114 | + "bigint": WrenEngineColumnType.BIGINT, |
| 115 | + # boolean |
| 116 | + "bool": WrenEngineColumnType.BOOLEAN, |
| 117 | + "boolean": WrenEngineColumnType.BOOLEAN, |
| 118 | + # Decimal |
| 119 | + "float": WrenEngineColumnType.FLOAT8, |
| 120 | + "double": WrenEngineColumnType.DOUBLE, |
| 121 | + "decimal": WrenEngineColumnType.DECIMAL, |
| 122 | + "numeric": WrenEngineColumnType.NUMERIC, |
| 123 | + # Date and Time Types(https://dev.mysql.com/doc/refman/8.4/en/date-and-time-types.html) |
| 124 | + "date": WrenEngineColumnType.DATE, |
| 125 | + "datetime": WrenEngineColumnType.TIMESTAMP, |
| 126 | + "timestamp": WrenEngineColumnType.TIMESTAMPTZ, |
| 127 | + # JSON Type |
| 128 | + "json": WrenEngineColumnType.JSON, |
| 129 | + } |
| 130 | + |
| 131 | + return switcher.get(data_type.lower(), WrenEngineColumnType.UNKNOWN) |
0 commit comments