Skip to content

Commit 6afbe24

Browse files
authored
feat(ibis): Implement Snowflake Metadata APIs (#895)
1 parent 7461b69 commit 6afbe24

File tree

4 files changed

+215
-11
lines changed

4 files changed

+215
-11
lines changed

ibis-server/app/model/metadata/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from app.model.metadata.mssql import MSSQLMetadata
77
from app.model.metadata.mysql import MySQLMetadata
88
from app.model.metadata.postgres import PostgresMetadata
9+
from app.model.metadata.snowflake import SnowflakeMetadata
910
from app.model.metadata.trino import TrinoMetadata
1011

1112
mapping = {
@@ -16,6 +17,7 @@
1617
DataSource.mysql: MySQLMetadata,
1718
DataSource.postgres: PostgresMetadata,
1819
DataSource.trino: TrinoMetadata,
20+
DataSource.snowflake: SnowflakeMetadata,
1921
}
2022

2123

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from contextlib import closing
2+
3+
import ibis
4+
5+
from app.model import SnowflakeConnectionInfo
6+
from app.model.data_source import DataSource
7+
from app.model.metadata.dto import (
8+
Column,
9+
Constraint,
10+
ConstraintType,
11+
Table,
12+
TableProperties,
13+
WrenEngineColumnType,
14+
)
15+
from app.model.metadata.metadata import Metadata
16+
17+
18+
class SnowflakeMetadata(Metadata):
19+
def __init__(self, connection_info: SnowflakeConnectionInfo):
20+
super().__init__(connection_info)
21+
self.connection = DataSource.snowflake.get_connection(connection_info)
22+
23+
def get_table_list(self) -> list[Table]:
24+
schema = self._get_schema_name()
25+
sql = f"""
26+
SELECT
27+
c.TABLE_CATALOG AS TABLE_CATALOG,
28+
c.TABLE_SCHEMA AS TABLE_SCHEMA,
29+
c.TABLE_NAME AS TABLE_NAME,
30+
c.COLUMN_NAME AS COLUMN_NAME,
31+
c.DATA_TYPE AS DATA_TYPE,
32+
c.IS_NULLABLE AS IS_NULLABLE,
33+
c.COMMENT AS COLUMN_COMMENT,
34+
t.COMMENT AS TABLE_COMMENT
35+
FROM
36+
INFORMATION_SCHEMA.COLUMNS c
37+
JOIN
38+
INFORMATION_SCHEMA.TABLES t
39+
ON c.TABLE_SCHEMA = t.TABLE_SCHEMA
40+
AND c.TABLE_NAME = t.TABLE_NAME
41+
WHERE
42+
c.TABLE_SCHEMA = '{schema}';
43+
"""
44+
response = self.connection.sql(sql).to_pandas().to_dict(orient="records")
45+
46+
unique_tables = {}
47+
for row in response:
48+
# generate unique table name
49+
schema_table = self._format_compact_table_name(
50+
row["TABLE_SCHEMA"], row["TABLE_NAME"]
51+
)
52+
# init table if not exists
53+
if schema_table not in unique_tables:
54+
unique_tables[schema_table] = Table(
55+
name=schema_table,
56+
description=row["TABLE_COMMENT"],
57+
columns=[],
58+
properties=TableProperties(
59+
schema=row["TABLE_SCHEMA"],
60+
catalog=row["TABLE_CATALOG"],
61+
table=row["TABLE_NAME"],
62+
),
63+
primaryKey="",
64+
)
65+
66+
# table exists, and add column to the table
67+
unique_tables[schema_table].columns.append(
68+
Column(
69+
name=row["COLUMN_NAME"],
70+
type=self._transform_column_type(row["DATA_TYPE"]),
71+
notNull=row["IS_NULLABLE"].lower() == "no",
72+
description=row["COLUMN_COMMENT"],
73+
properties=None,
74+
)
75+
)
76+
return list(unique_tables.values())
77+
78+
def get_constraints(self) -> list[Constraint]:
79+
database = self._get_database_name()
80+
schema = self._get_schema_name()
81+
sql = f"""
82+
SHOW IMPORTED KEYS IN SCHEMA {database}.{schema};
83+
"""
84+
with closing(self.connection.raw_sql(sql)) as cur:
85+
fields = [field[0] for field in cur.description]
86+
result = [dict(zip(fields, row)) for row in cur.fetchall()]
87+
res = (
88+
ibis.memtable(result).to_pandas().to_dict(orient="records")
89+
if len(result) > 0
90+
else []
91+
)
92+
constraints = []
93+
for row in res:
94+
constraints.append(
95+
Constraint(
96+
constraintName=self._format_constraint_name(
97+
row["pk_table_name"],
98+
row["pk_column_name"],
99+
row["fk_table_name"],
100+
row["fk_column_name"],
101+
),
102+
constraintTable=self._format_compact_table_name(
103+
row["pk_schema_name"], row["pk_table_name"]
104+
),
105+
constraintColumn=row["pk_column_name"],
106+
constraintedTable=self._format_compact_table_name(
107+
row["fk_schema_name"], row["fk_table_name"]
108+
),
109+
constraintedColumn=row["fk_column_name"],
110+
constraintType=ConstraintType.FOREIGN_KEY,
111+
)
112+
)
113+
return constraints
114+
115+
def get_version(self) -> str:
116+
return self.connection.sql("SELECT CURRENT_VERSION()").to_pandas().iloc[0, 0]
117+
118+
def _get_database_name(self):
119+
return self.connection_info.database.get_secret_value()
120+
121+
def _get_schema_name(self):
122+
return self.connection_info.sf_schema.get_secret_value()
123+
124+
def _format_compact_table_name(self, schema: str, table: str):
125+
return f"{schema}.{table}"
126+
127+
def _format_constraint_name(
128+
self, table_name, column_name, referenced_table_name, referenced_column_name
129+
):
130+
return f"{table_name}_{column_name}_{referenced_table_name}_{referenced_column_name}"
131+
132+
def _transform_column_type(self, data_type):
133+
# all possible types listed here: https://docs.snowflake.com/en/sql-reference/intro-summary-data-types
134+
switcher = {
135+
# Numeric Types
136+
"number": WrenEngineColumnType.NUMERIC,
137+
"decimal": WrenEngineColumnType.NUMERIC,
138+
"numeric": WrenEngineColumnType.NUMERIC,
139+
"int": WrenEngineColumnType.INTEGER,
140+
"integer": WrenEngineColumnType.INTEGER,
141+
"bigint": WrenEngineColumnType.BIGINT,
142+
"smallint": WrenEngineColumnType.SMALLINT,
143+
"tinyint": WrenEngineColumnType.TINYINT,
144+
"byteint": WrenEngineColumnType.TINYINT,
145+
# Float
146+
"float4": WrenEngineColumnType.FLOAT4,
147+
"float": WrenEngineColumnType.FLOAT8,
148+
"float8": WrenEngineColumnType.FLOAT8,
149+
"double": WrenEngineColumnType.DOUBLE,
150+
"double precision": WrenEngineColumnType.DOUBLE,
151+
"real": WrenEngineColumnType.REAL,
152+
# String Types
153+
"varchar": WrenEngineColumnType.VARCHAR,
154+
"char": WrenEngineColumnType.CHAR,
155+
"character": WrenEngineColumnType.CHAR,
156+
"string": WrenEngineColumnType.STRING,
157+
"text": WrenEngineColumnType.TEXT,
158+
# Boolean
159+
"boolean": WrenEngineColumnType.BOOLEAN,
160+
# Date and Time Types
161+
"date": WrenEngineColumnType.DATE,
162+
"datetime": WrenEngineColumnType.TIMESTAMP,
163+
"timestamp": WrenEngineColumnType.TIMESTAMP,
164+
"timestamp_ntz": WrenEngineColumnType.TIMESTAMP,
165+
"timestamp_tz": WrenEngineColumnType.TIMESTAMPTZ,
166+
}
167+
168+
return switcher.get(data_type.lower(), WrenEngineColumnType.UNKNOWN)

ibis-server/tests/routers/v2/connector/test_snowflake.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_query(manifest_str):
9494
36901,
9595
"O",
9696
"173665.47",
97-
"1996-01-02 00:00:00.000000",
97+
"1996-01-02",
9898
"1_36901",
9999
"2024-01-01 23:59:59.000000",
100100
"2024-01-01 23:59:59.000000 UTC",
@@ -261,14 +261,48 @@ def test_validate_rule_column_is_valid_without_one_parameter(manifest_str):
261261
assert response.status_code == 422
262262
assert response.text == "Missing required parameter: `modelName`"
263263

264-
@pytest.mark.skip(reason="Not implemented")
265264
def test_metadata_list_tables():
266-
pass
265+
response = client.post(
266+
url=f"{base_url}/metadata/tables",
267+
json={"connectionInfo": connection_info},
268+
)
269+
assert response.status_code == 200
270+
tables = response.json()
271+
assert len(tables) == 8
272+
table = next(filter(lambda t: t["name"] == "TPCH_SF1.ORDERS", tables))
273+
assert table["name"] == "TPCH_SF1.ORDERS"
274+
assert table["primaryKey"] is not None
275+
assert table["description"] == "Orders data as defined by TPC-H"
276+
assert table["properties"] == {
277+
"catalog": "SNOWFLAKE_SAMPLE_DATA",
278+
"schema": "TPCH_SF1",
279+
"table": "ORDERS",
280+
}
281+
assert len(table["columns"]) == 9
282+
column = next(filter(lambda c: c["name"] == "O_COMMENT", table["columns"]))
283+
assert column == {
284+
"name": "O_COMMENT",
285+
"nestedColumns": None,
286+
"type": "TEXT",
287+
"notNull": True,
288+
"description": None,
289+
"properties": None,
290+
}
267291

268-
@pytest.mark.skip(reason="Not implemented")
269292
def test_metadata_list_constraints():
270-
pass
293+
response = client.post(
294+
url=f"{base_url}/metadata/constraints",
295+
json={"connectionInfo": connection_info},
296+
)
297+
assert response.status_code == 200
298+
299+
result = response.json()
300+
assert len(result) == 0
271301

272-
@pytest.mark.skip(reason="Not implemented")
273302
def test_metadata_get_version():
274-
pass
303+
response = client.post(
304+
url=f"{base_url}/metadata/version",
305+
json={"connectionInfo": connection_info},
306+
)
307+
assert response.status_code == 200
308+
assert response.text is not None

ibis-server/tests/routers/v3/connector/snowflake/test_functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ def test_scalar_function(manifest_str: str, connection_info):
8686
assert response.status_code == 200
8787
result = response.json()
8888
assert result == {
89-
"columns": ["col"],
89+
"columns": ["COL"],
9090
"data": [[1]],
91-
"dtypes": {"col": "int32"},
91+
"dtypes": {"COL": "int64"},
9292
}
9393

9494
def test_aggregate_function(manifest_str: str, connection_info):
@@ -103,7 +103,7 @@ def test_aggregate_function(manifest_str: str, connection_info):
103103
assert response.status_code == 200
104104
result = response.json()
105105
assert result == {
106-
"columns": ["col"],
106+
"columns": ["COL"],
107107
"data": [[1]],
108-
"dtypes": {"col": "int64"},
108+
"dtypes": {"COL": "int64"},
109109
}

0 commit comments

Comments
 (0)