Skip to content

Commit 3780a13

Browse files
grieve54706cpcloud
andauthored
fix(mssql): remove sort key to keep order (#9848)
Co-authored-by: Phillip Cloud <[email protected]>
1 parent c99cb4b commit 3780a13

File tree

2 files changed

+65
-29
lines changed

2 files changed

+65
-29
lines changed

ibis/backends/mssql/__init__.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -244,24 +244,21 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
244244
# us to pre-filter the columns we want back.
245245
# The syntax is:
246246
# `sys.dm_exec_describe_first_result_set(@tsql, @params, @include_browse_information)`
247-
query = f"""SELECT name,
248-
is_nullable AS nullable,
249-
system_type_name,
250-
precision,
251-
scale
252-
FROM
253-
sys.dm_exec_describe_first_result_set({tsql}, NULL, 0)"""
247+
query = f"""
248+
SELECT
249+
name,
250+
is_nullable,
251+
system_type_name,
252+
precision,
253+
scale
254+
FROM sys.dm_exec_describe_first_result_set({tsql}, NULL, 0)
255+
ORDER BY column_ordinal
256+
"""
254257
with self._safe_raw_sql(query) as cur:
255258
rows = cur.fetchall()
256259

257260
schema = {}
258-
for (
259-
name,
260-
nullable,
261-
system_type_name,
262-
precision,
263-
scale,
264-
) in sorted(rows, key=itemgetter(1)):
261+
for name, nullable, system_type_name, precision, scale in rows:
265262
newtyp = self.compiler.type_mapper.from_string(
266263
system_type_name, nullable=nullable
267264
)

ibis/backends/mssql/tests/test_client.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from __future__ import annotations
22

33
import pytest
4+
import sqlglot as sg
5+
import sqlglot.expressions as sge
46
from pytest import param
57

68
import ibis
79
import ibis.expr.datatypes as dt
810
from ibis import udf
911

10-
DB_TYPES = [
12+
RAW_DB_TYPES = [
1113
# Exact numbers
1214
("BIGINT", dt.int64),
1315
("BIT", dt.boolean),
@@ -36,23 +38,9 @@
3638
("DATETIME", dt.Timestamp(scale=3)),
3739
# Characters strings
3840
("CHAR", dt.string),
39-
param(
40-
"TEXT",
41-
dt.string,
42-
marks=pytest.mark.notyet(
43-
["mssql"], reason="Not supported by UTF-8 aware collations"
44-
),
45-
),
4641
("VARCHAR", dt.string),
4742
# Unicode character strings
4843
("NCHAR", dt.string),
49-
param(
50-
"NTEXT",
51-
dt.string,
52-
marks=pytest.mark.notyet(
53-
["mssql"], reason="Not supported by UTF-8 aware collations"
54-
),
55-
),
5644
("NVARCHAR", dt.string),
5745
# Binary strings
5846
("BINARY", dt.binary),
@@ -67,6 +55,23 @@
6755
("GEOGRAPHY", dt.geography),
6856
("HIERARCHYID", dt.string),
6957
]
58+
PARAM_TYPES = [
59+
param(
60+
"TEXT",
61+
dt.string,
62+
marks=pytest.mark.notyet(
63+
["mssql"], reason="Not supported by UTF-8 aware collations"
64+
),
65+
),
66+
param(
67+
"NTEXT",
68+
dt.string,
69+
marks=pytest.mark.notyet(
70+
["mssql"], reason="Not supported by UTF-8 aware collations"
71+
),
72+
),
73+
]
74+
DB_TYPES = RAW_DB_TYPES + PARAM_TYPES
7075

7176

7277
@pytest.mark.parametrize(("server_type", "expected_type"), DB_TYPES, ids=str)
@@ -81,6 +86,40 @@ def test_get_schema(con, server_type, expected_type, temp_table):
8186
assert con.sql(f"SELECT * FROM [{temp_table}]").schema() == expected_schema
8287

8388

89+
def test_schema_type_order(con, temp_table):
90+
columns = []
91+
pairs = {}
92+
93+
quoted = con.compiler.quoted
94+
dialect = con.dialect
95+
table_id = sg.to_identifier(temp_table, quoted=quoted)
96+
97+
for i, (server_type, expected_type) in enumerate(RAW_DB_TYPES):
98+
column_name = f"col_{i}"
99+
columns.append(
100+
sge.ColumnDef(
101+
this=sg.to_identifier(column_name, quoted=quoted), kind=server_type
102+
)
103+
)
104+
pairs[column_name] = expected_type
105+
106+
query = sge.Create(
107+
kind="TABLE", this=sge.Schema(this=table_id, expressions=columns)
108+
)
109+
stmt = query.sql(dialect)
110+
111+
with con.begin() as c:
112+
c.execute(stmt)
113+
114+
expected_schema = ibis.schema(pairs)
115+
116+
assert con.get_schema(temp_table) == expected_schema
117+
assert con.table(temp_table).schema() == expected_schema
118+
119+
raw_sql = sg.select("*").from_(table_id).sql(dialect)
120+
assert con.sql(raw_sql).schema() == expected_schema
121+
122+
84123
def test_builtin_scalar_udf(con):
85124
@udf.scalar.builtin
86125
def difference(a: str, b: str) -> int:

0 commit comments

Comments
 (0)