@@ -404,6 +404,14 @@ def _convert_pandas_category(pd_s: pd.Series):
404
404
f"Input must be a pandas Series with categorical data: { pd_s .dtype } "
405
405
)
406
406
407
+ if pd .api .types .is_object_dtype (pd_s .cat .categories .dtype ):
408
+ return pd_s .astype (pd .StringDtype (storage = "pyarrow" ))
409
+
410
+ if not isinstance (pd_s .cat .categories .dtype , pd .IntervalDtype ):
411
+ raise ValueError (
412
+ f"Must be a IntervalDtype with categorical data: { pd_s .cat .categories .dtype } "
413
+ )
414
+
407
415
if pd_s .cat .categories .dtype .closed == "left" : # type: ignore
408
416
left_key = "left_inclusive"
409
417
right_key = "right_exclusive"
@@ -465,6 +473,17 @@ def test_cut_by_int_bins(scalars_dfs, labels, right):
465
473
pd .testing .assert_series_equal (bf_result .to_pandas (), pd_result )
466
474
467
475
476
+ def test_cut_by_int_bins_w_labels (scalars_dfs ):
477
+ scalars_df , scalars_pandas_df = scalars_dfs
478
+
479
+ labels = ["A" , "B" , "C" , "D" , "E" ]
480
+ pd_result = pd .cut (scalars_pandas_df ["float64_col" ], 5 , labels = labels )
481
+ bf_result = bpd .cut (scalars_df ["float64_col" ], 5 , labels = labels )
482
+
483
+ pd_result = _convert_pandas_category (pd_result )
484
+ pd .testing .assert_series_equal (bf_result .to_pandas (), pd_result )
485
+
486
+
468
487
@pytest .mark .parametrize (
469
488
("breaks" , "right" , "labels" ),
470
489
[
@@ -494,7 +513,7 @@ def test_cut_by_int_bins(scalars_dfs, labels, right):
494
513
),
495
514
],
496
515
)
497
- def test_cut_numeric_breaks (scalars_dfs , breaks , right , labels ):
516
+ def test_cut_by_numeric_breaks (scalars_dfs , breaks , right , labels ):
498
517
scalars_df , scalars_pandas_df = scalars_dfs
499
518
500
519
pd_result = pd .cut (
@@ -508,6 +527,18 @@ def test_cut_numeric_breaks(scalars_dfs, breaks, right, labels):
508
527
pd .testing .assert_series_equal (bf_result , pd_result_converted )
509
528
510
529
530
+ def test_cut_by_numeric_breaks_w_labels (scalars_dfs ):
531
+ scalars_df , scalars_pandas_df = scalars_dfs
532
+
533
+ bins = [0 , 5 , 10 , 15 , 20 ]
534
+ labels = ["A" , "B" , "C" , "D" ]
535
+ pd_result = pd .cut (scalars_pandas_df ["float64_col" ], bins , labels = labels )
536
+ bf_result = bpd .cut (scalars_df ["float64_col" ], bins , labels = labels )
537
+
538
+ pd_result = _convert_pandas_category (pd_result )
539
+ pd .testing .assert_series_equal (bf_result .to_pandas (), pd_result )
540
+
541
+
511
542
@pytest .mark .parametrize (
512
543
("bins" , "right" , "labels" ),
513
544
[
@@ -534,7 +565,7 @@ def test_cut_numeric_breaks(scalars_dfs, breaks, right, labels):
534
565
),
535
566
],
536
567
)
537
- def test_cut_with_interval (scalars_dfs , bins , right , labels ):
568
+ def test_cut_by_interval_bins (scalars_dfs , bins , right , labels ):
538
569
scalars_df , scalars_pandas_df = scalars_dfs
539
570
bf_result = bpd .cut (
540
571
scalars_df ["int64_too" ], bins , labels = labels , right = right
@@ -548,22 +579,30 @@ def test_cut_with_interval(scalars_dfs, bins, right, labels):
548
579
pd .testing .assert_series_equal (bf_result , pd_result_converted )
549
580
550
581
582
+ def test_cut_by_interval_bins_w_labels (scalars_dfs ):
583
+ scalars_df , scalars_pandas_df = scalars_dfs
584
+
585
+ bins = pd .IntervalIndex .from_tuples ([(1 , 2 ), (2 , 3 ), (4 , 5 )])
586
+ labels = ["A" , "B" , "C" , "D" , "E" ]
587
+ pd_result = pd .cut (scalars_pandas_df ["float64_col" ], bins , labels = labels )
588
+ bf_result = bpd .cut (scalars_df ["float64_col" ], bins , labels = labels )
589
+
590
+ pd_result = _convert_pandas_category (pd_result )
591
+ pd .testing .assert_series_equal (bf_result .to_pandas (), pd_result )
592
+
593
+
551
594
@pytest .mark .parametrize (
552
- "bins" ,
595
+ ( "bins" , "labels" ) ,
553
596
[
554
- pytest .param ([], id = "empty_breaks" ),
555
- pytest .param (
556
- [1 ], id = "single_int_breaks" , marks = pytest .mark .skip (reason = "b/404338651" )
557
- ),
558
- pytest .param (pd .IntervalIndex .from_tuples ([]), id = "empty_interval_index" ),
597
+ pytest .param ([], None , id = "empty_breaks" ),
598
+ pytest .param ([1 ], False , id = "single_int_breaks" ),
599
+ pytest .param (pd .IntervalIndex .from_tuples ([]), None , id = "empty_interval_index" ),
559
600
],
560
601
)
561
- def test_cut_by_edge_cases_bins (scalars_dfs , bins ):
602
+ def test_cut_by_edge_cases_bins (scalars_dfs , bins , labels ):
562
603
scalars_df , scalars_pandas_df = scalars_dfs
563
- bf_result = bpd .cut (scalars_df ["int64_too" ], bins , labels = False ).to_pandas ()
564
- if isinstance (bins , list ):
565
- bins = pd .IntervalIndex .from_tuples (bins )
566
- pd_result = pd .cut (scalars_pandas_df ["int64_too" ], bins , labels = False )
604
+ bf_result = bpd .cut (scalars_df ["int64_too" ], bins , labels = labels ).to_pandas ()
605
+ pd_result = pd .cut (scalars_pandas_df ["int64_too" ], bins , labels = labels )
567
606
568
607
pd_result_converted = _convert_pandas_category (pd_result )
569
608
pd .testing .assert_series_equal (bf_result , pd_result_converted )
0 commit comments