Skip to content

Commit 9dc8ed5

Browse files
authored
Merge branch 'main' into main
2 parents 370086c + ca82618 commit 9dc8ed5

16 files changed

+837
-149
lines changed

.github/workflows/ci-tests.yml

+8-2
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,20 @@ jobs:
100100
strategy:
101101
fail-fast: true
102102
matrix:
103-
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
103+
os:
104+
- ubuntu-latest
105+
- windows-latest
106+
- macos-13
107+
# - macos-latest # see: https://github.com/actions/setup-python/issues/696
104108
python-version: ["3.8", "3.9", "3.10", "3.11"]
105109
pandas-version: ["1.5.3", "2.0.3", "2.2.0"]
106110
pydantic-version: ["1.10.11", "2.3.0"]
107111
include:
108112
- os: ubuntu-latest
109113
pip-cache: ~/.cache/pip
110-
- os: macos-latest
114+
# - os: macos-latest
115+
# pip-cache: ~/Library/Caches/pip
116+
- os: macos-13
111117
pip-cache: ~/Library/Caches/pip
112118
- os: windows-latest
113119
pip-cache: ~/AppData/Local/pip/Cache

pandera/api/polars/model.py

+27-30
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Class-based api for polars models."""
22

33
from typing import (
4-
Any,
54
Dict,
65
List,
76
Tuple,
@@ -19,8 +18,10 @@
1918
from pandera.api.polars.container import DataFrameSchema
2019
from pandera.api.polars.components import Column
2120
from pandera.api.polars.model_config import BaseConfig
21+
from pandera.engines import polars_engine as pe
2222
from pandera.errors import SchemaInitError
2323
from pandera.typing import AnnotationInfo
24+
from pandera.typing.polars import Series
2425

2526

2627
class DataFrameModel(_DataFrameModel[pl.LazyFrame, DataFrameSchema]):
@@ -52,24 +53,30 @@ def _build_columns( # pylint:disable=too-many-locals
5253
field_name = field.name
5354
check_name = getattr(field, "check_name", None)
5455

55-
if annotation.metadata:
56-
if field.dtype_kwargs:
57-
raise TypeError(
58-
"Cannot specify redundant 'dtype_kwargs' "
59-
+ f"for {annotation.raw_annotation}."
60-
+ "\n Usage Tip: Drop 'typing.Annotated'."
61-
)
62-
dtype_kwargs = get_dtype_kwargs(annotation)
63-
dtype = annotation.arg(**dtype_kwargs) # type: ignore
64-
elif annotation.default_dtype:
65-
dtype = annotation.default_dtype
66-
else:
67-
dtype = annotation.arg
68-
69-
dtype = None if dtype is Any else dtype
70-
71-
if annotation.origin is None or isinstance(
72-
annotation.origin, pl.datatypes.DataTypeClass
56+
engine_dtype = None
57+
try:
58+
engine_dtype = pe.Engine.dtype(annotation.raw_annotation)
59+
dtype = engine_dtype.type
60+
except TypeError as exc:
61+
if annotation.metadata:
62+
if field.dtype_kwargs:
63+
raise TypeError(
64+
"Cannot specify redundant 'dtype_kwargs' "
65+
+ f"for {annotation.raw_annotation}."
66+
+ "\n Usage Tip: Drop 'typing.Annotated'."
67+
) from exc
68+
dtype_kwargs = get_dtype_kwargs(annotation)
69+
dtype = annotation.arg(**dtype_kwargs) # type: ignore
70+
elif annotation.default_dtype:
71+
dtype = annotation.default_dtype
72+
else:
73+
dtype = annotation.arg
74+
75+
if (
76+
annotation.origin is None
77+
or isinstance(annotation.origin, pl.datatypes.DataTypeClass)
78+
or annotation.origin is Series
79+
or engine_dtype
7380
):
7481
if check_name is False:
7582
raise SchemaInitError(
@@ -89,19 +96,9 @@ def _build_columns( # pylint:disable=too-many-locals
8996
columns[field_name] = Column(**column_kwargs)
9097

9198
else:
92-
origin_name = (
93-
f"{annotation.origin.__module__}."
94-
f"{annotation.origin.__name__}"
95-
)
96-
msg = (
97-
" Series[TYPE] annotations are not supported for polars. "
98-
"Use the bare TYPE directly"
99-
if origin_name == "pandera.typing.pandas.Series"
100-
else ""
101-
)
10299
raise SchemaInitError(
103100
f"Invalid annotation '{field_name}: "
104-
f"{annotation.raw_annotation}'.{msg}"
101+
f"{annotation.raw_annotation}'."
105102
)
106103

107104
return columns

pandera/api/pyspark/container.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from pathlib import Path
99
from typing import Any, Dict, List, Optional, Union, cast, overload
1010

11-
from pyspark.sql import DataFrame
11+
from pyspark.sql import DataFrame, SparkSession
12+
from pyspark.sql.types import StructType, StructField
1213

1314
from pandera import errors
1415
from pandera.api.base.schema import BaseSchema
@@ -563,6 +564,34 @@ def to_json(
563564

564565
return pandera.io.to_json(self, target, **kwargs)
565566

567+
def to_structtype(self) -> StructType:
568+
"""Recover fields of DataFrameSchema as a Pyspark StructType object.
569+
570+
As the output of this method will be used to specify a read schema in Pyspark
571+
(avoiding automatic schema inference), the False `nullable` properties are
572+
just ignored, as this check will be executed by the Pandera validations
573+
after a dataset is read.
574+
575+
:returns: StructType object with current schema fields.
576+
"""
577+
fields = [
578+
StructField(column, self.columns[column].dtype.type, True)
579+
for column in self.columns
580+
]
581+
return StructType(fields)
582+
583+
def to_ddl(self) -> str:
584+
"""Recover fields of DataFrameSchema as a Pyspark DDL string.
585+
586+
:returns: String with current schema fields, in compact DDL format.
587+
"""
588+
# `StructType.toDDL()` is only available in internal java classes
589+
spark = SparkSession.builder.getOrCreate()
590+
# Create a base dataframe from where we access underlying Java classes
591+
empty_df_with_schema = spark.createDataFrame([], self.to_structtype())
592+
593+
return empty_df_with_schema._jdf.schema().toDDL()
594+
566595

567596
def _validate_columns(
568597
column_dict: dict[Any, "pandera.api.pyspark.components.Column"], # type: ignore [name-defined]

pandera/api/pyspark/model.py

+18
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Class-based api for pyspark models."""
2+
23
# pylint:disable=abstract-method
34
import copy
45
import inspect
@@ -22,6 +23,7 @@
2223
)
2324

2425
import pyspark.sql as ps
26+
from pyspark.sql.types import StructType
2527

2628
from pandera.api.base.model import BaseModel
2729
from pandera.api.checks import Check
@@ -271,6 +273,22 @@ def to_yaml(cls, stream: Optional[os.PathLike] = None):
271273
"""
272274
return cls.to_schema().to_yaml(stream)
273275

276+
@classmethod
277+
def to_structtype(cls) -> StructType:
278+
"""Recover fields of DataFrameModel as a Pyspark StructType object.
279+
280+
:returns: StructType object with current model fields.
281+
"""
282+
return cls.to_schema().to_structtype()
283+
284+
@classmethod
285+
def to_ddl(cls) -> str:
286+
"""Recover fields of DataFrameModel as a Pyspark DDL string.
287+
288+
:returns: String with current model fields, in compact DDL format.
289+
"""
290+
return cls.to_schema().to_ddl()
291+
274292
@classmethod
275293
@docstring_substitution(validate_doc=DataFrameSchema.validate.__doc__)
276294
def validate(

pandera/engines/polars_engine.py

+40-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@
55
import decimal
66
import inspect
77
import warnings
8-
from typing import Any, Union, Optional, Iterable, Literal, Sequence, Tuple
8+
from typing import (
9+
Any,
10+
Union,
11+
Optional,
12+
Iterable,
13+
Literal,
14+
Sequence,
15+
Tuple,
16+
Type,
17+
)
918

1019

1120
import polars as pl
@@ -416,16 +425,26 @@ class Date(DataType, dtypes.Date):
416425
class DateTime(DataType, dtypes.DateTime):
417426
"""Polars datetime data type."""
418427

419-
type = pl.Datetime
428+
type: Type[pl.Datetime] = pl.Datetime
429+
time_zone_agnostic: bool = False
420430

421431
def __init__( # pylint:disable=super-init-not-called
422432
self,
423433
time_zone: Optional[str] = None,
424434
time_unit: Optional[str] = None,
435+
time_zone_agnostic: bool = False,
425436
) -> None:
437+
438+
_kwargs = {}
439+
if time_unit is not None:
440+
# avoid deprecated warning when initializing pl.Datetime:
441+
# passing time_unit=None is deprecated.
442+
_kwargs["time_unit"] = time_unit
443+
426444
object.__setattr__(
427-
self, "type", pl.Datetime(time_zone=time_zone, time_unit=time_unit)
445+
self, "type", pl.Datetime(time_zone=time_zone, **_kwargs)
428446
)
447+
object.__setattr__(self, "time_zone_agnostic", time_zone_agnostic)
429448

430449
@classmethod
431450
def from_parametrized_dtype(cls, polars_dtype: pl.Datetime):
@@ -435,6 +454,24 @@ def from_parametrized_dtype(cls, polars_dtype: pl.Datetime):
435454
time_zone=polars_dtype.time_zone, time_unit=polars_dtype.time_unit
436455
)
437456

457+
def check(
458+
self,
459+
pandera_dtype: dtypes.DataType,
460+
data_container: Optional[PolarsDataContainer] = None,
461+
) -> Union[bool, Iterable[bool]]:
462+
try:
463+
pandera_dtype = Engine.dtype(pandera_dtype)
464+
except TypeError:
465+
return False
466+
467+
if self.time_zone_agnostic:
468+
return (
469+
isinstance(pandera_dtype.type, pl.Datetime)
470+
and pandera_dtype.type.time_unit == self.type.time_unit
471+
)
472+
473+
return self.type == pandera_dtype.type and super().check(pandera_dtype)
474+
438475

439476
@Engine.register_dtype(
440477
equivalents=[

pandera/external_config.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
2222

2323
import pyspark.pandas
24+
except (ImportError, ModuleNotFoundError):
25+
pass
2426
finally:
2527
if is_spark_local_ip_dirty:
2628
os.environ.pop("SPARK_LOCAL_IP")

pandera/io/pandas_io.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ def from_frictionless_schema(
740740
schema: Union[str, Path, Dict, FrictionlessSchema]
741741
) -> DataFrameSchema:
742742
# pylint: disable=line-too-long,anomalous-backslash-in-string
743-
"""Create a :class:`~pandera.api.pandas.container.DataFrameSchema` from either a
743+
r"""Create a :class:`~pandera.api.pandas.container.DataFrameSchema` from either a
744744
frictionless json/yaml schema file saved on disk, or from a frictionless
745745
schema already loaded into memory.
746746

0 commit comments

Comments
 (0)