1
1
import pandas as pd
2
2
import numpy as np
3
3
import matplotlib .pyplot as plt
4
+ import matplotlib .colors as mcolors
4
5
from abc import ABC , abstractmethod
5
6
from typing import List , Dict
7
+ import math
6
8
import warnings
7
9
8
10
@@ -52,7 +54,7 @@ def get_coordinates(self, left_center_pts, right_center_pts, widths):
52
54
xs , ys = [], []
53
55
for l , r , w in zip (left_center_pts , right_center_pts , widths ):
54
56
x , y = np .zeros (4 ), np .zeros (4 )
55
- alpha = np .arctan (abs (l [1 ] - r [1 ]) / abs (l [0 ] - r [0 ]))
57
+ alpha = np .arctan (abs (l [1 ] - r [1 ]) / abs (l [0 ] - r [0 ])) if l [ 0 ] != r [ 0 ] else np . arctan ( np . inf )
56
58
vertical_w = w / np .cos (alpha )
57
59
58
60
x [0 :2 ], x [2 :4 ] = l [0 ], r [0 ]
@@ -73,7 +75,7 @@ def get_coordinates(self, left_center_pts, right_center_pts, widths):
73
75
xs , ys = [], []
74
76
for l , r , w in zip (left_center_pts , right_center_pts , widths ):
75
77
x , y = np .zeros (4 ), np .zeros (4 )
76
- alpha = np .arctan (abs (l [1 ] - r [1 ]) / abs (l [0 ] - r [0 ]))
78
+ alpha = np .arctan (abs (l [1 ] - r [1 ]) / abs (l [0 ] - r [0 ])) if l [ 0 ] != r [ 0 ] else np . arctan ( np . inf )
77
79
vertical_w = w / np .cos (alpha )
78
80
79
81
ax , ay = l [0 ], l [1 ] + vertical_w / 2
@@ -91,7 +93,7 @@ def get_coordinates(self, left_center_pts, right_center_pts, widths):
91
93
new_cy = cy + l [1 ] - ey
92
94
93
95
x [0 ], x [1 ] = new_ax , new_cx
94
- y [0 ], y [1 ] = (new_ay , new_cy ) if l [1 ] < r [1 ] else (new_cy , new_ay )
96
+ y [0 ], y [1 ] = (new_ay , new_cy ) if l [1 ] <= r [1 ] else (new_cy , new_ay )
95
97
96
98
x [2 ], x [3 ] = x [0 ] + r [0 ] - l [0 ], x [1 ] + r [0 ] - l [0 ]
97
99
y [2 ], y [3 ] = y [0 ] + r [1 ] - l [1 ], y [1 ] + r [1 ] - l [1 ]
@@ -129,6 +131,7 @@ def plot(self,
129
131
# Highlighting
130
132
hi_var : str = None ,
131
133
hi_value : List [str ] = None ,
134
+ hi_box : str = "vertical" ,
132
135
color : List [str ] = None ,
133
136
default_color = "lightskyblue" ,
134
137
# Manipulating Spacing and Layout
@@ -147,7 +150,7 @@ def plot(self,
147
150
148
151
var_lst = var
149
152
# parse the color input to enable intensity feature
150
- color_lst = color_lst = self ._parse_colors (color )
153
+ color_lst = self ._parse_colors (color )
151
154
self .data_df = self .data_df_origin .copy ()
152
155
for col in self .data_df :
153
156
if self .data_df [col ].dtype .name == "category" :
@@ -337,6 +340,9 @@ def plot(self,
337
340
label_rectangle_default_color = default_color
338
341
label_rectangle_widths = []
339
342
label_rectangle_total_obvs = {}
343
+
344
+ label_rectangle_vertical = True if hi_box == "vertical" else False
345
+
340
346
if label_rectangle :
341
347
label_rectangle_painter = Rectangle ()
342
348
label_rectangle_left_center_pts , label_rectangle_right_center_pts = [],[]
@@ -384,6 +390,7 @@ def plot(self,
384
390
label_rectangle_width_color_total = [0 ] * len (coordinates_dict )
385
391
xs , ys = figure_type .get_coordinates (left_center_pts , right_center_pts , widths )
386
392
if not space_univar :
393
+
387
394
for color in self .color_lst :
388
395
widths_color , ratio_color_centers = [], []
389
396
index = 0
@@ -409,22 +416,26 @@ def plot(self,
409
416
410
417
# always remember that color list was reversed, so the first color is the default color
411
418
if label_rectangle :
419
+ # if not label_rectangle_vertical:
412
420
label_rectangle_total_obvs_color = label_rectangle_total_obvs .copy ()
413
421
for i ,color in enumerate (reversed (self .color_lst )):
414
422
the_hi_value = self .hi_value [i ] if i != len (self .color_lst )- 1 else None
415
423
label_rectangle_widths_color , label_rectangle_ratio_color_centers = [], []
416
424
idx = 0
417
425
for k ,v in coordinates_dict .items ():
418
426
col_name = k [0 ].split (self .same_var_placeholder )[0 ]
427
+ # case1: hi_value and hi_var
419
428
if k [1 ] == the_hi_value and self .hi_var == col_name :
420
429
label_rectangle_width_temp = label_rectangle_widths [idx ]
421
430
label_rectangle_total_obvs_color [k ] = 0
431
+ # case2: for other colors
422
432
elif the_hi_value :
423
433
num_obv = self .data_df .groupby ([self .hi_var , col_name ]).size ().get ((the_hi_value , k [1 ]), 0 )
424
434
label_rectangle_width_temp = bar * num_obv
425
435
if self .min_bar_width and label_rectangle_width_temp <= self .min_bar_width and label_rectangle_width_temp != 0 :
426
436
label_rectangle_width_temp = self .min_bar_width
427
437
label_rectangle_total_obvs_color [k ] -= num_obv
438
+ # case3: for the last default color
428
439
else :
429
440
label_rectangle_width_temp = bar * label_rectangle_total_obvs_color [k ]
430
441
if self .min_bar_width and label_rectangle_width_temp <= self .min_bar_width and label_rectangle_width_temp != 0 :
@@ -434,11 +445,37 @@ def plot(self,
434
445
label_rectangle_width_color_total [idx ] += label_rectangle_width_temp
435
446
idx += 1
436
447
437
- xs , ys = label_rectangle_painter .get_coordinates (label_rectangle_left_center_pts , label_rectangle_right_center_pts , label_rectangle_widths )
438
- color_left_center_pts , color_right_center_pts = label_rectangle_painter .get_center_highlight (xs , ys ,
439
- label_rectangle_ratio_color_centers )
440
- ax = label_rectangle_painter .plot (ax , color_left_center_pts , color_right_center_pts , label_rectangle_widths_color , color )
441
- # ax = label_rectangle_painter.plot(ax, label_rectangle_left_center_pts, label_rectangle_right_center_pts, label_rectangle_widths,'green')
448
+ if not label_rectangle_vertical :
449
+ xs , ys = label_rectangle_painter .get_coordinates (label_rectangle_left_center_pts , label_rectangle_right_center_pts , label_rectangle_widths )
450
+ color_left_center_pts , color_right_center_pts = label_rectangle_painter .get_center_highlight (xs , ys ,
451
+ label_rectangle_ratio_color_centers )
452
+ ax = label_rectangle_painter .plot (ax , color_left_center_pts , color_right_center_pts , label_rectangle_widths_color , color )
453
+ # ax = label_rectangle_painter.plot(ax, label_rectangle_left_center_pts, label_rectangle_right_center_pts, label_rectangle_widths,'green')
454
+ else :
455
+ xs , ys = label_rectangle_painter .get_coordinates (label_rectangle_left_center_pts , label_rectangle_right_center_pts , label_rectangle_widths )
456
+
457
+ # switch the order of xs and ys temporarily for calculating the vertical coordinates
458
+ vertical_xs = [np .array ([x [1 ],x [3 ],x [0 ],x [2 ]]) for x in xs ]
459
+ vertical_ys = [np .array ([y [1 ],y [3 ],y [0 ],y [2 ]]) for y in ys ]
460
+
461
+ vertical_color_left_center_pts , vertical_color_right_center_pts = label_rectangle_painter .get_center_highlight (vertical_xs , vertical_ys ,
462
+ label_rectangle_ratio_color_centers )
463
+ # switch back to normal coordinates
464
+ vertical_label_rectangle_widths_color = [space * 0.8 * 2 * (w_color / w ) for w_color ,w in zip (label_rectangle_widths_color , label_rectangle_widths )]
465
+ vertical_color_left_center_pts , vertical_color_right_center_pts , vertical_label_rectangle_widths_color = self ._compute_left_right_centers (
466
+ vertical_color_left_center_pts , vertical_color_right_center_pts , vertical_label_rectangle_widths_color
467
+ )
468
+
469
+
470
+ ax = label_rectangle_painter .plot (ax , vertical_color_left_center_pts , vertical_color_right_center_pts ,vertical_label_rectangle_widths_color , color )
471
+
472
+
473
+ # else:
474
+ # vertical_label_rectangle_left_center_pts = [(pt2[0],pt1[0]) for pt1, pt2 in zip(label_rectangle_left_center_pts,label_rectangle_right_center_pts)]
475
+ # vertical_label_rectangle_right_center_pts = [(pt1[1],pt2[0]) for pt1, pt2 in zip(label_rectangle_left_center_pts,label_rectangle_right_center_pts)]
476
+
477
+ # label_rectangle_total_obvs_color = label_rectangle_total_obvs.copy()
478
+ # for i,color in enumerate(reversed(self.color_lst)):
442
479
443
480
if display_figure :
444
481
ax .get_figure ()
@@ -647,4 +684,64 @@ def _parse_colors(self, color_list):
647
684
648
685
return parsed_colors
649
686
687
+ def _compute_left_right_centers (self , midpoints_top , midpoints_bottom , lengths ):
688
+ """
689
+ Given the top and bottom midpoints of one or more aligned rectangles, along with their width,
690
+ compute the midpoints and lengths of the left and right centers.
691
+
692
+ Parameters:
693
+ --------
694
+ midpoints_top : list of (float, float)
695
+ List of top edge midpoints (Tx, Ty)
696
+ midpoints_bottom : list of (float, float)
697
+ List of bottom edge midpoints (Bx, By)
698
+ length : float
699
+ The width of the rectangles (assumed to be the same for all)
700
+
701
+ Returns:
702
+ --------
703
+ left_midpoints : list of (float, float)
704
+ Midpoints of the left edges
705
+ right_midpoints : list of (float, float)
706
+ Midpoints of the right edges
707
+ edge_length : float
708
+ The height of the rectangles (same for both left and right edges)
709
+ """
710
+ n = len (midpoints_top )
711
+
712
+ left_midpoints = []
713
+ right_midpoints = []
714
+ edge_lengths = []
715
+
716
+ for i in range (n ):
717
+ Tx , Ty = midpoints_top [i ]
718
+ Bx , By = midpoints_bottom [i ]
719
+
720
+ Ty , By = (By , Ty ) if Ty < By else (Ty , By )
721
+
722
+ length = lengths [i ]
723
+
724
+ # Compute the vertical edge length
725
+ edge_length = abs (By - Ty )
726
+
727
+ if length != 0 :
728
+ # Compute left and right edge midpoints
729
+ Lmx = Tx - length / 2
730
+ Lmy = Ty - edge_length / 2
731
+ Rmx = Tx + length / 2
732
+ Rmy = By + edge_length / 2
733
+ else :
734
+ Lmx = Tx - length / 2
735
+ Lmy = Ty - edge_length / 2
736
+ Rmx = Lmx
737
+ Rmy = By + edge_length / 2
738
+ edge_length = 0
739
+
740
+ left_midpoints .append ((Lmx , Lmy ))
741
+ right_midpoints .append ((Rmx , Rmy ))
742
+ edge_lengths .append (edge_length )
743
+
744
+ return left_midpoints , right_midpoints , edge_lengths
745
+
746
+
650
747
0 commit comments