Skip to content

Commit 84e78f8

Browse files
committed
refine things
Signed-off-by: Matt Richards <[email protected]>
1 parent e902e59 commit 84e78f8

File tree

3 files changed

+42
-14
lines changed

3 files changed

+42
-14
lines changed

pandera/backends/polars/checks.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ def preprocess(self, check_obj: pl.LazyFrame, key: Optional[str]):
4545
# for the index to groupby on. Right now grouping by the index is not allowed.
4646
return check_obj
4747

48-
def apply(self, check_obj: PolarsData | AllColumnsPolarsCheckData):
48+
def apply(
49+
self, check_obj: PolarsData | AllColumnsPolarsCheckData
50+
) -> bool | pl.LazyFrame:
4951
"""Apply the check function to a check object."""
5052
if self.check.element_wise:
5153
selector = pl.col(check_obj.key or "*")
@@ -75,11 +77,13 @@ def apply(self, check_obj: PolarsData | AllColumnsPolarsCheckData):
7577

7678
return out
7779

78-
def postprocess(self, check_obj, check_output):
80+
def postprocess(
81+
self, check_obj: PolarsData | AllColumnsPolarsCheckData, check_output
82+
):
7983
"""Postprocesses the result of applying the check function."""
80-
if isinstance(check_obj, PolarsData) and isinstance(
81-
check_output, pl.LazyFrame
82-
):
84+
if isinstance(
85+
check_obj, PolarsData | AllColumnsPolarsCheckData
86+
) and isinstance(check_output, pl.LazyFrame):
8387
return self.postprocess_lazyframe_output(check_obj, check_output)
8488
elif isinstance(check_output, bool):
8589
return self.postprocess_bool_output(check_obj, check_output)
@@ -89,7 +93,7 @@ def postprocess(self, check_obj, check_output):
8993

9094
def postprocess_lazyframe_output(
9195
self,
92-
check_obj: PolarsData,
96+
check_obj: PolarsData | AllColumnsPolarsCheckData,
9397
check_output: pl.LazyFrame,
9498
) -> CheckResult:
9599
"""Postprocesses the result of applying the check function."""
@@ -114,7 +118,7 @@ def postprocess_lazyframe_output(
114118

115119
def postprocess_bool_output(
116120
self,
117-
check_obj: PolarsData,
121+
check_obj: PolarsData | AllColumnsPolarsCheckData,
118122
check_output: bool,
119123
) -> CheckResult:
120124
"""Postprocesses the result of applying the check function."""

pandera/engines/polars_engine.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Type,
1717
Union,
1818
TypedDict,
19+
overload,
1920
)
2021

2122
import polars as pl
@@ -578,6 +579,24 @@ class Array(DataType):
578579

579580
type = pl.Array
580581

582+
@overload
583+
def __init__(
584+
self,
585+
inner: Literal[None] = ...,
586+
shape: Literal[None] = ...,
587+
*,
588+
width: Literal[None] = ...,
589+
) -> None: ...
590+
591+
@overload
592+
def __init__(
593+
self,
594+
inner: PolarsDataType = ...,
595+
shape: Union[int, Tuple[int, ...], None] = ...,
596+
*,
597+
width: Optional[int] = ...,
598+
) -> None: ...
599+
581600
def __init__( # pylint:disable=super-init-not-called
582601
self,
583602
inner: Optional[PolarsDataType] = None,
@@ -593,8 +612,7 @@ def __init__( # pylint:disable=super-init-not-called
593612
kwargs["shape"] = width
594613
elif shape is not None:
595614
kwargs["shape"] = shape
596-
597-
if inner or shape or width:
615+
if inner:
598616
object.__setattr__(self, "type", pl.Array(inner=inner, **kwargs))
599617

600618
@classmethod

tests/polars/test_polars_dtypes.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from polars.testing.parametric import dataframes
1414

1515
import pandera.errors
16-
from pandera.api.polars.types import PolarsData
16+
from pandera.api.polars.types import AllColumnsPolarsCheckData
1717
from pandera.api.polars.utils import get_lazyframe_column_dtypes
1818
from pandera.constants import CHECK_OUTPUT_KEY
1919
from pandera.engines import polars_engine as pe
@@ -98,7 +98,9 @@ def test_coerce_no_cast(dtype, data):
9898
pl.enable_string_cache()
9999
pandera_dtype = dtype()
100100
df = data.draw(get_dataframe_strategy(type_=pandera_dtype.type))
101-
coerced = pandera_dtype.coerce(data_container=PolarsData(df))
101+
coerced = pandera_dtype.coerce(
102+
data_container=AllColumnsPolarsCheckData(df)
103+
)
102104
assert_frame_equal(df, coerced)
103105

104106

@@ -317,7 +319,9 @@ def test_polars_object_coercible(to_dtype, container, result):
317319
Test that polars_object_coercible can detect that a polars object is
318320
coercible or not.
319321
"""
320-
is_coercible = polars_object_coercible(PolarsData(container), to_dtype)
322+
is_coercible = polars_object_coercible(
323+
AllColumnsPolarsCheckData(container), to_dtype
324+
)
321325
assert_frame_equal(is_coercible, result)
322326

323327

@@ -439,12 +443,14 @@ def test_polars_nested_dtypes_try_coercion(
439443
data,
440444
):
441445
pandera_dtype = pe.Engine.dtype(coercible_dtype)
442-
coerced_data = pandera_dtype.try_coerce(PolarsData(data))
446+
coerced_data = pandera_dtype.try_coerce(AllColumnsPolarsCheckData(data))
443447
assert coerced_data.collect().equals(data.collect())
444448

445449
# coercing data with invalid type should raise an error
446450
try:
447-
pe.Engine.dtype(noncoercible_dtype).try_coerce(PolarsData(data))
451+
pe.Engine.dtype(noncoercible_dtype).try_coerce(
452+
AllColumnsPolarsCheckData(data)
453+
)
448454
except pandera.errors.ParserError as exc:
449455
col = pl.col(exc.failure_cases.columns[0])
450456
assert exc.failure_cases.select(col).equals(data.collect())

0 commit comments

Comments
 (0)