Skip to content

Commit bb54dfa

Browse files
authored
Merge pull request #1 from TianchengY/feature_label_rectangle
merge feature label rectangle
2 parents 95dc760 + 4b885de commit bb54dfa

File tree

1 file changed

+118
-39
lines changed

1 file changed

+118
-39
lines changed

hammock/hammock.py

Lines changed: 118 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import matplotlib.pyplot as plt
44
from abc import ABC, abstractmethod
5-
from typing import List,Dict
5+
from typing import List, Dict
66
import warnings
77

88

@@ -121,7 +121,7 @@ def __init__(self,
121121

122122
def plot(self,
123123
var: List[str] = None,
124-
value_order: Dict[str, Dict[int,str]] = None,
124+
value_order: Dict[str, Dict[int, str]] = None,
125125
missing: bool = False,
126126
hi_missing: bool = False,
127127
missing_label_space: float = 1.,
@@ -130,7 +130,7 @@ def plot(self,
130130
hi_var: str = None,
131131
hi_value: List[str] = None,
132132
color: List[str] = None,
133-
default_color="blue",
133+
default_color="lightskyblue",
134134
# Manipulating Spacing and Layout
135135
bar_width: float = 1.,
136136
min_bar_width: float = .05,
@@ -152,7 +152,7 @@ def plot(self,
152152
if self.data_df[col].dtype.name == "category":
153153
self.data_df[col] = self.data_df[col].cat.add_categories(self.missing_data_placeholder)
154154
elif "float" in self.data_df[col].dtype.name:
155-
self.data_df[col] = self.data_df[col].apply(lambda x: np.round(x,2))
155+
self.data_df[col] = self.data_df[col].apply(lambda x: np.round(x, 2))
156156
self.data_df_columns = self.data_df.columns.tolist()
157157

158158
if not var_lst:
@@ -161,7 +161,6 @@ def plot(self,
161161
)
162162

163163
if color and type(color) != type([]):
164-
165164
raise ValueError(
166165
f'Argument "color" must be a list os str.'
167166
)
@@ -179,9 +178,9 @@ def plot(self,
179178
)
180179

181180
if value_order:
182-
for k,v_ori in value_order.items():
181+
for k, v_ori in value_order.items():
183182
uni_val_set = set(self.data_df[k].dropna().unique())
184-
v = [value_name for order,value_name in v_ori.items()]
183+
v = [value_name for order, value_name in v_ori.items()]
185184
if not set(v) >= uni_val_set:
186185
error_values = (set(v) ^ uni_val_set) & set(v)
187186
raise ValueError(
@@ -219,17 +218,18 @@ def plot(self,
219218
self.hi_value.append(self.missing_data_placeholder)
220219
else:
221220
self.hi_value = [self.missing_data_placeholder]
222-
colors = ["red", "green", "yellow", "lightblue","orange", "gray", "brown", "olive", "pink", "cyan", "magenta"]
221+
colors = ["red", "green", "yellow", "purple", "orange", "gray", "brown", "olive", "pink", "cyan", "magenta"]
223222
self.color_lst = [color for color in color_lst] if color_lst else (
224223
colors[:len(self.hi_value)] if hi_var else None)
225224
if hi_var:
226225
if hi_value and len(self.color_lst) < len(hi_value):
227-
for i in range(len(hi_value)-len(self.color_lst)):
226+
for i in range(len(hi_value) - len(self.color_lst)):
228227
for c in colors:
229228
if c not in self.color_lst:
230229
self.color_lst.append(c)
231230
break
232-
warnings.warn(f"Warning: The length of color is less than the total number of (high values and missing), color was automatically extended to {self.color_lst}")
231+
warnings.warn(
232+
f"Warning: The length of color is less than the total number of (high values and missing), color was automatically extended to {self.color_lst}")
233233
if hi_var and default_color in self.color_lst:
234234
raise ValueError(
235235
f'The current highlight colors {self.color_lst} conflict with the default color {default_color}. Please choose another default color or other highlight colors'
@@ -259,7 +259,7 @@ def plot(self,
259259
raise ValueError(
260260
f'the values: {error_values} in highlight value is not in data.'
261261
)
262-
262+
263263
value_color_dict = dict(zip(self.hi_value, self.color_lst))
264264

265265
self.data_df[self.color_coloumn_placeholder] = self.data_df[hi_var].apply(
@@ -303,7 +303,7 @@ def plot(self,
303303
ax, coordinates_dict = self._list_labels(ax, self.height, self.width, self.label)
304304

305305
space = self.space * 10 if label else 0
306-
bar = self.bar_width*3.5/max(data_point_numbers)
306+
bar = self.bar_width * 3.5 / max(data_point_numbers)
307307

308308
if self.shape == "parallelogram":
309309
figure_type = Parallelogram()
@@ -325,10 +325,54 @@ def plot(self,
325325
left_center_pts.append(left_coordinate)
326326
right_center_pts.append(right_coordinate)
327327

328+
label_rectangle = True if self.label else False
329+
label_rectangle_default_color = default_color
330+
label_rectangle_widths = []
331+
label_rectangle_total_obvs = {}
332+
if label_rectangle:
333+
label_rectangle_painter = Rectangle()
334+
label_rectangle_left_center_pts, label_rectangle_right_center_pts = [],[]
335+
for k,v in coordinates_dict.items():
336+
337+
# get width for label rectangles by counting the number of observations for each value
338+
339+
col_name = k[0].split(self.same_var_placeholder)[0]
340+
num_obv = self.data_df[col_name].value_counts().get(k[1], 0)
341+
label_rectangle_total_obvs[k] = num_obv
342+
label_rectangle_width = bar * num_obv
343+
if self.min_bar_width and label_rectangle_width <= self.min_bar_width:
344+
label_rectangle_width = self.min_bar_width
345+
346+
# get left and right coordinates for label rectangles
347+
# add space for very thick label rectangles
348+
half_label_rectangle_width = label_rectangle_width/2
349+
edge_adjust = self.max_y_range * 0.01
350+
if v[1] - half_label_rectangle_width < 0:
351+
adjust_value = half_label_rectangle_width - v[1] + edge_adjust
352+
label_rectangle_left_coordinate= (v[0]-space*0.8, v[1]+adjust_value)
353+
label_rectangle_right_coordinate = (v[0] + space * 0.8, v[1]+adjust_value)
354+
elif v[1] + half_label_rectangle_width > self.max_y_range:
355+
adjust_value = half_label_rectangle_width + v[1] - self.max_y_range + edge_adjust
356+
label_rectangle_left_coordinate= (v[0]-space*0.8, v[1]-adjust_value)
357+
label_rectangle_right_coordinate = (v[0] + space * 0.8, v[1]-adjust_value)
358+
else:
359+
label_rectangle_left_coordinate = (v[0]-space*0.8, v[1])
360+
label_rectangle_right_coordinate = (v[0] + space * 0.8, v[1])
361+
362+
label_rectangle_left_center_pts.append(label_rectangle_left_coordinate)
363+
label_rectangle_right_center_pts.append(label_rectangle_right_coordinate)
364+
365+
366+
367+
label_rectangle_widths.append(label_rectangle_width)
368+
328369
if not hi_var:
329370
ax = figure_type.plot(ax, left_center_pts, right_center_pts, widths, default_color)
371+
if label_rectangle:
372+
ax = label_rectangle_painter.plot(ax, label_rectangle_left_center_pts, label_rectangle_right_center_pts, label_rectangle_widths,label_rectangle_default_color)
330373
else:
331374
width_color_total = [0] * len(widths)
375+
label_rectangle_width_color_total = [0] * len(coordinates_dict)
332376
xs, ys = figure_type.get_coordinates(left_center_pts, right_center_pts, widths)
333377
for color in self.color_lst:
334378
widths_color, ratio_color_centers = [], []
@@ -353,6 +397,39 @@ def plot(self,
353397
ratio_color_centers)
354398
ax = figure_type.plot(ax, color_left_center_pts, color_right_center_pts, widths_color, color)
355399

400+
# always remember that color list was reversed, so the first color is the default color
401+
if label_rectangle:
402+
label_rectangle_total_obvs_color = label_rectangle_total_obvs.copy()
403+
for i,color in enumerate(reversed(self.color_lst)):
404+
the_hi_value = self.hi_value[i] if i != len(self.color_lst)-1 else None
405+
label_rectangle_widths_color, label_rectangle_ratio_color_centers = [], []
406+
idx=0
407+
for k,v in coordinates_dict.items():
408+
col_name = k[0].split(self.same_var_placeholder)[0]
409+
if k[1] == the_hi_value and self.hi_var == col_name:
410+
label_rectangle_width_temp = label_rectangle_widths[idx]
411+
label_rectangle_total_obvs_color[k] = 0
412+
elif the_hi_value:
413+
num_obv = self.data_df.groupby([self.hi_var, col_name]).size().get((the_hi_value, k[1]), 0)
414+
label_rectangle_width_temp = bar * num_obv
415+
if self.min_bar_width and label_rectangle_width_temp <= self.min_bar_width and label_rectangle_width_temp != 0:
416+
label_rectangle_width_temp = self.min_bar_width
417+
label_rectangle_total_obvs_color[k] -= num_obv
418+
else:
419+
label_rectangle_width_temp = bar * label_rectangle_total_obvs_color[k]
420+
if self.min_bar_width and label_rectangle_width_temp <= self.min_bar_width and label_rectangle_width_temp != 0:
421+
label_rectangle_width_temp = self.min_bar_width
422+
label_rectangle_widths_color.append(label_rectangle_width_temp)
423+
label_rectangle_ratio_color_centers.append((label_rectangle_width_color_total[idx] + label_rectangle_width_temp / 2) / label_rectangle_widths[idx])
424+
label_rectangle_width_color_total[idx] += label_rectangle_width_temp
425+
idx+=1
426+
427+
xs, ys = label_rectangle_painter.get_coordinates(label_rectangle_left_center_pts, label_rectangle_right_center_pts, label_rectangle_widths)
428+
color_left_center_pts, color_right_center_pts = label_rectangle_painter.get_center_highlight(xs, ys,
429+
label_rectangle_ratio_color_centers)
430+
ax = label_rectangle_painter.plot(ax, color_left_center_pts, color_right_center_pts, label_rectangle_widths_color, color)
431+
# ax = label_rectangle_painter.plot(ax, label_rectangle_left_center_pts, label_rectangle_right_center_pts, label_rectangle_widths,'green')
432+
356433
if display_figure:
357434
ax.get_figure()
358435
else:
@@ -365,7 +442,7 @@ def plot(self,
365442

366443
def _get_varname(self, x):
367444
return x.split(self.same_var_placeholder)[:-1][0]
368-
445+
369446
def is_float(self, element: any) -> bool:
370447
if element is None:
371448
return False
@@ -396,10 +473,10 @@ def _get_two_var(self, var_lst: List[str]):
396473

397474
return var_pair_lst
398475

399-
def _gen_coordinate(self, start, n, edge, spacing, total_range,val_type="str"):
476+
def _gen_coordinate(self, start, n, edge, spacing, total_range, val_type="str"):
400477
coor_lst = []
401-
402-
if val_type=="str":
478+
479+
if val_type == "str":
403480
for i in range(n):
404481
coor_lst.append(start + i * spacing)
405482

@@ -411,17 +488,17 @@ def _gen_coordinate(self, start, n, edge, spacing, total_range,val_type="str"):
411488
coor_lst.append(total_range + (start - edge) - edge)
412489
return coor_lst
413490

414-
def _get_same_scale_minmax(self,original_unique_value):
415-
min,max = 0,0
416-
for i,varname in enumerate(self.same_scale):
491+
def _get_same_scale_minmax(self, original_unique_value):
492+
min, max = 0, 0
493+
for i, varname in enumerate(self.same_scale):
417494
var_type = str(self.data_df_origin[varname].dtype.name)
418495
if "int" in var_type or "float" in var_type:
419496
min_val, max_val = original_unique_value[varname][0], original_unique_value[varname][-1]
420497
if i == 0:
421-
min,max = min_val, max_val
498+
min, max = min_val, max_val
422499
else:
423-
min = min_val if min_val<min else min
424-
max = max_val if max_val>max else max
500+
min = min_val if min_val < min else min
501+
max = max_val if max_val > max else max
425502

426503
else:
427504
min_val, max_val = 1, len(original_unique_value[varname])
@@ -430,40 +507,44 @@ def _get_same_scale_minmax(self,original_unique_value):
430507
else:
431508
min = min_val if min_val < min else min
432509
max = max_val if max_val > max else max
433-
return (min,max)
510+
return (min, max)
434511

435512
def _list_labels(self, ax, figsize_y, figsize_x, label):
436513

437514
scale = 10
438515
edge_scale = 10
439516
y_range = scale * figsize_y - self.missing_label_space * scale if self.missing else scale * figsize_y
440517
x_range = scale * figsize_x
518+
self.max_y_range, self.max_x_range = scale * figsize_y, scale * figsize_x
441519
edge_x_range = x_range / edge_scale
442520
edge_y_range = y_range / edge_scale
521+
# self.edge_y_range, self.edge_x_range = edge_y_range, edge_x_range
443522
y_start = edge_y_range + self.missing_label_space * scale if self.missing else edge_y_range
444523
coordinates_dict = {}
445524

446525
unique_value = []
447526
original_unique_value = {}
448527
varname_lst = [self._get_varname(var) for var in self.var_lst]
449-
528+
450529
for var, varname in zip(self.var_lst, varname_lst):
451530
unique_valnames = self.data_df[varname].dropna().unique().tolist()
452531
sorted_unique_valnames = []
453532
if self.value_order and varname in self.value_order:
454533
varname_value_order_dict = self.value_order[varname]
455534
sorted_unique_valnames_temp = [v for k, v in
456-
sorted(varname_value_order_dict.items(), key=lambda item: item[0])]
535+
sorted(varname_value_order_dict.items(), key=lambda item: item[0])]
457536
for v in sorted_unique_valnames_temp:
458537
if v in unique_valnames:
459538
sorted_unique_valnames.append(v)
460539
if self.missing_data_placeholder in unique_valnames:
461540
unique_valnames.remove(self.missing_data_placeholder)
462-
sorted_unique_valnames = sorted(unique_valnames) if not sorted_unique_valnames else sorted_unique_valnames
541+
sorted_unique_valnames = sorted(
542+
unique_valnames) if not sorted_unique_valnames else sorted_unique_valnames
463543
original_unique_value[varname] = sorted_unique_valnames.copy()
464544
sorted_unique_valnames.append(self.missing_data_placeholder)
465545
else:
466-
sorted_unique_valnames = sorted(unique_valnames) if not sorted_unique_valnames else sorted_unique_valnames
546+
sorted_unique_valnames = sorted(
547+
unique_valnames) if not sorted_unique_valnames else sorted_unique_valnames
467548
original_unique_value[varname] = sorted_unique_valnames.copy()
468549
unique_value.append([(var, x) for x in sorted_unique_valnames])
469550

@@ -478,11 +559,11 @@ def _list_labels(self, ax, figsize_y, figsize_x, label):
478559

479560
# prepare for same_scale variabels
480561
if self.same_scale:
481-
same_scale_min,same_scale_max = self._get_same_scale_minmax(original_unique_value)
482-
same_scale_range = same_scale_max-same_scale_min
562+
same_scale_min, same_scale_max = self._get_same_scale_minmax(original_unique_value)
563+
same_scale_range = same_scale_max - same_scale_min
483564

484565
# plot labels for each variables
485-
for var_i,(x, uni_val) in enumerate(zip(label_coordinates, unique_value)):
566+
for var_i, (x, uni_val) in enumerate(zip(label_coordinates, unique_value)):
486567
label_num = len(uni_val) - 2 if (uni_val[0][0], self.missing_data_placeholder) in uni_val else len(
487568
uni_val) - 1
488569
varname = varname_lst[var_i]
@@ -493,21 +574,22 @@ def _list_labels(self, ax, figsize_y, figsize_x, label):
493574
temp_value_range = (y_range - 2 * edge_y_range)
494575
# handle the variables in same_scale
495576
if self.same_scale and varname in self.same_scale:
496-
min_val, max_val = same_scale_min,same_scale_max
577+
min_val, max_val = same_scale_min, same_scale_max
497578
else:
498-
min_val,max_val = original_unique_value[varname][0],original_unique_value[varname][-1]
499-
value_interval = [temp_value_range*(x_val-min_val)/(max_val-min_val) for x_val in original_unique_value[varname]]
579+
min_val, max_val = original_unique_value[varname][0], original_unique_value[varname][-1]
580+
value_interval = [temp_value_range * (x_val - min_val) / (max_val - min_val) for x_val in
581+
original_unique_value[varname]]
500582
uni_val_coordinates = self._gen_coordinate(y_start, label_num, edge_y_range,
501-
value_interval, y_range,val_type = "number")
583+
value_interval, y_range, val_type="number")
502584
else:
503585
# handle the variables in same_scale
504586
if self.same_scale and varname in self.same_scale:
505587
temp_value_range = (y_range - 2 * edge_y_range)
506-
quant_val = list(range(1,len(original_unique_value[varname])+1))
588+
quant_val = list(range(1, len(original_unique_value[varname]) + 1))
507589
min_val, max_val = same_scale_min, same_scale_max
508590
value_interval = [temp_value_range * (x_val - min_val) / (max_val - min_val) for x_val in quant_val]
509591
uni_val_coordinates = self._gen_coordinate(y_start, label_num, edge_y_range,
510-
value_interval, y_range, val_type = "number")
592+
value_interval, y_range, val_type="number")
511593
else:
512594
value_interval = (y_range - 2 * edge_y_range) / (label_num)
513595
uni_val_coordinates = self._gen_coordinate(y_start, label_num, edge_y_range,
@@ -530,9 +612,6 @@ def _list_labels(self, ax, figsize_y, figsize_x, label):
530612
else:
531613
ax.text(x, y, val[1], ha='center', va='center')
532614
coordinates_dict[val] = (x, y)
533-
534-
535615
return ax, coordinates_dict
536616

537617

538-

0 commit comments

Comments
 (0)