Skip to content

Commit 76d31e0

Browse files
authored
Merge pull request #8 from TianchengY/hi_box
add hi_box feature
2 parents edcabe0 + 94798a1 commit 76d31e0

File tree

2 files changed

+107
-9
lines changed

2 files changed

+107
-9
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ satisfied respondents simply choose the highest value.
111111
| | `label` | `bool` | Whether or not to display labels between the plotting segments |
112112
| Highlighting | `hi_var` | `str` | Variable to be highlighted. Default is none. |
113113
| | `hi_value` | `List[str or int]` | List of values of `hi_var` to be highlighted. You can highlighted one or multiple values. |
114+
| | `hi_box` | `str` | Controls how highlighted values are displayed within category labels. Options are "vertical" for vertically stacked color segments or "horizontal" for horizontally split color segments. Default is "vertical".|
114115
| | `hi_missing` | `bool` | Whether or not missing values for `hi_var` should be highlighted. |
115116
| | `color` | `List[str]` | List of colors corresponding to the list of values to be highlighted. Each color can be specified as a plain color name (e.g., `"red"`, `"yellow"`) or in the format `"color=alpha"` (e.g., `"red=0.5"`) to control transparency/intensity, where `alpha` is a decimal between 0 and 1. The default highlight color list is `["red", "green", "yellow", "lightblue", "orange", "gray", "brown", "olive", "pink", "cyan", "magenta"]`. |
116117
| | `default_color` | `str` | Default color of plotting elements for boxes that are not highlighted. Default is "blue" |

hammock_plot/hammock_plot.py

Lines changed: 106 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import pandas as pd
22
import numpy as np
33
import matplotlib.pyplot as plt
4+
import matplotlib.colors as mcolors
45
from abc import ABC, abstractmethod
56
from typing import List, Dict
7+
import math
68
import warnings
79

810

@@ -52,7 +54,7 @@ def get_coordinates(self, left_center_pts, right_center_pts, widths):
5254
xs, ys = [], []
5355
for l, r, w in zip(left_center_pts, right_center_pts, widths):
5456
x, y = np.zeros(4), np.zeros(4)
55-
alpha = np.arctan(abs(l[1] - r[1]) / abs(l[0] - r[0]))
57+
alpha = np.arctan(abs(l[1] - r[1]) / abs(l[0] - r[0])) if l[0] != r[0] else np.arctan(np.inf)
5658
vertical_w = w / np.cos(alpha)
5759

5860
x[0:2], x[2:4] = l[0], r[0]
@@ -73,7 +75,7 @@ def get_coordinates(self, left_center_pts, right_center_pts, widths):
7375
xs, ys = [], []
7476
for l, r, w in zip(left_center_pts, right_center_pts, widths):
7577
x, y = np.zeros(4), np.zeros(4)
76-
alpha = np.arctan(abs(l[1] - r[1]) / abs(l[0] - r[0]))
78+
alpha = np.arctan(abs(l[1] - r[1]) / abs(l[0] - r[0])) if l[0] != r[0] else np.arctan(np.inf)
7779
vertical_w = w / np.cos(alpha)
7880

7981
ax, ay = l[0], l[1] + vertical_w / 2
@@ -91,7 +93,7 @@ def get_coordinates(self, left_center_pts, right_center_pts, widths):
9193
new_cy = cy + l[1] - ey
9294

9395
x[0], x[1] = new_ax, new_cx
94-
y[0], y[1] = (new_ay, new_cy) if l[1] < r[1] else (new_cy, new_ay)
96+
y[0], y[1] = (new_ay, new_cy) if l[1] <= r[1] else (new_cy, new_ay)
9597

9698
x[2], x[3] = x[0] + r[0] - l[0], x[1] + r[0] - l[0]
9799
y[2], y[3] = y[0] + r[1] - l[1], y[1] + r[1] - l[1]
@@ -129,6 +131,7 @@ def plot(self,
129131
# Highlighting
130132
hi_var: str = None,
131133
hi_value: List[str] = None,
134+
hi_box: str = "vertical",
132135
color: List[str] = None,
133136
default_color="lightskyblue",
134137
# Manipulating Spacing and Layout
@@ -147,7 +150,7 @@ def plot(self,
147150

148151
var_lst = var
149152
# parse the color input to enable intensity feature
150-
color_lst = color_lst = self._parse_colors(color)
153+
color_lst = self._parse_colors(color)
151154
self.data_df = self.data_df_origin.copy()
152155
for col in self.data_df:
153156
if self.data_df[col].dtype.name == "category":
@@ -337,6 +340,9 @@ def plot(self,
337340
label_rectangle_default_color = default_color
338341
label_rectangle_widths = []
339342
label_rectangle_total_obvs = {}
343+
344+
label_rectangle_vertical = True if hi_box == "vertical" else False
345+
340346
if label_rectangle:
341347
label_rectangle_painter = Rectangle()
342348
label_rectangle_left_center_pts, label_rectangle_right_center_pts = [],[]
@@ -384,6 +390,7 @@ def plot(self,
384390
label_rectangle_width_color_total = [0] * len(coordinates_dict)
385391
xs, ys = figure_type.get_coordinates(left_center_pts, right_center_pts, widths)
386392
if not space_univar:
393+
387394
for color in self.color_lst:
388395
widths_color, ratio_color_centers = [], []
389396
index = 0
@@ -409,22 +416,26 @@ def plot(self,
409416

410417
# always remember that color list was reversed, so the first color is the default color
411418
if label_rectangle:
419+
# if not label_rectangle_vertical:
412420
label_rectangle_total_obvs_color = label_rectangle_total_obvs.copy()
413421
for i,color in enumerate(reversed(self.color_lst)):
414422
the_hi_value = self.hi_value[i] if i != len(self.color_lst)-1 else None
415423
label_rectangle_widths_color, label_rectangle_ratio_color_centers = [], []
416424
idx=0
417425
for k,v in coordinates_dict.items():
418426
col_name = k[0].split(self.same_var_placeholder)[0]
427+
# case1: hi_value and hi_var
419428
if k[1] == the_hi_value and self.hi_var == col_name:
420429
label_rectangle_width_temp = label_rectangle_widths[idx]
421430
label_rectangle_total_obvs_color[k] = 0
431+
# case2: for other colors
422432
elif the_hi_value:
423433
num_obv = self.data_df.groupby([self.hi_var, col_name]).size().get((the_hi_value, k[1]), 0)
424434
label_rectangle_width_temp = bar * num_obv
425435
if self.min_bar_width and label_rectangle_width_temp <= self.min_bar_width and label_rectangle_width_temp != 0:
426436
label_rectangle_width_temp = self.min_bar_width
427437
label_rectangle_total_obvs_color[k] -= num_obv
438+
# case3: for the last default color
428439
else:
429440
label_rectangle_width_temp = bar * label_rectangle_total_obvs_color[k]
430441
if self.min_bar_width and label_rectangle_width_temp <= self.min_bar_width and label_rectangle_width_temp != 0:
@@ -434,11 +445,37 @@ def plot(self,
434445
label_rectangle_width_color_total[idx] += label_rectangle_width_temp
435446
idx+=1
436447

437-
xs, ys = label_rectangle_painter.get_coordinates(label_rectangle_left_center_pts, label_rectangle_right_center_pts, label_rectangle_widths)
438-
color_left_center_pts, color_right_center_pts = label_rectangle_painter.get_center_highlight(xs, ys,
439-
label_rectangle_ratio_color_centers)
440-
ax = label_rectangle_painter.plot(ax, color_left_center_pts, color_right_center_pts, label_rectangle_widths_color, color)
441-
# ax = label_rectangle_painter.plot(ax, label_rectangle_left_center_pts, label_rectangle_right_center_pts, label_rectangle_widths,'green')
448+
if not label_rectangle_vertical:
449+
xs, ys = label_rectangle_painter.get_coordinates(label_rectangle_left_center_pts, label_rectangle_right_center_pts, label_rectangle_widths)
450+
color_left_center_pts, color_right_center_pts = label_rectangle_painter.get_center_highlight(xs, ys,
451+
label_rectangle_ratio_color_centers)
452+
ax = label_rectangle_painter.plot(ax, color_left_center_pts, color_right_center_pts, label_rectangle_widths_color, color)
453+
# ax = label_rectangle_painter.plot(ax, label_rectangle_left_center_pts, label_rectangle_right_center_pts, label_rectangle_widths,'green')
454+
else:
455+
xs, ys = label_rectangle_painter.get_coordinates(label_rectangle_left_center_pts, label_rectangle_right_center_pts, label_rectangle_widths)
456+
457+
# switch the order of xs and ys temporarily for calculating the vertical coordinates
458+
vertical_xs = [np.array([x[1],x[3],x[0],x[2]]) for x in xs]
459+
vertical_ys = [np.array([y[1],y[3],y[0],y[2]]) for y in ys]
460+
461+
vertical_color_left_center_pts, vertical_color_right_center_pts = label_rectangle_painter.get_center_highlight(vertical_xs, vertical_ys,
462+
label_rectangle_ratio_color_centers)
463+
# switch back to normal coordinates
464+
vertical_label_rectangle_widths_color = [space*0.8*2*(w_color/w) for w_color,w in zip (label_rectangle_widths_color, label_rectangle_widths)]
465+
vertical_color_left_center_pts, vertical_color_right_center_pts, vertical_label_rectangle_widths_color = self._compute_left_right_centers(
466+
vertical_color_left_center_pts, vertical_color_right_center_pts, vertical_label_rectangle_widths_color
467+
)
468+
469+
470+
ax = label_rectangle_painter.plot(ax, vertical_color_left_center_pts, vertical_color_right_center_pts,vertical_label_rectangle_widths_color, color)
471+
472+
473+
# else:
474+
# vertical_label_rectangle_left_center_pts = [(pt2[0],pt1[0]) for pt1, pt2 in zip(label_rectangle_left_center_pts,label_rectangle_right_center_pts)]
475+
# vertical_label_rectangle_right_center_pts = [(pt1[1],pt2[0]) for pt1, pt2 in zip(label_rectangle_left_center_pts,label_rectangle_right_center_pts)]
476+
477+
# label_rectangle_total_obvs_color = label_rectangle_total_obvs.copy()
478+
# for i,color in enumerate(reversed(self.color_lst)):
442479

443480
if display_figure:
444481
ax.get_figure()
@@ -647,4 +684,64 @@ def _parse_colors(self, color_list):
647684

648685
return parsed_colors
649686

687+
def _compute_left_right_centers(self, midpoints_top, midpoints_bottom, lengths):
688+
"""
689+
Given the top and bottom midpoints of one or more aligned rectangles, along with their width,
690+
compute the midpoints and lengths of the left and right centers.
691+
692+
Parameters:
693+
--------
694+
midpoints_top : list of (float, float)
695+
List of top edge midpoints (Tx, Ty)
696+
midpoints_bottom : list of (float, float)
697+
List of bottom edge midpoints (Bx, By)
698+
length : float
699+
The width of the rectangles (assumed to be the same for all)
700+
701+
Returns:
702+
--------
703+
left_midpoints : list of (float, float)
704+
Midpoints of the left edges
705+
right_midpoints : list of (float, float)
706+
Midpoints of the right edges
707+
edge_length : float
708+
The height of the rectangles (same for both left and right edges)
709+
"""
710+
n = len(midpoints_top)
711+
712+
left_midpoints = []
713+
right_midpoints = []
714+
edge_lengths = []
715+
716+
for i in range(n):
717+
Tx, Ty = midpoints_top[i]
718+
Bx, By = midpoints_bottom[i]
719+
720+
Ty, By = (By, Ty) if Ty < By else (Ty, By)
721+
722+
length = lengths[i]
723+
724+
# Compute the vertical edge length
725+
edge_length = abs(By - Ty)
726+
727+
if length != 0:
728+
# Compute left and right edge midpoints
729+
Lmx = Tx - length / 2
730+
Lmy = Ty - edge_length / 2
731+
Rmx = Tx + length / 2
732+
Rmy = By + edge_length / 2
733+
else:
734+
Lmx = Tx - length / 2
735+
Lmy = Ty - edge_length / 2
736+
Rmx = Lmx
737+
Rmy = By + edge_length / 2
738+
edge_length = 0
739+
740+
left_midpoints.append((Lmx, Lmy))
741+
right_midpoints.append((Rmx, Rmy))
742+
edge_lengths.append(edge_length)
743+
744+
return left_midpoints, right_midpoints, edge_lengths
745+
746+
650747

0 commit comments

Comments
 (0)