Skip to content

Commit ff34c7b

Browse files
committed
refactor(datatype): use a mapping to store StructType fields rather than names and types tuples
also schedule `Struct.from_dict()`, `Struct.pairs` and `Struct(names, types)` constructor for removal
1 parent c162750 commit ff34c7b

23 files changed

+123
-80
lines changed

ibis/backends/base/sql/alchemy/datatypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def _pg_map(dialect, itype):
182182
@to_sqla_type.register(Dialect, dt.Struct)
183183
def _struct(dialect, itype):
184184
return StructType(
185-
[(name, to_sqla_type(dialect, type)) for name, type in itype.pairs.items()]
185+
[(name, to_sqla_type(dialect, type)) for name, type in itype.fields.items()]
186186
)
187187

188188

ibis/backends/clickhouse/datatypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def _(ty: dt.Map) -> str:
260260
@serialize_raw.register(dt.Struct)
261261
def _(ty: dt.Struct) -> str:
262262
fields = ", ".join(
263-
f"{name} {serialize(field_ty)}" for name, field_ty in ty.pairs.items()
263+
f"{name} {serialize(field_ty)}" for name, field_ty in ty.fields.items()
264264
)
265265
return f"Tuple({fields})"
266266

ibis/backends/clickhouse/tests/test_types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def test_columns_types_with_additional_argument(con):
161161
param("Decimal(10, 3)", dt.Decimal(10, 3, nullable=False), id="decimal"),
162162
param(
163163
"Tuple(a String, b Array(Nullable(Float64)))",
164-
dt.Struct.from_dict(
164+
dt.Struct(
165165
dict(
166166
a=dt.String(nullable=False),
167167
b=dt.Array(dt.float64, nullable=False),
@@ -172,7 +172,7 @@ def test_columns_types_with_additional_argument(con):
172172
),
173173
param(
174174
"Tuple(String, Array(Nullable(Float64)))",
175-
dt.Struct.from_dict(
175+
dt.Struct(
176176
dict(
177177
f0=dt.String(nullable=False),
178178
f1=dt.Array(dt.float64, nullable=False),
@@ -183,7 +183,7 @@ def test_columns_types_with_additional_argument(con):
183183
),
184184
param(
185185
"Tuple(a String, Array(Nullable(Float64)))",
186-
dt.Struct.from_dict(
186+
dt.Struct(
187187
dict(
188188
a=dt.String(nullable=False),
189189
f1=dt.Array(dt.float64, nullable=False),
@@ -194,7 +194,7 @@ def test_columns_types_with_additional_argument(con):
194194
),
195195
param(
196196
"Nested(a String, b Array(Nullable(Float64)))",
197-
dt.Struct.from_dict(
197+
dt.Struct(
198198
dict(
199199
a=dt.Array(dt.String(nullable=False), nullable=False),
200200
b=dt.Array(dt.Array(dt.float64, nullable=False), nullable=False),

ibis/backends/duckdb/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]:
455455
["column_name", "column_type", "null"], rows.mappings()
456456
):
457457
ibis_type = parse(type)
458-
yield name, ibis_type(nullable=null.lower() == "yes")
458+
yield name, ibis_type.copy(nullable=null.lower() == "yes")
459459

460460
def _register_in_memory_table(self, table_op):
461461
df = table_op.data.to_frame()

ibis/backends/duckdb/tests/test_datatypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
P=dt.string,
5050
Q=dt.Array(dt.int32),
5151
R=dt.Map(dt.string, dt.int64),
52-
S=dt.Struct.from_dict(
52+
S=dt.Struct(
5353
dict(
5454
a=dt.int32,
5555
b=dt.string,

ibis/backends/polars/datatypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def from_ibis_interval(dtype):
5858
def from_ibis_struct(dtype):
5959
fields = [
6060
pl.Field(name=name, dtype=to_polars_type(dtype))
61-
for name, dtype in dtype.pairs.items()
61+
for name, dtype in dtype.fields.items()
6262
]
6363
return pl.Struct(fields)
6464

ibis/backends/pyarrow/datatypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def from_ibis_interval(dtype: dt.Interval):
5151
@to_pyarrow_type.register
5252
def from_ibis_struct(dtype: dt.Struct):
5353
return pa.struct(
54-
pa.field(name, to_pyarrow_type(typ)) for name, typ in dtype.pairs.items()
54+
pa.field(name, to_pyarrow_type(typ)) for name, typ in dtype.fields.items()
5555
)
5656

5757

ibis/backends/pyspark/datatypes.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
import ibis.common.exceptions as com
88
import ibis.expr.datatypes as dt
9+
import ibis.expr.schema as sch
910
from ibis.backends.base.sql.registry import sql_type_names
10-
from ibis.expr.schema import Schema
1111

1212
_sql_type_names = dict(sql_type_names, date='date')
1313

@@ -72,10 +72,11 @@ def _spark_map(spark_dtype_obj, nullable=True):
7272

7373
@dt.dtype.register(pt.StructType)
7474
def _spark_struct(spark_dtype_obj, nullable=True):
75-
names = spark_dtype_obj.names
76-
fields = spark_dtype_obj.fields
77-
ibis_types = [dt.dtype(f.dataType, nullable=f.nullable) for f in fields]
78-
return dt.Struct(names, ibis_types, nullable=nullable)
75+
fields = {
76+
n: dt.dtype(f.dataType, nullable=f.nullable)
77+
for n, f in zip(spark_dtype_obj.names, spark_dtype_obj.fields)
78+
}
79+
return dt.Struct(fields, nullable=nullable)
7980

8081

8182
_IBIS_DTYPE_TO_SPARK_DTYPE = {v: k for k, v in _SPARK_DTYPE_TO_IBIS_DTYPE.items()}
@@ -122,10 +123,17 @@ def _map(ibis_dtype_obj):
122123

123124

124125
@spark_dtype.register(dt.Struct)
125-
@spark_dtype.register(Schema)
126126
def _struct(ibis_dtype_obj):
127127
fields = [
128128
pt.StructField(n, spark_dtype(t), t.nullable)
129-
for n, t in zip(ibis_dtype_obj.names, ibis_dtype_obj.types)
129+
for n, t in ibis_dtype_obj.fields.items()
130+
]
131+
return pt.StructType(fields)
132+
133+
134+
@spark_dtype.register(sch.Schema)
135+
def _schema(ibis_schem_obj):
136+
fields = [
137+
pt.StructField(n, spark_dtype(t), t.nullable) for n, t in ibis_schem_obj.items()
130138
]
131139
return pt.StructType(fields)

ibis/backends/tests/test_aggregation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_aggregate_multikey_group_reduction_udf(backend, alltypes, df):
183183

184184
@reduction(
185185
input_type=[dt.double],
186-
output_type=dt.Struct(['mean', 'std'], [dt.double, dt.double]),
186+
output_type=dt.Struct({'mean': dt.double, 'std': dt.double}),
187187
)
188188
def mean_and_std(v):
189189
return v.mean(), v.std()

ibis/backends/tests/test_struct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_null_literal(con, field):
7979
def test_struct_column(alltypes, df):
8080
t = alltypes
8181
expr = ibis.struct(dict(a=t.string_col, b=1, c=t.bigint_col)).name("s")
82-
assert expr.type() == dt.Struct.from_dict(dict(a=dt.string, b=dt.int8, c=dt.int64))
82+
assert expr.type() == dt.Struct(dict(a=dt.string, b=dt.int8, c=dt.int64))
8383
result = expr.execute()
8484
expected = pd.Series(
8585
(dict(a=a, b=1, c=c) for a, c in zip(df.string_col, df.bigint_col)),

ibis/backends/tests/test_vectorized_udf.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def add_one_struct(v):
8585
def create_add_one_struct_udf(result_formatter):
8686
return elementwise(
8787
input_type=[dt.double],
88-
output_type=dt.Struct(['col1', 'col2'], [dt.double, dt.double]),
88+
output_type=dt.Struct({'col1': dt.double, 'col2': dt.double}),
8989
)(_format_struct_udf_return_type(add_one_struct, result_formatter))
9090

9191

@@ -127,7 +127,7 @@ def create_add_one_struct_udf(result_formatter):
127127

128128
@elementwise(
129129
input_type=[dt.double],
130-
output_type=dt.Struct(['double_col', 'col2'], [dt.double, dt.double]),
130+
output_type=dt.Struct({'double_col': dt.double, 'col2': dt.double}),
131131
)
132132
def overwrite_struct_elementwise(v):
133133
assert isinstance(v, pd.Series)
@@ -137,7 +137,7 @@ def overwrite_struct_elementwise(v):
137137
@elementwise(
138138
input_type=[dt.double],
139139
output_type=dt.Struct(
140-
['double_col', 'col2', 'float_col'], [dt.double, dt.double, dt.double]
140+
{'double_col': dt.double, 'col2': dt.double, 'float_col': dt.double}
141141
),
142142
)
143143
def multiple_overwrite_struct_elementwise(v):
@@ -147,7 +147,7 @@ def multiple_overwrite_struct_elementwise(v):
147147

148148
@analytic(
149149
input_type=[dt.double, dt.double],
150-
output_type=dt.Struct(['double_col', 'demean_weight'], [dt.double, dt.double]),
150+
output_type=dt.Struct({'double_col': dt.double, 'demean_weight': dt.double}),
151151
)
152152
def overwrite_struct_analytic(v, w):
153153
assert isinstance(v, pd.Series)
@@ -165,7 +165,7 @@ def demean_struct(v, w):
165165
def create_demean_struct_udf(result_formatter):
166166
return analytic(
167167
input_type=[dt.double, dt.double],
168-
output_type=dt.Struct(['demean', 'demean_weight'], [dt.double, dt.double]),
168+
output_type=dt.Struct({'demean': dt.double, 'demean_weight': dt.double}),
169169
)(_format_struct_udf_return_type(demean_struct, result_formatter))
170170

171171

@@ -203,7 +203,7 @@ def mean_struct(v, w):
203203
def create_mean_struct_udf(result_formatter):
204204
return reduction(
205205
input_type=[dt.double, dt.int64],
206-
output_type=dt.Struct(['mean', 'mean_weight'], [dt.double, dt.double]),
206+
output_type=dt.Struct({'mean': dt.double, 'mean_weight': dt.double}),
207207
)(_format_struct_udf_return_type(mean_struct, result_formatter))
208208

209209

@@ -220,7 +220,7 @@ def create_mean_struct_udf(result_formatter):
220220

221221
@reduction(
222222
input_type=[dt.double, dt.int64],
223-
output_type=dt.Struct(['double_col', 'mean_weight'], [dt.double, dt.double]),
223+
output_type=dt.Struct({'double_col': dt.double, 'mean_weight': dt.double}),
224224
)
225225
def overwrite_struct_reduction(v, w):
226226
assert isinstance(v, (np.ndarray, pd.Series))
@@ -495,7 +495,7 @@ def test_elementwise_udf_destructure_exact_once(
495495
):
496496
@elementwise(
497497
input_type=[dt.double],
498-
output_type=dt.Struct(['col1', 'col2'], [dt.double, dt.double]),
498+
output_type=dt.Struct({'col1': dt.double, 'col2': dt.double}),
499499
)
500500
def add_one_struct_exact_once(v):
501501
key = v.iloc[0]

ibis/backends/trino/datatypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _string(_, itype):
173173
@to_sqla_type.register(TrinoDialect, dt.Struct)
174174
def _struct(dialect, itype):
175175
return ROW(
176-
[(name, to_sqla_type(dialect, typ)) for name, typ in itype.pairs.items()]
176+
[(name, to_sqla_type(dialect, typ)) for name, typ in itype.fields.items()]
177177
)
178178

179179

ibis/expr/datatypes/core.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
from ibis.common.grounds import Concrete, Singleton
1515
from ibis.common.validators import (
1616
all_of,
17+
frozendict_of,
1718
instance_of,
1819
isin,
1920
map_to,
20-
tuple_of,
2121
validator,
2222
)
23+
from ibis.util import deprecated, warn_deprecated
2324

2425
dtype = Dispatcher('dtype')
2526

@@ -642,18 +643,42 @@ def to_integer_type(self):
642643
class Struct(DataType):
643644
"""Structured values."""
644645

645-
names = tuple_of(instance_of(str))
646-
types = tuple_of(datatype)
646+
fields = frozendict_of(instance_of(str), datatype)
647647

648648
scalar = ir.StructScalar
649649
column = ir.StructColumn
650650

651-
def __init__(self, names, types, **kwargs):
652-
if len(names) != len(types):
653-
raise IbisTypeError(
654-
'Struct datatype names and types must have the same length'
651+
@classmethod
652+
def __create__(cls, names, types=None, nullable=True):
653+
if types is None:
654+
fields = names
655+
else:
656+
warn_deprecated(
657+
"Struct(names, types)",
658+
as_of="4.1",
659+
removed_in="5.0",
660+
instead=(
661+
"construct a Struct type using a mapping of names to types instead: "
662+
"Struct(dict(zip(names, types)))"
663+
),
655664
)
656-
super().__init__(names=names, types=types, **kwargs)
665+
if len(names) != len(types):
666+
raise IbisTypeError(
667+
'Struct datatype names and types must have the same length'
668+
)
669+
fields = dict(zip(names, types))
670+
671+
return super().__create__(fields=fields, nullable=nullable)
672+
673+
def __reduce__(self):
674+
return (self.__class__, (self.fields, None, self.nullable))
675+
676+
def copy(self, fields=None, nullable=None):
677+
if fields is None:
678+
fields = self.fields
679+
if nullable is None:
680+
nullable = self.nullable
681+
return type(self)(fields, nullable=nullable)
657682

658683
@classmethod
659684
def from_tuples(
@@ -673,10 +698,14 @@ def from_tuples(
673698
Struct
674699
Struct data type instance
675700
"""
676-
names, types = zip(*pairs)
677-
return cls(names, types, nullable=nullable)
701+
return cls(dict(pairs), nullable=nullable)
678702

679703
@classmethod
704+
@deprecated(
705+
as_of="4.1",
706+
removed_in="5.0",
707+
instead="directly construct a Struct type instead",
708+
)
680709
def from_dict(
681710
cls, pairs: Mapping[str, str | DataType], nullable: bool = True
682711
) -> Struct:
@@ -694,26 +723,33 @@ def from_dict(
694723
Struct
695724
Struct data type instance
696725
"""
697-
names, types = pairs.keys(), pairs.values()
698-
return cls(names, types, nullable=nullable)
726+
return cls(pairs, nullable=nullable)
699727

700728
@property
729+
@deprecated(
730+
as_of="4.1",
731+
removed_in="5.0",
732+
instead="use struct_type.fields attribute instead",
733+
)
701734
def pairs(self) -> Mapping[str, DataType]:
702-
"""Return a mapping from names to data type instances.
735+
return self.fields
703736

704-
Returns
705-
-------
706-
Mapping[str, DataType]
707-
Mapping of field name to data type
708-
"""
709-
return dict(zip(self.names, self.types))
737+
@property
738+
def names(self) -> tuple[str, ...]:
739+
"""Return the names of the struct's fields."""
740+
return tuple(self.fields.keys())
741+
742+
@property
743+
def types(self) -> tuple[DataType, ...]:
744+
"""Return the types of the struct's fields."""
745+
return tuple(self.fields.values())
710746

711747
def __getitem__(self, key: str) -> DataType:
712-
return self.pairs[key]
748+
return self.fields[key]
713749

714750
def __repr__(self) -> str:
715751
return '{}({}, nullable={})'.format(
716-
self.name, list(self.pairs.items()), self.nullable
752+
self.name, list(self.fields.items()), self.nullable
717753
)
718754

719755
@property

ibis/expr/datatypes/value.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,15 @@ def infer(value: Any) -> dt.DataType:
3232
raise InputTypeError(value)
3333

3434

35+
# TODO(kszucs): support NamedTuples and dataclasses instead of OrderedDict
36+
# which should trigger infer_map instead
3537
@infer.register(collections.OrderedDict)
3638
def infer_struct(value: Mapping[str, Any]) -> dt.Struct:
3739
"""Infer the [`Struct`][ibis.expr.datatypes.Struct] type of `value`."""
3840
if not value:
3941
raise TypeError('Empty struct type not supported')
40-
return dt.Struct(list(value.keys()), list(map(infer, value.values())))
42+
fields = {name: infer(val) for name, val in value.items()}
43+
return dt.Struct(fields)
4144

4245

4346
@infer.register(collections.abc.Mapping)
@@ -51,7 +54,7 @@ def infer_map(value: Mapping[Any, Any]) -> dt.Map:
5154
highest_precedence(map(infer, value.values())),
5255
)
5356
except IbisTypeError:
54-
return dt.Struct.from_dict(toolz.valmap(infer, value, factory=type(value)))
57+
return dt.Struct(toolz.valmap(infer, value, factory=type(value)))
5558

5659

5760
@infer.register((list, tuple))
@@ -303,7 +306,7 @@ def normalize(typ, value):
303306
return frozendict({k: normalize(typ.value_type, v) for k, v in value.items()})
304307
elif typ.is_struct():
305308
return frozendict(
306-
{k: normalize(typ[k], v) for k, v in value.items() if k in typ.pairs}
309+
{k: normalize(typ[k], v) for k, v in value.items() if k in typ.fields}
307310
)
308311
elif typ.is_geospatial():
309312
if isinstance(value, (tuple, list)):

0 commit comments

Comments
 (0)