@@ -62,7 +62,7 @@ def process_dataset(
62
62
target_col : str ,
63
63
target_label : str ,
64
64
embed_method : EmbeddingMethod ,
65
- projection_method : ProjectionMethod ,
65
+ projection : ProjectionMethod ,
66
66
n_components : int ,
67
67
) -> go .Figure :
68
68
"""Process a single dataset and create clustering visualizations.
@@ -72,7 +72,7 @@ def process_dataset(
72
72
target_col (str): Name of the target property column
73
73
target_label (str): Display label for the property
74
74
embed_method (EmbeddingMethod): Method to convert compositions to vectors
75
- projection_method (ProjectionMethod): Method to reduce dimensionality
75
+ projection (ProjectionMethod): Method to reduce dimensionality
76
76
n_components (int): Number of dimensions for projection (2 or 3)
77
77
78
78
Returns:
@@ -121,12 +121,18 @@ def process_dataset(
121
121
default_handler = lambda x : x .tolist () if hasattr (x , "tolist" ) else x
122
122
json .dump (embeddings_dict , file , default = default_handler )
123
123
124
- # Create plot with pre-computed embeddings
124
+ df_plot = pd .DataFrame ({"composition" : compositions })
125
+ df_plot [target_label ] = properties
126
+
127
+ if "embeddings" not in df_plot :
128
+ df_plot ["embeddings" ] = [embeddings_dict .get (comp ) for comp in compositions ]
129
+
125
130
fig = pmv .cluster_compositions (
126
- compositions = embeddings_dict ,
127
- properties = dict ( zip ( compositions , properties , strict = True )) ,
131
+ df = df_plot ,
132
+ composition_col = "composition" ,
128
133
prop_name = target_label ,
129
- projection_method = projection_method ,
134
+ embedding_method = "embeddings" ,
135
+ projection = projection ,
130
136
n_components = n_components ,
131
137
marker_size = 8 ,
132
138
opacity = 0.8 ,
@@ -136,7 +142,7 @@ def process_dataset(
136
142
)
137
143
138
144
# Update title and margins
139
- title = f"{ dataset_name } - { embed_method } + { projection_method } ({ n_components } D)"
145
+ title = f"{ dataset_name } - { embed_method } + { projection } ({ n_components } D)"
140
146
fig .layout .update (title = dict (text = title , x = 0.5 ), margin_t = 50 )
141
147
# format compositions and coordinates in hover tooltip
142
148
custom_data = [
@@ -146,9 +152,9 @@ def process_dataset(
146
152
fig .update_traces (
147
153
hovertemplate = (
148
154
"%{customdata[0]}<br>" # Formatted composition
149
- f"{ projection_method } 1: %{{x:.2f}}<br>" # First projection coordinate
150
- f"{ projection_method } 2: %{{y:.2f}}<br>" # Second projection coordinate
151
- + (f"{ projection_method } 3: %{{z:.2f}}<br>" if n_components == 3 else "" )
155
+ f"{ projection } 1: %{{x:.2f}}<br>" # First projection coordinate
156
+ f"{ projection } 2: %{{y:.2f}}<br>" # Second projection coordinate
157
+ + (f"{ projection } 3: %{{z:.2f}}<br>" if n_components == 3 else "" )
152
158
+ f"{ target_label } : %{{marker.color:.2f}}" # Property value
153
159
),
154
160
customdata = custom_data ,
@@ -203,7 +209,7 @@ def process_dataset(
203
209
target_col = target_col ,
204
210
target_label = target_label ,
205
211
embed_method = embed_method ,
206
- projection_method = proj_method ,
212
+ projection = proj_method ,
207
213
n_components = n_components ,
208
214
)
209
215
fig .update_layout (coloraxis_colorbar = cbar_args )
0 commit comments