2
2
import numpy as np
3
3
import matplotlib .pyplot as plt
4
4
from abc import ABC , abstractmethod
5
- from typing import List ,Dict
5
+ from typing import List , Dict
6
6
import warnings
7
7
8
8
@@ -121,7 +121,7 @@ def __init__(self,
121
121
122
122
def plot (self ,
123
123
var : List [str ] = None ,
124
- value_order : Dict [str , Dict [int ,str ]] = None ,
124
+ value_order : Dict [str , Dict [int , str ]] = None ,
125
125
missing : bool = False ,
126
126
hi_missing : bool = False ,
127
127
missing_label_space : float = 1. ,
@@ -130,7 +130,7 @@ def plot(self,
130
130
hi_var : str = None ,
131
131
hi_value : List [str ] = None ,
132
132
color : List [str ] = None ,
133
- default_color = "blue " ,
133
+ default_color = "lightskyblue " ,
134
134
# Manipulating Spacing and Layout
135
135
bar_width : float = 1. ,
136
136
min_bar_width : float = .05 ,
@@ -152,7 +152,7 @@ def plot(self,
152
152
if self .data_df [col ].dtype .name == "category" :
153
153
self .data_df [col ] = self .data_df [col ].cat .add_categories (self .missing_data_placeholder )
154
154
elif "float" in self .data_df [col ].dtype .name :
155
- self .data_df [col ] = self .data_df [col ].apply (lambda x : np .round (x ,2 ))
155
+ self .data_df [col ] = self .data_df [col ].apply (lambda x : np .round (x , 2 ))
156
156
self .data_df_columns = self .data_df .columns .tolist ()
157
157
158
158
if not var_lst :
@@ -161,7 +161,6 @@ def plot(self,
161
161
)
162
162
163
163
if color and type (color ) != type ([]):
164
-
165
164
raise ValueError (
166
165
f'Argument "color" must be a list os str.'
167
166
)
@@ -179,9 +178,9 @@ def plot(self,
179
178
)
180
179
181
180
if value_order :
182
- for k ,v_ori in value_order .items ():
181
+ for k , v_ori in value_order .items ():
183
182
uni_val_set = set (self .data_df [k ].dropna ().unique ())
184
- v = [value_name for order ,value_name in v_ori .items ()]
183
+ v = [value_name for order , value_name in v_ori .items ()]
185
184
if not set (v ) >= uni_val_set :
186
185
error_values = (set (v ) ^ uni_val_set ) & set (v )
187
186
raise ValueError (
@@ -219,17 +218,18 @@ def plot(self,
219
218
self .hi_value .append (self .missing_data_placeholder )
220
219
else :
221
220
self .hi_value = [self .missing_data_placeholder ]
222
- colors = ["red" , "green" , "yellow" , "lightblue" , "orange" , "gray" , "brown" , "olive" , "pink" , "cyan" , "magenta" ]
221
+ colors = ["red" , "green" , "yellow" , "purple" , "orange" , "gray" , "brown" , "olive" , "pink" , "cyan" , "magenta" ]
223
222
self .color_lst = [color for color in color_lst ] if color_lst else (
224
223
colors [:len (self .hi_value )] if hi_var else None )
225
224
if hi_var :
226
225
if hi_value and len (self .color_lst ) < len (hi_value ):
227
- for i in range (len (hi_value )- len (self .color_lst )):
226
+ for i in range (len (hi_value ) - len (self .color_lst )):
228
227
for c in colors :
229
228
if c not in self .color_lst :
230
229
self .color_lst .append (c )
231
230
break
232
- warnings .warn (f"Warning: The length of color is less than the total number of (high values and missing), color was automatically extended to { self .color_lst } " )
231
+ warnings .warn (
232
+ f"Warning: The length of color is less than the total number of (high values and missing), color was automatically extended to { self .color_lst } " )
233
233
if hi_var and default_color in self .color_lst :
234
234
raise ValueError (
235
235
f'The current highlight colors { self .color_lst } conflict with the default color { default_color } . Please choose another default color or other highlight colors'
@@ -259,7 +259,7 @@ def plot(self,
259
259
raise ValueError (
260
260
f'the values: { error_values } in highlight value is not in data.'
261
261
)
262
-
262
+
263
263
value_color_dict = dict (zip (self .hi_value , self .color_lst ))
264
264
265
265
self .data_df [self .color_coloumn_placeholder ] = self .data_df [hi_var ].apply (
@@ -303,7 +303,7 @@ def plot(self,
303
303
ax , coordinates_dict = self ._list_labels (ax , self .height , self .width , self .label )
304
304
305
305
space = self .space * 10 if label else 0
306
- bar = self .bar_width * 3.5 / max (data_point_numbers )
306
+ bar = self .bar_width * 3.5 / max (data_point_numbers )
307
307
308
308
if self .shape == "parallelogram" :
309
309
figure_type = Parallelogram ()
@@ -325,10 +325,54 @@ def plot(self,
325
325
left_center_pts .append (left_coordinate )
326
326
right_center_pts .append (right_coordinate )
327
327
328
+ label_rectangle = True if self .label else False
329
+ label_rectangle_default_color = default_color
330
+ label_rectangle_widths = []
331
+ label_rectangle_total_obvs = {}
332
+ if label_rectangle :
333
+ label_rectangle_painter = Rectangle ()
334
+ label_rectangle_left_center_pts , label_rectangle_right_center_pts = [],[]
335
+ for k ,v in coordinates_dict .items ():
336
+
337
+ # get width for label rectangles by counting the number of observations for each value
338
+
339
+ col_name = k [0 ].split (self .same_var_placeholder )[0 ]
340
+ num_obv = self .data_df [col_name ].value_counts ().get (k [1 ], 0 )
341
+ label_rectangle_total_obvs [k ] = num_obv
342
+ label_rectangle_width = bar * num_obv
343
+ if self .min_bar_width and label_rectangle_width <= self .min_bar_width :
344
+ label_rectangle_width = self .min_bar_width
345
+
346
+ # get left and right coordinates for label rectangles
347
+ # add space for very thick label rectangles
348
+ half_label_rectangle_width = label_rectangle_width / 2
349
+ edge_adjust = self .max_y_range * 0.01
350
+ if v [1 ] - half_label_rectangle_width < 0 :
351
+ adjust_value = half_label_rectangle_width - v [1 ] + edge_adjust
352
+ label_rectangle_left_coordinate = (v [0 ]- space * 0.8 , v [1 ]+ adjust_value )
353
+ label_rectangle_right_coordinate = (v [0 ] + space * 0.8 , v [1 ]+ adjust_value )
354
+ elif v [1 ] + half_label_rectangle_width > self .max_y_range :
355
+ adjust_value = half_label_rectangle_width + v [1 ] - self .max_y_range + edge_adjust
356
+ label_rectangle_left_coordinate = (v [0 ]- space * 0.8 , v [1 ]- adjust_value )
357
+ label_rectangle_right_coordinate = (v [0 ] + space * 0.8 , v [1 ]- adjust_value )
358
+ else :
359
+ label_rectangle_left_coordinate = (v [0 ]- space * 0.8 , v [1 ])
360
+ label_rectangle_right_coordinate = (v [0 ] + space * 0.8 , v [1 ])
361
+
362
+ label_rectangle_left_center_pts .append (label_rectangle_left_coordinate )
363
+ label_rectangle_right_center_pts .append (label_rectangle_right_coordinate )
364
+
365
+
366
+
367
+ label_rectangle_widths .append (label_rectangle_width )
368
+
328
369
if not hi_var :
329
370
ax = figure_type .plot (ax , left_center_pts , right_center_pts , widths , default_color )
371
+ if label_rectangle :
372
+ ax = label_rectangle_painter .plot (ax , label_rectangle_left_center_pts , label_rectangle_right_center_pts , label_rectangle_widths ,label_rectangle_default_color )
330
373
else :
331
374
width_color_total = [0 ] * len (widths )
375
+ label_rectangle_width_color_total = [0 ] * len (coordinates_dict )
332
376
xs , ys = figure_type .get_coordinates (left_center_pts , right_center_pts , widths )
333
377
for color in self .color_lst :
334
378
widths_color , ratio_color_centers = [], []
@@ -353,6 +397,39 @@ def plot(self,
353
397
ratio_color_centers )
354
398
ax = figure_type .plot (ax , color_left_center_pts , color_right_center_pts , widths_color , color )
355
399
400
+ # always remember that color list was reversed, so the first color is the default color
401
+ if label_rectangle :
402
+ label_rectangle_total_obvs_color = label_rectangle_total_obvs .copy ()
403
+ for i ,color in enumerate (reversed (self .color_lst )):
404
+ the_hi_value = self .hi_value [i ] if i != len (self .color_lst )- 1 else None
405
+ label_rectangle_widths_color , label_rectangle_ratio_color_centers = [], []
406
+ idx = 0
407
+ for k ,v in coordinates_dict .items ():
408
+ col_name = k [0 ].split (self .same_var_placeholder )[0 ]
409
+ if k [1 ] == the_hi_value and self .hi_var == col_name :
410
+ label_rectangle_width_temp = label_rectangle_widths [idx ]
411
+ label_rectangle_total_obvs_color [k ] = 0
412
+ elif the_hi_value :
413
+ num_obv = self .data_df .groupby ([self .hi_var , col_name ]).size ().get ((the_hi_value , k [1 ]), 0 )
414
+ label_rectangle_width_temp = bar * num_obv
415
+ if self .min_bar_width and label_rectangle_width_temp <= self .min_bar_width and label_rectangle_width_temp != 0 :
416
+ label_rectangle_width_temp = self .min_bar_width
417
+ label_rectangle_total_obvs_color [k ] -= num_obv
418
+ else :
419
+ label_rectangle_width_temp = bar * label_rectangle_total_obvs_color [k ]
420
+ if self .min_bar_width and label_rectangle_width_temp <= self .min_bar_width and label_rectangle_width_temp != 0 :
421
+ label_rectangle_width_temp = self .min_bar_width
422
+ label_rectangle_widths_color .append (label_rectangle_width_temp )
423
+ label_rectangle_ratio_color_centers .append ((label_rectangle_width_color_total [idx ] + label_rectangle_width_temp / 2 ) / label_rectangle_widths [idx ])
424
+ label_rectangle_width_color_total [idx ] += label_rectangle_width_temp
425
+ idx += 1
426
+
427
+ xs , ys = label_rectangle_painter .get_coordinates (label_rectangle_left_center_pts , label_rectangle_right_center_pts , label_rectangle_widths )
428
+ color_left_center_pts , color_right_center_pts = label_rectangle_painter .get_center_highlight (xs , ys ,
429
+ label_rectangle_ratio_color_centers )
430
+ ax = label_rectangle_painter .plot (ax , color_left_center_pts , color_right_center_pts , label_rectangle_widths_color , color )
431
+ # ax = label_rectangle_painter.plot(ax, label_rectangle_left_center_pts, label_rectangle_right_center_pts, label_rectangle_widths,'green')
432
+
356
433
if display_figure :
357
434
ax .get_figure ()
358
435
else :
@@ -365,7 +442,7 @@ def plot(self,
365
442
366
443
def _get_varname (self , x ):
367
444
return x .split (self .same_var_placeholder )[:- 1 ][0 ]
368
-
445
+
369
446
def is_float (self , element : any ) -> bool :
370
447
if element is None :
371
448
return False
@@ -396,10 +473,10 @@ def _get_two_var(self, var_lst: List[str]):
396
473
397
474
return var_pair_lst
398
475
399
- def _gen_coordinate (self , start , n , edge , spacing , total_range ,val_type = "str" ):
476
+ def _gen_coordinate (self , start , n , edge , spacing , total_range , val_type = "str" ):
400
477
coor_lst = []
401
-
402
- if val_type == "str" :
478
+
479
+ if val_type == "str" :
403
480
for i in range (n ):
404
481
coor_lst .append (start + i * spacing )
405
482
@@ -411,17 +488,17 @@ def _gen_coordinate(self, start, n, edge, spacing, total_range,val_type="str"):
411
488
coor_lst .append (total_range + (start - edge ) - edge )
412
489
return coor_lst
413
490
414
- def _get_same_scale_minmax (self ,original_unique_value ):
415
- min ,max = 0 ,0
416
- for i ,varname in enumerate (self .same_scale ):
491
+ def _get_same_scale_minmax (self , original_unique_value ):
492
+ min , max = 0 , 0
493
+ for i , varname in enumerate (self .same_scale ):
417
494
var_type = str (self .data_df_origin [varname ].dtype .name )
418
495
if "int" in var_type or "float" in var_type :
419
496
min_val , max_val = original_unique_value [varname ][0 ], original_unique_value [varname ][- 1 ]
420
497
if i == 0 :
421
- min ,max = min_val , max_val
498
+ min , max = min_val , max_val
422
499
else :
423
- min = min_val if min_val < min else min
424
- max = max_val if max_val > max else max
500
+ min = min_val if min_val < min else min
501
+ max = max_val if max_val > max else max
425
502
426
503
else :
427
504
min_val , max_val = 1 , len (original_unique_value [varname ])
@@ -430,40 +507,44 @@ def _get_same_scale_minmax(self,original_unique_value):
430
507
else :
431
508
min = min_val if min_val < min else min
432
509
max = max_val if max_val > max else max
433
- return (min ,max )
510
+ return (min , max )
434
511
435
512
def _list_labels (self , ax , figsize_y , figsize_x , label ):
436
513
437
514
scale = 10
438
515
edge_scale = 10
439
516
y_range = scale * figsize_y - self .missing_label_space * scale if self .missing else scale * figsize_y
440
517
x_range = scale * figsize_x
518
+ self .max_y_range , self .max_x_range = scale * figsize_y , scale * figsize_x
441
519
edge_x_range = x_range / edge_scale
442
520
edge_y_range = y_range / edge_scale
521
+ # self.edge_y_range, self.edge_x_range = edge_y_range, edge_x_range
443
522
y_start = edge_y_range + self .missing_label_space * scale if self .missing else edge_y_range
444
523
coordinates_dict = {}
445
524
446
525
unique_value = []
447
526
original_unique_value = {}
448
527
varname_lst = [self ._get_varname (var ) for var in self .var_lst ]
449
-
528
+
450
529
for var , varname in zip (self .var_lst , varname_lst ):
451
530
unique_valnames = self .data_df [varname ].dropna ().unique ().tolist ()
452
531
sorted_unique_valnames = []
453
532
if self .value_order and varname in self .value_order :
454
533
varname_value_order_dict = self .value_order [varname ]
455
534
sorted_unique_valnames_temp = [v for k , v in
456
- sorted (varname_value_order_dict .items (), key = lambda item : item [0 ])]
535
+ sorted (varname_value_order_dict .items (), key = lambda item : item [0 ])]
457
536
for v in sorted_unique_valnames_temp :
458
537
if v in unique_valnames :
459
538
sorted_unique_valnames .append (v )
460
539
if self .missing_data_placeholder in unique_valnames :
461
540
unique_valnames .remove (self .missing_data_placeholder )
462
- sorted_unique_valnames = sorted (unique_valnames ) if not sorted_unique_valnames else sorted_unique_valnames
541
+ sorted_unique_valnames = sorted (
542
+ unique_valnames ) if not sorted_unique_valnames else sorted_unique_valnames
463
543
original_unique_value [varname ] = sorted_unique_valnames .copy ()
464
544
sorted_unique_valnames .append (self .missing_data_placeholder )
465
545
else :
466
- sorted_unique_valnames = sorted (unique_valnames ) if not sorted_unique_valnames else sorted_unique_valnames
546
+ sorted_unique_valnames = sorted (
547
+ unique_valnames ) if not sorted_unique_valnames else sorted_unique_valnames
467
548
original_unique_value [varname ] = sorted_unique_valnames .copy ()
468
549
unique_value .append ([(var , x ) for x in sorted_unique_valnames ])
469
550
@@ -478,11 +559,11 @@ def _list_labels(self, ax, figsize_y, figsize_x, label):
478
559
479
560
# prepare for same_scale variabels
480
561
if self .same_scale :
481
- same_scale_min ,same_scale_max = self ._get_same_scale_minmax (original_unique_value )
482
- same_scale_range = same_scale_max - same_scale_min
562
+ same_scale_min , same_scale_max = self ._get_same_scale_minmax (original_unique_value )
563
+ same_scale_range = same_scale_max - same_scale_min
483
564
484
565
# plot labels for each variables
485
- for var_i ,(x , uni_val ) in enumerate (zip (label_coordinates , unique_value )):
566
+ for var_i , (x , uni_val ) in enumerate (zip (label_coordinates , unique_value )):
486
567
label_num = len (uni_val ) - 2 if (uni_val [0 ][0 ], self .missing_data_placeholder ) in uni_val else len (
487
568
uni_val ) - 1
488
569
varname = varname_lst [var_i ]
@@ -493,21 +574,22 @@ def _list_labels(self, ax, figsize_y, figsize_x, label):
493
574
temp_value_range = (y_range - 2 * edge_y_range )
494
575
# handle the variables in same_scale
495
576
if self .same_scale and varname in self .same_scale :
496
- min_val , max_val = same_scale_min ,same_scale_max
577
+ min_val , max_val = same_scale_min , same_scale_max
497
578
else :
498
- min_val ,max_val = original_unique_value [varname ][0 ],original_unique_value [varname ][- 1 ]
499
- value_interval = [temp_value_range * (x_val - min_val )/ (max_val - min_val ) for x_val in original_unique_value [varname ]]
579
+ min_val , max_val = original_unique_value [varname ][0 ], original_unique_value [varname ][- 1 ]
580
+ value_interval = [temp_value_range * (x_val - min_val ) / (max_val - min_val ) for x_val in
581
+ original_unique_value [varname ]]
500
582
uni_val_coordinates = self ._gen_coordinate (y_start , label_num , edge_y_range ,
501
- value_interval , y_range ,val_type = "number" )
583
+ value_interval , y_range , val_type = "number" )
502
584
else :
503
585
# handle the variables in same_scale
504
586
if self .same_scale and varname in self .same_scale :
505
587
temp_value_range = (y_range - 2 * edge_y_range )
506
- quant_val = list (range (1 ,len (original_unique_value [varname ])+ 1 ))
588
+ quant_val = list (range (1 , len (original_unique_value [varname ]) + 1 ))
507
589
min_val , max_val = same_scale_min , same_scale_max
508
590
value_interval = [temp_value_range * (x_val - min_val ) / (max_val - min_val ) for x_val in quant_val ]
509
591
uni_val_coordinates = self ._gen_coordinate (y_start , label_num , edge_y_range ,
510
- value_interval , y_range , val_type = "number" )
592
+ value_interval , y_range , val_type = "number" )
511
593
else :
512
594
value_interval = (y_range - 2 * edge_y_range ) / (label_num )
513
595
uni_val_coordinates = self ._gen_coordinate (y_start , label_num , edge_y_range ,
@@ -530,9 +612,6 @@ def _list_labels(self, ax, figsize_y, figsize_x, label):
530
612
else :
531
613
ax .text (x , y , val [1 ], ha = 'center' , va = 'center' )
532
614
coordinates_dict [val ] = (x , y )
533
-
534
-
535
615
return ax , coordinates_dict
536
616
537
617
538
-
0 commit comments