Skip to content

Commit af842b1

Browse files
authored
feat: allow pandas.cut 'labels' parameter to accept a list of string (#1549)
1 parent 9f10541 commit af842b1

File tree

6 files changed

+169
-37
lines changed

6 files changed

+169
-37
lines changed

bigframes/core/compile/aggregate_compiler.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,13 @@ def _(
366366
for this_bin in range(op.bins):
367367
if op.labels is False:
368368
value = compile_ibis_types.literal_to_ibis_scalar(
369-
this_bin, force_dtype=pd.Int64Dtype()
369+
this_bin,
370+
force_dtype=pd.Int64Dtype(),
371+
)
372+
elif isinstance(op.labels, typing.Iterable):
373+
value = compile_ibis_types.literal_to_ibis_scalar(
374+
list(op.labels)[this_bin],
375+
force_dtype=pd.StringDtype(storage="pyarrow"),
370376
)
371377
else:
372378
left_adj = adj if this_bin == 0 and op.right else 0
@@ -402,7 +408,13 @@ def _(
402408

403409
if op.labels is False:
404410
value = compile_ibis_types.literal_to_ibis_scalar(
405-
this_bin, force_dtype=pd.Int64Dtype()
411+
this_bin,
412+
force_dtype=pd.Int64Dtype(),
413+
)
414+
elif isinstance(op.labels, typing.Iterable):
415+
value = compile_ibis_types.literal_to_ibis_scalar(
416+
list(op.labels)[this_bin],
417+
force_dtype=pd.StringDtype(storage="pyarrow"),
406418
)
407419
else:
408420
if op.right:

bigframes/core/reshape/tile.py

+35-8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import bigframes_vendored.pandas.core.reshape.tile as vendored_pandas_tile
2121
import pandas as pd
2222

23+
import bigframes.constants
2324
import bigframes.core.expression as ex
2425
import bigframes.core.ordering as order
2526
import bigframes.core.utils as utils
@@ -41,17 +42,37 @@ def cut(
4142
right: typing.Optional[bool] = True,
4243
labels: typing.Union[typing.Iterable[str], bool, None] = None,
4344
) -> bigframes.series.Series:
44-
if labels is not None and labels is not False:
45+
if (
46+
labels is not None
47+
and labels is not False
48+
and not isinstance(labels, typing.Iterable)
49+
):
50+
raise ValueError(
51+
"Bin labels must either be False, None or passed in as a list-like argument"
52+
)
53+
if (
54+
isinstance(labels, typing.Iterable)
55+
and len(list(labels)) > 0
56+
and not isinstance(list(labels)[0], str)
57+
):
4558
raise NotImplementedError(
46-
"The 'labels' parameter must be either False or None. "
47-
"Please provide a valid value for 'labels'."
59+
"When using an iterable for labels, only iterables of strings are supported "
60+
f"but found {type(list(labels)[0])}. {constants.FEEDBACK_LINK}"
4861
)
62+
4963
if x.size == 0:
5064
raise ValueError("Cannot cut empty array.")
5165

5266
if isinstance(bins, int):
5367
if bins <= 0:
5468
raise ValueError("`bins` should be a positive integer.")
69+
if isinstance(labels, typing.Iterable):
70+
labels = tuple(labels)
71+
if len(labels) != bins:
72+
raise ValueError(
73+
f"Bin labels({len(labels)}) must be same as the value of bins({bins})"
74+
)
75+
5576
op = agg_ops.CutOp(bins, right=right, labels=labels)
5677
return x._apply_window_op(op, window_spec=window_specs.unbound())
5778
elif isinstance(bins, typing.Iterable):
@@ -64,9 +85,6 @@ def cut(
6485
elif len(list(bins)) == 0:
6586
as_index = pd.IntervalIndex.from_tuples(list(bins))
6687
bins = tuple()
67-
# To maintain consistency with pandas' behavior
68-
right = True
69-
labels = None
7088
elif isinstance(list(bins)[0], tuple):
7189
as_index = pd.IntervalIndex.from_tuples(list(bins))
7290
bins = tuple(bins)
@@ -88,8 +106,17 @@ def cut(
88106
raise ValueError("`bins` iterable should contain tuples or numerics.")
89107

90108
if as_index.is_overlapping:
91-
raise ValueError("Overlapping IntervalIndex is not accepted.")
92-
elif len(as_index) == 0:
109+
raise ValueError("Overlapping IntervalIndex is not accepted.") # TODO: test
110+
111+
if isinstance(labels, typing.Iterable):
112+
labels = tuple(labels)
113+
if len(labels) != len(as_index):
114+
raise ValueError(
115+
f"Bin labels({len(labels)}) must be same as the number of bin edges"
116+
f"({len(as_index)})"
117+
)
118+
119+
if len(as_index) == 0:
93120
dtype = agg_ops.CutOp(bins, right=right, labels=labels).output_type()
94121
return bigframes.series.Series(
95122
[pd.NA] * len(x),

bigframes/operations/aggregations.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ class CutOp(UnaryWindowOp):
340340
# TODO: Unintuitive, refactor into multiple ops?
341341
bins: typing.Union[int, Iterable]
342342
right: Optional[bool]
343-
labels: Optional[bool]
343+
labels: typing.Union[bool, Iterable[str], None]
344344

345345
@property
346346
def skips_nulls(self):
@@ -349,6 +349,8 @@ def skips_nulls(self):
349349
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
350350
if self.labels is False:
351351
return dtypes.INT_DTYPE
352+
elif isinstance(self.labels, Iterable):
353+
return dtypes.STRING_DTYPE
352354
else:
353355
# Assumption: buckets use same numeric type
354356
if isinstance(self.bins, int):

tests/system/small/test_pandas.py

+52-13
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,14 @@ def _convert_pandas_category(pd_s: pd.Series):
404404
f"Input must be a pandas Series with categorical data: {pd_s.dtype}"
405405
)
406406

407+
if pd.api.types.is_object_dtype(pd_s.cat.categories.dtype):
408+
return pd_s.astype(pd.StringDtype(storage="pyarrow"))
409+
410+
if not isinstance(pd_s.cat.categories.dtype, pd.IntervalDtype):
411+
raise ValueError(
412+
f"Must be a IntervalDtype with categorical data: {pd_s.cat.categories.dtype}"
413+
)
414+
407415
if pd_s.cat.categories.dtype.closed == "left": # type: ignore
408416
left_key = "left_inclusive"
409417
right_key = "right_exclusive"
@@ -465,6 +473,17 @@ def test_cut_by_int_bins(scalars_dfs, labels, right):
465473
pd.testing.assert_series_equal(bf_result.to_pandas(), pd_result)
466474

467475

476+
def test_cut_by_int_bins_w_labels(scalars_dfs):
477+
scalars_df, scalars_pandas_df = scalars_dfs
478+
479+
labels = ["A", "B", "C", "D", "E"]
480+
pd_result = pd.cut(scalars_pandas_df["float64_col"], 5, labels=labels)
481+
bf_result = bpd.cut(scalars_df["float64_col"], 5, labels=labels)
482+
483+
pd_result = _convert_pandas_category(pd_result)
484+
pd.testing.assert_series_equal(bf_result.to_pandas(), pd_result)
485+
486+
468487
@pytest.mark.parametrize(
469488
("breaks", "right", "labels"),
470489
[
@@ -494,7 +513,7 @@ def test_cut_by_int_bins(scalars_dfs, labels, right):
494513
),
495514
],
496515
)
497-
def test_cut_numeric_breaks(scalars_dfs, breaks, right, labels):
516+
def test_cut_by_numeric_breaks(scalars_dfs, breaks, right, labels):
498517
scalars_df, scalars_pandas_df = scalars_dfs
499518

500519
pd_result = pd.cut(
@@ -508,6 +527,18 @@ def test_cut_numeric_breaks(scalars_dfs, breaks, right, labels):
508527
pd.testing.assert_series_equal(bf_result, pd_result_converted)
509528

510529

530+
def test_cut_by_numeric_breaks_w_labels(scalars_dfs):
531+
scalars_df, scalars_pandas_df = scalars_dfs
532+
533+
bins = [0, 5, 10, 15, 20]
534+
labels = ["A", "B", "C", "D"]
535+
pd_result = pd.cut(scalars_pandas_df["float64_col"], bins, labels=labels)
536+
bf_result = bpd.cut(scalars_df["float64_col"], bins, labels=labels)
537+
538+
pd_result = _convert_pandas_category(pd_result)
539+
pd.testing.assert_series_equal(bf_result.to_pandas(), pd_result)
540+
541+
511542
@pytest.mark.parametrize(
512543
("bins", "right", "labels"),
513544
[
@@ -534,7 +565,7 @@ def test_cut_numeric_breaks(scalars_dfs, breaks, right, labels):
534565
),
535566
],
536567
)
537-
def test_cut_with_interval(scalars_dfs, bins, right, labels):
568+
def test_cut_by_interval_bins(scalars_dfs, bins, right, labels):
538569
scalars_df, scalars_pandas_df = scalars_dfs
539570
bf_result = bpd.cut(
540571
scalars_df["int64_too"], bins, labels=labels, right=right
@@ -548,22 +579,30 @@ def test_cut_with_interval(scalars_dfs, bins, right, labels):
548579
pd.testing.assert_series_equal(bf_result, pd_result_converted)
549580

550581

582+
def test_cut_by_interval_bins_w_labels(scalars_dfs):
583+
scalars_df, scalars_pandas_df = scalars_dfs
584+
585+
bins = pd.IntervalIndex.from_tuples([(1, 2), (2, 3), (4, 5)])
586+
labels = ["A", "B", "C", "D", "E"]
587+
pd_result = pd.cut(scalars_pandas_df["float64_col"], bins, labels=labels)
588+
bf_result = bpd.cut(scalars_df["float64_col"], bins, labels=labels)
589+
590+
pd_result = _convert_pandas_category(pd_result)
591+
pd.testing.assert_series_equal(bf_result.to_pandas(), pd_result)
592+
593+
551594
@pytest.mark.parametrize(
552-
"bins",
595+
("bins", "labels"),
553596
[
554-
pytest.param([], id="empty_breaks"),
555-
pytest.param(
556-
[1], id="single_int_breaks", marks=pytest.mark.skip(reason="b/404338651")
557-
),
558-
pytest.param(pd.IntervalIndex.from_tuples([]), id="empty_interval_index"),
597+
pytest.param([], None, id="empty_breaks"),
598+
pytest.param([1], False, id="single_int_breaks"),
599+
pytest.param(pd.IntervalIndex.from_tuples([]), None, id="empty_interval_index"),
559600
],
560601
)
561-
def test_cut_by_edge_cases_bins(scalars_dfs, bins):
602+
def test_cut_by_edge_cases_bins(scalars_dfs, bins, labels):
562603
scalars_df, scalars_pandas_df = scalars_dfs
563-
bf_result = bpd.cut(scalars_df["int64_too"], bins, labels=False).to_pandas()
564-
if isinstance(bins, list):
565-
bins = pd.IntervalIndex.from_tuples(bins)
566-
pd_result = pd.cut(scalars_pandas_df["int64_too"], bins, labels=False)
604+
bf_result = bpd.cut(scalars_df["int64_too"], bins, labels=labels).to_pandas()
605+
pd_result = pd.cut(scalars_pandas_df["int64_too"], bins, labels=labels)
567606

568607
pd_result_converted = _convert_pandas_category(pd_result)
569608
pd.testing.assert_series_equal(bf_result, pd_result_converted)

tests/unit/test_pandas.py

+51-6
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,48 @@ def test_method_matches_session(method_name: str):
9191
assert pandas_signature.return_annotation == session_signature.return_annotation
9292

9393

94-
def test_cut_raises_with_labels():
94+
@pytest.mark.parametrize(
95+
("bins", "labels", "error_message"),
96+
[
97+
pytest.param(
98+
5,
99+
True,
100+
"Bin labels must either be False, None or passed in as a list-like argument",
101+
id="true",
102+
),
103+
pytest.param(
104+
5,
105+
1.5,
106+
"Bin labels must either be False, None or passed in as a list-like argument",
107+
id="invalid_types",
108+
),
109+
pytest.param(
110+
2,
111+
["A"],
112+
"must be same as the value of bins",
113+
id="int_bins_mismatch",
114+
),
115+
pytest.param(
116+
[1, 2, 3],
117+
["A"],
118+
"must be same as the number of bin edges",
119+
id="iterator_bins_mismatch",
120+
),
121+
],
122+
)
123+
def test_cut_raises_with_invalid_labels(bins: int, labels, error_message: str):
124+
mock_series = mock.create_autospec(bigframes.pandas.Series, instance=True)
125+
with pytest.raises(ValueError, match=error_message):
126+
bigframes.pandas.cut(mock_series, bins, labels=labels)
127+
128+
129+
def test_cut_raises_with_unsupported_labels():
130+
mock_series = mock.create_autospec(bigframes.pandas.Series, instance=True)
131+
labels = [1, 2]
95132
with pytest.raises(
96-
NotImplementedError,
97-
match="The 'labels' parameter must be either False or None.",
133+
NotImplementedError, match=r".*only iterables of strings are supported.*"
98134
):
99-
mock_series = mock.create_autospec(bigframes.pandas.Series, instance=True)
100-
bigframes.pandas.cut(mock_series, 4, labels=["a", "b", "c", "d"])
135+
bigframes.pandas.cut(mock_series, 2, labels=labels) # type: ignore
101136

102137

103138
@pytest.mark.parametrize(
@@ -111,11 +146,21 @@ def test_cut_raises_with_labels():
111146
"`bins` iterable should contain tuples or numerics",
112147
id="iterable_w_wrong_type",
113148
),
149+
pytest.param(
150+
[10, 3],
151+
"left side of interval must be <= right side",
152+
id="decreased_breaks",
153+
),
154+
pytest.param(
155+
[(1, 10), (2, 25)],
156+
"Overlapping IntervalIndex is not accepted.",
157+
id="overlapping_intervals",
158+
),
114159
],
115160
)
116161
def test_cut_raises_with_invalid_bins(bins: int, error_message: str):
162+
mock_series = mock.create_autospec(bigframes.pandas.Series, instance=True)
117163
with pytest.raises(ValueError, match=error_message):
118-
mock_series = mock.create_autospec(bigframes.pandas.Series, instance=True)
119164
bigframes.pandas.cut(mock_series, bins, labels=False)
120165

121166

third_party/bigframes_vendored/pandas/core/reshape/tile.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ def cut(
3131
age ranges. Supports binning into an equal number of bins, or a
3232
pre-specified array of bins.
3333
34-
``labels=False`` implies you just want the bins back.
35-
3634
**Examples:**
3735
3836
>>> import bigframes.pandas as bpd
@@ -55,7 +53,16 @@ def cut(
5553
3 {'left_exclusive': 7.5, 'right_inclusive': 10.0}
5654
dtype: struct<left_exclusive: double, right_inclusive: double>[pyarrow]
5755
58-
Cut with an integer (equal-width bins) and labels=False:
56+
Cut with the same bins, but assign them specific labels:
57+
58+
>>> bpd.cut(s, bins=3, labels=["bad", "medium", "good"])
59+
0 bad
60+
1 bad
61+
2 medium
62+
3 good
63+
dtype: string
64+
65+
`labels=False` implies you want the bins back.
5966
6067
>>> bpd.cut(s, bins=4, labels=False)
6168
0 0
@@ -67,7 +74,6 @@ def cut(
6774
Cut with pd.IntervalIndex, requires importing pandas for IntervalIndex:
6875
6976
>>> import pandas as pd
70-
7177
>>> interval_index = pd.IntervalIndex.from_tuples([(0, 1), (1, 5), (5, 20)])
7278
>>> bpd.cut(s, bins=interval_index)
7379
0 <NA>
@@ -107,7 +113,7 @@ def cut(
107113
dtype: struct<left_inclusive: int64, right_exclusive: int64>[pyarrow]
108114
109115
Args:
110-
x (Series):
116+
x (bigframes.pandas.Series):
111117
The input Series to be binned. Must be 1-dimensional.
112118
bins (int, pd.IntervalIndex, Iterable):
113119
The criteria to bin by.
@@ -127,10 +133,11 @@ def cut(
127133
``right == True`` (the default), then the `bins` ``[1, 2, 3, 4]``
128134
indicate (1,2], (2,3], (3,4]. This argument is ignored when
129135
`bins` is an IntervalIndex.
130-
labels (default None):
136+
labels (bool, Iterable, default None):
131137
Specifies the labels for the returned bins. Must be the same length as
132138
the resulting bins. If False, returns only integer indicators of the
133-
bins. This affects the type of the output container.
139+
bins. This affects the type of the output container. This argument is
140+
ignored when `bins` is an IntervalIndex. If True, raises an error.
134141
135142
Returns:
136143
bigframes.pandas.Series:

0 commit comments

Comments
 (0)