@@ -21,7 +21,11 @@ def count_elements(formulas: list) -> pd.Series:
21
21
Returns:
22
22
pd.Series: Total number of appearances of each element in `formulas`.
23
23
"""
24
- srs = pd .Series (formulas ).apply (lambda x : pd .Series (Composition (x ).as_dict ())).sum ()
24
+ formula2dict = lambda str : pd .Series (
25
+ Composition (str ).fractional_composition .as_dict ()
26
+ )
27
+
28
+ srs = pd .Series (formulas ).apply (formula2dict ).sum ()
25
29
26
30
# ensure all elements are present in returned Series (with count zero if they
27
31
# weren't in formulas)
@@ -32,7 +36,10 @@ def count_elements(formulas: list) -> pd.Series:
32
36
33
37
34
38
def ptable_elemental_prevalence (
35
- formulas : List [str ] = None , elem_counts : pd .Series = None , log_scale : bool = False
39
+ formulas : List [str ] = None ,
40
+ elem_counts : pd .Series = None ,
41
+ log_scale : bool = False ,
42
+ cbar_title : str = None ,
36
43
) -> None :
37
44
"""Display the prevalence of each element in a materials dataset plotted as a
38
45
heatmap over the periodic table. `formulas` xor `elem_counts` must be passed.
@@ -54,34 +61,43 @@ def ptable_elemental_prevalence(
54
61
55
62
ptable = pd .read_csv (ROOT + "/data/periodic_table.csv" )
56
63
57
- n_row = ptable .row .max ()
58
- n_column = ptable .column .max ()
64
+ n_rows = ptable .row .max ()
65
+ n_columns = ptable .column .max ()
59
66
60
- plt .figure (figsize = (n_column , n_row ))
67
+ plt .figure (figsize = (n_columns , n_rows ))
61
68
62
69
rw = rh = 0.9 # rectangle width/height
63
- count_min = elem_counts .min ()
64
- count_max = elem_counts .max ()
70
+ min_count = elem_counts .min ()
71
+ max_count = elem_counts . replace ([ np . inf , - np . inf ], np . nan ). dropna () .max ()
65
72
66
73
norm = Normalize (
67
- vmin = 0 if log_scale else count_min ,
68
- vmax = np .log (count_max ) if log_scale else count_max ,
74
+ vmin = 0 if log_scale else min_count ,
75
+ vmax = np .log (max_count ) if log_scale else max_count ,
69
76
)
70
77
71
78
text_style = dict (
72
79
horizontalalignment = "center" ,
73
80
verticalalignment = "center" ,
74
81
fontsize = 20 ,
75
82
fontweight = "semibold" ,
76
- color = "black" ,
77
83
)
78
84
79
85
for symbol , row , column , _ in ptable .values :
80
- row = n_row - row
86
+ row = n_rows - row
81
87
count = elem_counts [symbol ]
88
+
82
89
if log_scale and count != 0 :
83
90
count = np .log (count )
84
- color = YlGn (norm (count )) if count != 0 else "silver"
91
+
92
+ # inf or NaN are expected when passing in elem_counts from ptable_elemental_ratio
93
+ if count == 0 : # not in formulas_a
94
+ color = "yellow"
95
+ elif count == np .inf :
96
+ color = "orange" # not in formulas_b
97
+ elif pd .isna (count ):
98
+ color = "gray" # not in either formulas_a nor formulas_b
99
+ else :
100
+ color = YlGn (norm (count )) if count != 0 else "silver"
85
101
86
102
if row < 3 :
87
103
row += 0.5
@@ -95,43 +111,68 @@ def ptable_elemental_prevalence(
95
111
x_offset = 3.5
96
112
y_offset = 7.8
97
113
length = 9
98
- for i in range (granularity ):
99
- value = int (round (( i ) * count_max / (granularity - 1 )))
114
+ for idx in range (granularity ):
115
+ value = int (round (idx * max_count / (granularity - 1 )))
100
116
if log_scale and value != 0 :
101
117
value = np .log (value )
102
118
color = YlGn (norm (value )) if value != 0 else "silver"
103
- x_loc = i / (granularity ) * length + x_offset
119
+ x_loc = idx / (granularity ) * length + x_offset
104
120
width = length / granularity
105
121
height = 0.35
106
122
rect = Rectangle (
107
123
(x_loc , y_offset ), width , height , edgecolor = "gray" , facecolor = color
108
124
)
109
125
110
- if i in [0 , 4 , 9 , 14 , 19 ]:
126
+ if idx in [0 , 4 , 9 , 14 , 19 ]:
111
127
text = f"{ value :g} "
112
128
if log_scale :
113
129
text = f"{ np .exp (value ):g} " .replace ("e+0" , "e" )
114
130
plt .text (x_loc + width / 2 , y_offset - 0.4 , text , ** text_style )
115
131
116
132
plt .gca ().add_patch (rect )
117
133
118
- plt .text (
119
- x_offset + length / 2 ,
120
- y_offset + 0.7 ,
121
- "log(Element Count)" if log_scale else "Element Count" ,
122
- horizontalalignment = "center" ,
123
- verticalalignment = "center" ,
124
- fontweight = "semibold" ,
125
- fontsize = 20 ,
126
- color = "k" ,
127
- )
134
+ if cbar_title is None :
135
+ cbar_title = "log(Element Count)" if log_scale else "Element Count"
136
+
137
+ plt .text (x_offset + length / 2 , y_offset + 0.7 , cbar_title , ** text_style )
128
138
129
- plt .ylim (- 0.15 , n_row + 0.1 )
130
- plt .xlim (0.85 , n_column + 1.1 )
139
+ plt .ylim (- 0.15 , n_rows + 0.1 )
140
+ plt .xlim (0.85 , n_columns + 1.1 )
131
141
132
142
plt .axis ("off" )
133
143
134
144
145
+ def ptable_elemental_ratio (
146
+ formulas_a : List [str ], formulas_b : List [str ], log_scale : bool = False
147
+ ) -> None :
148
+ """Display the prevalence of each element in a materials dataset plotted as a
149
+ heatmap over the periodic table. `formulas` xor `elem_counts` must be passed.
150
+
151
+ Adapted from https://github.com/kaaiian/ML_figures.
152
+
153
+ Args:
154
+ formulas (list[str]): compositional strings, e.g. ["Fe2O3", "Bi2Te3"]
155
+ elem_counts (pd.Series): Map from element symbol to prevalence count
156
+ log_scale (bool, optional): Whether color map scale is log or linear.
157
+ """
158
+ elem_counts_a = count_elements (formulas_a )
159
+ elem_counts_b = count_elements (formulas_b )
160
+
161
+ elem_counts = elem_counts_a / elem_counts_b
162
+
163
+ cbar_title = "log(Element Ratio)" if log_scale else "Element Ratio"
164
+
165
+ ptable_elemental_prevalence (
166
+ elem_counts = elem_counts , log_scale = log_scale , cbar_title = cbar_title
167
+ )
168
+
169
+ text_style = dict (fontsize = 14 , fontweight = "semibold" )
170
+
171
+ plt .text (0.2 , 2 , "yellow: not in first list" , ** text_style )
172
+ plt .text (0.2 , 1.5 , "orange: not in second list" , ** text_style )
173
+ plt .text (0.2 , 1 , "gray: not in either" , ** text_style )
174
+
175
+
135
176
def hist_elemental_prevalence (
136
177
formulas : list ,
137
178
log_scale : bool = False ,
0 commit comments