Skip to content

Commit d17d53e

Browse files
authored
feat(datatypes): add support for fixed length arrays (#10729)
1 parent d0b2d4d commit d17d53e

File tree

8 files changed

+108
-19
lines changed

8 files changed

+108
-19
lines changed

ibis/backends/clickhouse/tests/test_datatypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def test_parse_type(ch_type, ibis_type):
296296
| its.date_dtype()
297297
| its.time_dtype()
298298
| its.timestamp_dtype(scale=st.integers(0, 9))
299-
| its.array_dtypes(roundtrippable_types, nullable=false)
299+
| its.array_dtypes(roundtrippable_types, nullable=false, length=st.none())
300300
| its.map_dtypes(map_key_types, roundtrippable_types, nullable=false)
301301
)
302302
)

ibis/backends/duckdb/tests/test_datatypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
("UUID", dt.uuid),
3737
("VARCHAR", dt.string),
3838
("INTEGER[]", dt.Array(dt.int32)),
39-
("INTEGER[3]", dt.Array(dt.int32)),
39+
("INTEGER[3]", dt.Array(dt.int32, length=3)),
4040
("MAP(VARCHAR, BIGINT)", dt.Map(dt.string, dt.int64)),
4141
(
4242
"STRUCT(a INTEGER, b VARCHAR, c MAP(VARCHAR, DOUBLE[])[])",

ibis/backends/duckdb/tests/test_udf.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import ibis
88
from ibis import udf
9+
from ibis.util import gen_name
910

1011

1112
@udf.scalar.builtin
@@ -149,3 +150,31 @@ def regexp_extract(s, pattern, groups): ...
149150
e = regexp_extract("2023-04-15", r"(\d+)-(\d+)-(\d+)", ["y", "m", "d"])
150151
sql = str(ibis.to_sql(e, dialect="duckdb"))
151152
assert r"REGEXP_EXTRACT('2023-04-15', '(\d+)-(\d+)-(\d+)', ['y', 'm', 'd'])" in sql
153+
154+
155+
@pytest.fixture(scope="module")
156+
def array_cosine_t(con):
157+
return con.create_table(
158+
gen_name("array_cosine_t"),
159+
obj={"fixed": [[1, 2, 3]], "varlen": [[1, 2, 3]]},
160+
schema={"fixed": "array<double, 3>", "varlen": "array<double>"},
161+
temp=True,
162+
)
163+
164+
165+
@pytest.mark.parametrize(
166+
("column", "expr_fn"),
167+
[
168+
("fixed", lambda c: c),
169+
("varlen", lambda c: c.cast("array<float, 3>")),
170+
],
171+
ids=["no-cast", "cast"],
172+
)
173+
def test_builtin_fixed_length_array_udf(array_cosine_t, column, expr_fn):
174+
@udf.scalar.builtin
175+
def array_cosine_similarity(a, b) -> float: ...
176+
177+
expr = expr_fn(array_cosine_t[column])
178+
expr = array_cosine_similarity(expr, expr)
179+
result = expr.execute()
180+
assert result.iat[0] == 1.0

ibis/backends/sql/datatypes.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,14 @@ def to_ibis(cls, typ: sge.DataType, nullable: bool | None = None) -> dt.DataType
180180
"nullable", nullable if nullable is not None else cls.default_nullable
181181
)
182182
if method := getattr(cls, f"_from_sqlglot_{typecode.name}", None):
183-
dtype = method(*typ.expressions, nullable=nullable)
183+
if typecode == sge.DataType.Type.ARRAY:
184+
dtype = method(
185+
*typ.expressions,
186+
*(typ.args.get("values", ()) or ()),
187+
nullable=nullable,
188+
)
189+
else:
190+
dtype = method(*typ.expressions, nullable=nullable)
184191
elif (known_typ := _from_sqlglot_types.get(typecode)) is not None:
185192
dtype = known_typ(nullable=nullable)
186193
else:
@@ -222,9 +229,16 @@ def to_string(cls, dtype: dt.DataType) -> str:
222229

223230
@classmethod
224231
def _from_sqlglot_ARRAY(
225-
cls, value_type: sge.DataType, nullable: bool | None = None
232+
cls,
233+
value_type: sge.DataType,
234+
length: sge.Literal | None = None,
235+
nullable: bool | None = None,
226236
) -> dt.Array:
227-
return dt.Array(cls.to_ibis(value_type), nullable=nullable)
237+
return dt.Array(
238+
cls.to_ibis(value_type),
239+
length=None if length is None else int(length.this),
240+
nullable=nullable,
241+
)
228242

229243
@classmethod
230244
def _from_sqlglot_MAP(
@@ -380,7 +394,12 @@ def _from_ibis_Interval(cls, dtype: dt.Interval) -> sge.DataType:
380394
@classmethod
381395
def _from_ibis_Array(cls, dtype: dt.Array) -> sge.DataType:
382396
value_type = cls.from_ibis(dtype.value_type)
383-
return sge.DataType(this=typecode.ARRAY, expressions=[value_type], nested=True)
397+
return sge.DataType(
398+
this=typecode.ARRAY,
399+
expressions=[value_type],
400+
values=None if dtype.length is None else [sge.convert(dtype.length)],
401+
nested=True,
402+
)
384403

385404
@classmethod
386405
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
@@ -775,7 +794,10 @@ def _from_sqlglot_DECIMAL(
775794

776795
@classmethod
777796
def _from_sqlglot_ARRAY(
778-
cls, value_type=None, nullable: bool | None = None
797+
cls,
798+
value_type: sge.DataType | None = None,
799+
length: sge.Literal | None = None,
800+
nullable: bool | None = None,
779801
) -> dt.Array:
780802
assert value_type is None
781803
return dt.Array(dt.json, nullable=nullable)
@@ -1050,7 +1072,12 @@ def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType:
10501072
return sge.DataType(this=code)
10511073

10521074
@classmethod
1053-
def _from_sqlglot_ARRAY(cls, value_type: sge.DataType) -> NoReturn:
1075+
def _from_sqlglot_ARRAY(
1076+
cls,
1077+
value_type: sge.DataType,
1078+
length: sge.Literal | None = None,
1079+
nullable: bool | None = None,
1080+
) -> NoReturn:
10541081
raise com.UnsupportedBackendType("Arrays not supported in Exasol")
10551082

10561083
@classmethod
@@ -1105,7 +1132,12 @@ def _from_ibis_Struct(cls, dtype: dt.String) -> sge.DataType:
11051132
raise com.UnsupportedBackendType("SQL Server does not support structs")
11061133

11071134
@classmethod
1108-
def _from_sqlglot_ARRAY(cls) -> sge.DataType:
1135+
def _from_sqlglot_ARRAY(
1136+
cls,
1137+
value_type: sge.DataType,
1138+
length: sge.Literal | None = None,
1139+
nullable: bool | None = None,
1140+
) -> NoReturn:
11091141
raise com.UnsupportedBackendType("SQL Server does not support arrays")
11101142

11111143
@classmethod

ibis/expr/datatypes/core.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Iterable, Iterator, Mapping, Sequence
99
from numbers import Integral, Real
1010
from typing import (
11+
Annotated,
1112
Any,
1213
Generic,
1314
Literal,
@@ -27,7 +28,7 @@
2728
from ibis.common.collections import FrozenOrderedDict, MapSet
2829
from ibis.common.dispatch import lazy_singledispatch
2930
from ibis.common.grounds import Concrete, Singleton
30-
from ibis.common.patterns import Coercible, CoercionError
31+
from ibis.common.patterns import Between, Coercible, CoercionError
3132
from ibis.common.temporal import IntervalUnit, TimestampUnit
3233

3334

@@ -50,14 +51,14 @@ def dtype(value: Any, nullable: bool = True) -> DataType:
5051
>>> ibis.dtype("int32")
5152
Int32(nullable=True)
5253
>>> ibis.dtype("array<float>")
53-
Array(value_type=Float64(nullable=True), nullable=True)
54+
Array(value_type=Float64(nullable=True), length=None, nullable=True)
5455
5556
DataType objects may also be created from Python types:
5657
5758
>>> ibis.dtype(int)
5859
Int64(nullable=True)
5960
>>> ibis.dtype(list[float])
60-
Array(value_type=Float64(nullable=True), nullable=True)
61+
Array(value_type=Float64(nullable=True), length=None, nullable=True)
6162
6263
Or other type systems, like numpy/pandas/pyarrow types:
6364
@@ -309,6 +310,10 @@ def is_enum(self) -> bool:
309310
"""Return True if an instance of an Enum type."""
310311
return isinstance(self, Enum)
311312

313+
def is_fixed_length_array(self) -> bool:
314+
"""Return True if an instance of an Array type and has a known length."""
315+
return isinstance(self, Array) and self.length is not None
316+
312317
def is_float16(self) -> bool:
313318
"""Return True if an instance of a Float16 type."""
314319
return isinstance(self, Float16)
@@ -904,13 +909,19 @@ class Array(Variadic, Parametric, Generic[T]):
904909
"""Array values."""
905910

906911
value_type: T
912+
"""Element type of the array."""
913+
length: Annotated[int, Between(lower=0)] | None = None
914+
"""The length of the array if known."""
907915

908916
scalar = "ArrayScalar"
909917
column = "ArrayColumn"
910918

911919
@property
912920
def _pretty_piece(self) -> str:
913-
return f"<{self.value_type}>"
921+
value_type = self.value_type
922+
if (length := self.length) is not None:
923+
return f"<{value_type}, {length:d}>"
924+
return f"<{value_type}>"
914925

915926

916927
K = TypeVar("K", bound=DataType, covariant=True)
@@ -922,7 +933,9 @@ class Map(Variadic, Parametric, Generic[K, V]):
922933
"""Associative array values."""
923934

924935
key_type: K
936+
"""Map key type."""
925937
value_type: V
938+
"""Map value type."""
926939

927940
scalar = "MapScalar"
928941
column = "MapColumn"

ibis/expr/datatypes/parse.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def parse(
7474
>>> import ibis
7575
>>> import ibis.expr.datatypes as dt
7676
>>> dt.parse("array<int64>")
77-
Array(value_type=Int64(nullable=True), nullable=True)
77+
Array(value_type=Int64(nullable=True), length=None, nullable=True)
7878
7979
You can avoid parsing altogether by constructing objects directly
8080
@@ -182,8 +182,13 @@ def geotype_parser(typ: type[dt.DataType]) -> dt.DataType:
182182
)
183183

184184
ty = parsy.forward_declaration()
185-
angle_type = LANGLE.then(ty).skip(RANGLE)
186-
array = spaceless_string("array").then(angle_type).map(dt.Array)
185+
186+
array = (
187+
spaceless_string("array")
188+
.then(LANGLE)
189+
.then(parsy.seq(ty, COMMA.then(LENGTH).optional()).combine(dt.Array))
190+
.skip(RANGLE)
191+
)
187192

188193
map = (
189194
spaceless_string("map")

ibis/expr/datatypes/tests/test_core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,14 @@ def test_is_methods(dtype_class):
571571
assert is_dtype is True
572572

573573

574+
def test_is_fixed_length_array():
575+
dtype = dt.Array(dt.int8)
576+
assert not dtype.is_fixed_length_array()
577+
578+
dtype = dt.Array(dt.int8, 10)
579+
assert dtype.is_fixed_length_array()
580+
581+
574582
def test_is_array():
575583
assert dt.Array(dt.string).is_array()
576584
assert not dt.string.is_array()

ibis/tests/strategies.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,11 @@ def primitive_dtypes(nullable=_nullable):
149149

150150
_item_strategy = primitive_dtypes()
151151

152+
_length = st.one_of(st.none(), st.integers(min_value=0))
152153

153-
def array_dtypes(value_type=_item_strategy, nullable=_nullable):
154-
return st.builds(dt.Array, value_type=value_type, nullable=nullable)
154+
155+
def array_dtypes(value_type=_item_strategy, nullable=_nullable, length=_length):
156+
return st.builds(dt.Array, value_type=value_type, nullable=nullable, length=length)
155157

156158

157159
def map_dtypes(key_type=_item_strategy, value_type=_item_strategy, nullable=_nullable):
@@ -180,7 +182,7 @@ def struct_dtypes(
180182

181183
def geospatial_dtypes(nullable=_nullable):
182184
geotype = st.one_of(st.just("geography"), st.just("geometry"))
183-
srid = st.one_of(st.just(None), st.integers(min_value=0))
185+
srid = st.one_of(st.none(), st.integers(min_value=0))
184186
return st.one_of(
185187
st.builds(dt.Point, geotype=geotype, nullable=nullable, srid=srid),
186188
st.builds(dt.LineString, geotype=geotype, nullable=nullable, srid=srid),

0 commit comments

Comments
 (0)