Skip to content

Commit d55a5ee

Browse files
committed
fix(databricks): use AS JSON for programmatic output of schema information
1 parent 0356caf commit d55a5ee

File tree

4 files changed

+147
-15
lines changed

4 files changed

+147
-15
lines changed

ibis/backends/databricks/__init__.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import contextlib
66
import functools
77
import getpass
8+
import json
89
import os
910
import sys
1011
import tempfile
@@ -20,23 +21,73 @@
2021
import ibis
2122
import ibis.backends.sql.compilers as sc
2223
import ibis.common.exceptions as exc
24+
import ibis.expr.datatypes as dt
2325
import ibis.expr.operations as ops
2426
import ibis.expr.schema as sch
2527
import ibis.expr.types as ir
2628
from ibis import util
2729
from ibis.backends import CanCreateDatabase, PyArrowExampleLoader, UrlFromPath
2830
from ibis.backends.sql import SQLBackend
2931
from ibis.backends.sql.compilers.base import STAR, AlterTable, RenameTable
32+
from ibis.backends.sql.datatypes import DatabricksType
3033

3134
if TYPE_CHECKING:
32-
from collections.abc import Mapping
35+
from collections.abc import Iterable, Mapping
3336

3437
import pandas as pd
3538
import polars as pl
3639

3740
from ibis.expr.schema import SchemaLike
3841

3942

43+
def _databricks_type_to_ibis(typ, nullable: bool = True) -> dt.DataType:
44+
"""Convert a Databricks type to an Ibis type."""
45+
typname = typ["name"]
46+
if typname == "array":
47+
return dt.Array(
48+
_databricks_type_to_ibis(
49+
typ["element_type"], nullable=typ["element_nullable"]
50+
),
51+
nullable=nullable,
52+
)
53+
elif typname == "map":
54+
return dt.Map(
55+
key_type=_databricks_type_to_ibis(typ["key_type"]),
56+
value_type=_databricks_type_to_ibis(
57+
typ["value_type"], nullable=typ["value_nullable"]
58+
),
59+
nullable=nullable,
60+
)
61+
elif typname == "struct":
62+
return dt.Struct(
63+
{
64+
field["name"]: _databricks_type_to_ibis(
65+
field["type"], nullable=field["nullable"]
66+
)
67+
for field in typ["fields"]
68+
},
69+
nullable=nullable,
70+
)
71+
elif typname == "decimal":
72+
return dt.Decimal(
73+
precision=typ["precision"], scale=typ["scale"], nullable=nullable
74+
)
75+
else:
76+
return DatabricksType.from_string(typname, nullable=nullable)
77+
78+
79+
def _databricks_schema_to_ibis(schema: Iterable[Mapping[str, Any]]) -> sch.Schema:
80+
"""Convert a Databricks schema to an Ibis schema."""
81+
return sch.Schema(
82+
{
83+
item["name"]: _databricks_type_to_ibis(
84+
item["type"], nullable=item["nullable"]
85+
)
86+
for item in schema
87+
}
88+
)
89+
90+
4091
class Backend(SQLBackend, CanCreateDatabase, UrlFromPath, PyArrowExampleLoader):
4192
name = "databricks"
4293
compiler = sc.databricks.compiler
@@ -143,6 +194,9 @@ def create_table(
143194
else:
144195
table = obj
145196

197+
if not schema:
198+
schema = table.schema()
199+
146200
self._run_pre_execute_hooks(table)
147201

148202
query = self.compiler.to_sqlglot(table)
@@ -158,10 +212,7 @@ def create_table(
158212
dialect = self.dialect
159213

160214
initial_table = sg.table(temp_name, catalog=catalog, db=database, quoted=quoted)
161-
target = sge.Schema(
162-
this=initial_table,
163-
expressions=(schema or table.schema()).to_sqlglot(dialect),
164-
)
215+
target = sge.Schema(this=initial_table, expressions=schema.to_sqlglot(dialect))
165216

166217
properties = sge.Properties(expressions=properties)
167218
create_stmt = sge.Create(kind="TABLE", this=target, properties=properties)
@@ -256,23 +307,18 @@ def get_schema(
256307
"""
257308
table = sg.table(
258309
table_name, db=database, catalog=catalog, quoted=self.compiler.quoted
259-
)
260-
sql = sge.Describe(kind="TABLE", this=table).sql(self.dialect)
310+
).sql(self.dialect)
261311
try:
262312
with self.con.cursor() as cur:
263-
out = cur.execute(sql).fetchall_arrow()
313+
[(out,)] = cur.execute(f"DESCRIBE EXTENDED {table} AS JSON").fetchall()
264314
except databricks.sql.exc.ServerOperationError as e:
265315
raise exc.TableNotFound(
266316
f"Table {table_name!r} not found in "
267317
f"{catalog or self.current_catalog}.{database or self.current_database}"
268318
) from e
269319

270-
names = out["col_name"].to_pylist()
271-
types = out["data_type"].to_pylist()
272-
273-
return sch.Schema(
274-
dict(zip(names, map(self.compiler.type_mapper.from_string, types)))
275-
)
320+
js = json.loads(out)
321+
return _databricks_schema_to_ibis(js["columns"])
276322

277323
@contextlib.contextmanager
278324
def _safe_raw_sql(self, query, *args, **kwargs):

ibis/backends/databricks/tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from os import environ as env
77
from typing import TYPE_CHECKING, Any
88

9+
import pytest
10+
911
import ibis
1012
from ibis.backends.tests.base import BackendTest
1113

@@ -73,3 +75,9 @@ def connect(*, tmpdir, worker_id, **kw) -> BaseBackend:
7375
schema="default",
7476
**kw,
7577
)
78+
79+
80+
@pytest.fixture(scope="session")
81+
def con(tmp_path_factory, data_dir, worker_id):
82+
with TestConf.load_data(data_dir, tmp_path_factory, worker_id) as be:
83+
yield be.connection
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
import ibis.expr.datatypes as dt
6+
import ibis.expr.schema as sch
7+
from ibis.util import gen_name
8+
9+
10+
@pytest.fixture
11+
def tmp_table(con):
12+
name = gen_name("databricks_tmp_table")
13+
yield name
14+
con.drop_table(name, force=True)
15+
16+
17+
def test_nested(con, tmp_table):
18+
schema = sch.Schema(
19+
{
20+
"nums": "decimal(33, 2)",
21+
"time": "timestamp('UTC')",
22+
"operationName": "string",
23+
"category": "string",
24+
"tenantId": "string",
25+
"properties": dt.Struct(
26+
{
27+
"Timestamp": "timestamp('UTC')",
28+
"ActionType": "string",
29+
"Application": "string",
30+
"ApplicationId": "int32",
31+
"AppInstanceId": "int32",
32+
"AccountObjectId": "string",
33+
"AccountId": "string",
34+
"AccountDisplayName": "string",
35+
"IsAdminOperation": "boolean",
36+
"DeviceType": "string",
37+
"OSPlatform": "string",
38+
"IPAddress": "string",
39+
"IsAnonymousProxy": "boolean",
40+
"CountryCode": "string",
41+
"City": "string",
42+
"ISP": "string",
43+
"UserAgent": "string",
44+
"ActivityType": "string",
45+
"ActivityObjects": "string",
46+
"ObjectName": "string",
47+
"ObjectType": "string",
48+
"ObjectId": "string",
49+
"ReportId": "string",
50+
"AccountType": "string",
51+
"IsExternalUser": "boolean",
52+
"IsImpersonated": "boolean",
53+
"IPTags": "string",
54+
"IPCategory": "string",
55+
"UserAgentTags": "string",
56+
"RawEventData": "string",
57+
"AdditionalFields": "string",
58+
}
59+
),
60+
"Tenant": "string",
61+
"_rescued_data": "string",
62+
"timestamp": "timestamp('UTC')",
63+
"parse_details": dt.Struct(
64+
{
65+
"status": "string",
66+
"at": "timestamp('UTC')",
67+
"info": dt.Struct({"input-file-name": "string"}),
68+
}
69+
),
70+
"p_date": "string",
71+
"foo": "array<struct<a: map<string, array<struct<b: string>>>>>",
72+
"a": "array<int>",
73+
"b": "map<string, int>",
74+
}
75+
)
76+
77+
t = con.create_table(tmp_table, schema=schema)
78+
assert t.schema() == schema

ibis/backends/tests/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def test_rename_table(con, temp_table, temp_table_orig):
407407

408408

409409
@mark.notimpl(["polars", "druid", "athena"])
410-
@mark.never(["impala", "pyspark", "databricks"], reason="No non-nullable datatypes")
410+
@mark.never(["impala", "pyspark"], reason="No non-nullable datatypes")
411411
@pytest.mark.notimpl(
412412
["flink"],
413413
raises=com.IbisError,

0 commit comments

Comments
 (0)