Skip to content

Commit b2375de

Browse files
authored
fix: pandas.cut returns labels index for numeric breaks when labels=False (#1548)
* refactor tests to detect the bug and other edge cases bugs * fix: pandas.cut returns labels for numeric breaks when labels=False
1 parent bcac8c6 commit b2375de

File tree

4 files changed

+176
-132
lines changed

4 files changed

+176
-132
lines changed

bigframes/core/compile/aggregate_compiler.py

+37-45
Original file line numberDiff line numberDiff line change
@@ -360,69 +360,61 @@ def _(
360360
if isinstance(op.bins, int):
361361
col_min = _apply_window_if_present(x.min(), window)
362362
col_max = _apply_window_if_present(x.max(), window)
363+
adj = (col_max - col_min) * 0.001
363364
bin_width = (col_max - col_min) / op.bins
364365

365-
if op.labels is False:
366-
for this_bin in range(op.bins - 1):
367-
if op.right:
368-
case_expr = x <= (col_min + (this_bin + 1) * bin_width)
369-
else:
370-
case_expr = x < (col_min + (this_bin + 1) * bin_width)
371-
out = out.when(
372-
case_expr,
373-
compile_ibis_types.literal_to_ibis_scalar(
374-
this_bin, force_dtype=pd.Int64Dtype()
375-
),
366+
for this_bin in range(op.bins):
367+
if op.labels is False:
368+
value = compile_ibis_types.literal_to_ibis_scalar(
369+
this_bin, force_dtype=pd.Int64Dtype()
376370
)
377-
out = out.when(x.notnull(), op.bins - 1)
378-
else:
379-
interval_struct = None
380-
adj = (col_max - col_min) * 0.001
381-
for this_bin in range(op.bins):
382-
left_edge_adj = adj if this_bin == 0 and op.right else 0
383-
right_edge_adj = adj if this_bin == op.bins - 1 and not op.right else 0
371+
else:
372+
left_adj = adj if this_bin == 0 and op.right else 0
373+
right_adj = adj if this_bin == op.bins - 1 and not op.right else 0
384374

385-
left_edge = col_min + this_bin * bin_width - left_edge_adj
386-
right_edge = col_min + (this_bin + 1) * bin_width + right_edge_adj
375+
left = col_min + this_bin * bin_width - left_adj
376+
right = col_min + (this_bin + 1) * bin_width + right_adj
387377

388378
if op.right:
389-
interval_struct = ibis_types.struct(
390-
{
391-
"left_exclusive": left_edge,
392-
"right_inclusive": right_edge,
393-
}
379+
value = ibis_types.struct(
380+
{"left_exclusive": left, "right_inclusive": right}
394381
)
395382
else:
396-
interval_struct = ibis_types.struct(
397-
{
398-
"left_inclusive": left_edge,
399-
"right_exclusive": right_edge,
400-
}
383+
value = ibis_types.struct(
384+
{"left_inclusive": left, "right_exclusive": right}
401385
)
402-
403-
if this_bin < op.bins - 1:
404-
if op.right:
405-
case_expr = x <= (col_min + (this_bin + 1) * bin_width)
406-
else:
407-
case_expr = x < (col_min + (this_bin + 1) * bin_width)
408-
out = out.when(case_expr, interval_struct)
386+
if this_bin == op.bins - 1:
387+
case_expr = x.notnull()
388+
else:
389+
if op.right:
390+
case_expr = x <= (col_min + (this_bin + 1) * bin_width)
409391
else:
410-
out = out.when(x.notnull(), interval_struct)
392+
case_expr = x < (col_min + (this_bin + 1) * bin_width)
393+
out = out.when(case_expr, value)
411394
else: # Interpret as intervals
412-
for interval in op.bins:
395+
for this_bin, interval in enumerate(op.bins):
413396
left = compile_ibis_types.literal_to_ibis_scalar(interval[0])
414397
right = compile_ibis_types.literal_to_ibis_scalar(interval[1])
415398
if op.right:
416399
condition = (x > left) & (x <= right)
417-
interval_struct = ibis_types.struct(
418-
{"left_exclusive": left, "right_inclusive": right}
419-
)
420400
else:
421401
condition = (x >= left) & (x < right)
422-
interval_struct = ibis_types.struct(
423-
{"left_inclusive": left, "right_exclusive": right}
402+
403+
if op.labels is False:
404+
value = compile_ibis_types.literal_to_ibis_scalar(
405+
this_bin, force_dtype=pd.Int64Dtype()
424406
)
425-
out = out.when(condition, interval_struct)
407+
else:
408+
if op.right:
409+
value = ibis_types.struct(
410+
{"left_exclusive": left, "right_inclusive": right}
411+
)
412+
else:
413+
value = ibis_types.struct(
414+
{"left_inclusive": left, "right_exclusive": right}
415+
)
416+
417+
out = out.when(condition, value)
426418
return out.end()
427419

428420

bigframes/core/reshape/tile.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def cut(
4646
"The 'labels' parameter must be either False or None. "
4747
"Please provide a valid value for 'labels'."
4848
)
49+
if x.size == 0:
50+
raise ValueError("Cannot cut empty array.")
4951

5052
if isinstance(bins, int):
5153
if bins <= 0:
@@ -58,14 +60,19 @@ def cut(
5860
bins = tuple((bin.left.item(), bin.right.item()) for bin in bins)
5961
# To maintain consistency with pandas' behavior
6062
right = True
63+
labels = None
6164
elif len(list(bins)) == 0:
6265
as_index = pd.IntervalIndex.from_tuples(list(bins))
6366
bins = tuple()
67+
# To maintain consistency with pandas' behavior
68+
right = True
69+
labels = None
6470
elif isinstance(list(bins)[0], tuple):
6571
as_index = pd.IntervalIndex.from_tuples(list(bins))
6672
bins = tuple(bins)
6773
# To maintain consistency with pandas' behavior
6874
right = True
75+
labels = None
6976
elif pd.api.types.is_number(list(bins)[0]):
7077
bins_list = list(bins)
7178
as_index = pd.IntervalIndex.from_breaks(bins_list)
@@ -83,9 +90,13 @@ def cut(
8390
if as_index.is_overlapping:
8491
raise ValueError("Overlapping IntervalIndex is not accepted.")
8592
elif len(as_index) == 0:
86-
op = agg_ops.CutOp(bins, right=right, labels=labels)
93+
dtype = agg_ops.CutOp(bins, right=right, labels=labels).output_type()
8794
return bigframes.series.Series(
88-
[pd.NA] * len(x), dtype=op.output_type(), name=x.name
95+
[pd.NA] * len(x),
96+
dtype=dtype,
97+
name=x.name,
98+
index=x.index,
99+
session=x._session,
89100
)
90101
else:
91102
op = agg_ops.CutOp(bins, right=right, labels=labels)

bigframes/operations/aggregations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def skips_nulls(self):
347347
return False
348348

349349
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
350-
if isinstance(self.bins, int) and (self.labels is False):
350+
if self.labels is False:
351351
return dtypes.INT_DTYPE
352352
else:
353353
# Assumption: buckets use same numeric type

0 commit comments

Comments
 (0)