Skip to content

Commit ed0492e

Browse files
jaycee-lidandhlee
andauthored
feat: Support complex metrics in Vertex Experiments (#1698)
* Experiments complex metrics (#8) * feat: new class and API for metrics * update system test * update high level log method * fix system test * update example * change from system schema to google schema * fix: import error * Update log_classification_metrics_sample.py * Update samples/model-builder/experiment_tracking/log_classification_metrics_sample.py Co-authored-by: Dan Lee <[email protected]> * Update log_classification_metrics_sample_test.py * Update samples/model-builder/conftest.py Co-authored-by: Dan Lee <[email protected]> * fix: unit test * fix comments * fix comments and update google.ClassificationMetrics * fix comments and update ClassificationMetrics class * fix: ClassificationMetrics doesn't catch params with value=0 * add sample for get_classification_metrics * fix linting * add todos Co-authored-by: Dan Lee <[email protected]>
1 parent 5fe515c commit ed0492e

14 files changed

+810
-15
lines changed

google/cloud/aiplatform/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@
8686

8787
log_params = metadata.metadata._experiment_tracker.log_params
8888
log_metrics = metadata.metadata._experiment_tracker.log_metrics
89+
log_classification_metrics = (
90+
metadata.metadata._experiment_tracker.log_classification_metrics
91+
)
8992
get_experiment_df = metadata.metadata._experiment_tracker.get_experiment_df
9093
start_run = metadata.metadata._experiment_tracker.start_run
9194
start_execution = metadata.metadata._experiment_tracker.start_execution
@@ -110,6 +113,7 @@
110113
"log",
111114
"log_params",
112115
"log_metrics",
116+
"log_classification_metrics",
113117
"log_time_series_metrics",
114118
"get_experiment_df",
115119
"get_pipeline_df",

google/cloud/aiplatform/metadata/experiment_run_resource.py

+165
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
from google.cloud.aiplatform.metadata import metadata
4040
from google.cloud.aiplatform.metadata import resource
4141
from google.cloud.aiplatform.metadata import utils as metadata_utils
42+
from google.cloud.aiplatform.metadata.schema import utils as schema_utils
43+
from google.cloud.aiplatform.metadata.schema.google import (
44+
artifact_schema as google_artifact_schema,
45+
)
4246
from google.cloud.aiplatform.tensorboard import tensorboard_resource
4347
from google.cloud.aiplatform.utils import rest_utils
4448

@@ -990,6 +994,108 @@ def log_metrics(self, metrics: Dict[str, Union[float, int, str]]):
990994
# TODO: query the latest metrics artifact resource before logging.
991995
self._metadata_node.update(metadata={constants._METRIC_KEY: metrics})
992996

997+
@_v1_not_supported
998+
def log_classification_metrics(
999+
self,
1000+
*,
1001+
labels: Optional[List[str]] = None,
1002+
matrix: Optional[List[List[int]]] = None,
1003+
fpr: Optional[List[float]] = None,
1004+
tpr: Optional[List[float]] = None,
1005+
threshold: Optional[List[float]] = None,
1006+
display_name: Optional[str] = None,
1007+
):
1008+
"""Create an artifact for classification metrics and log to ExperimentRun. Currently supports confusion matrix and ROC curve.
1009+
1010+
```
1011+
my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment')
1012+
my_run.log_classification_metrics(
1013+
display_name='my-classification-metrics',
1014+
labels=['cat', 'dog'],
1015+
matrix=[[9, 1], [1, 9]],
1016+
fpr=[0.1, 0.5, 0.9],
1017+
tpr=[0.1, 0.7, 0.9],
1018+
threshold=[0.9, 0.5, 0.1],
1019+
)
1020+
```
1021+
1022+
Args:
1023+
labels (List[str]):
1024+
Optional. List of label names for the confusion matrix. Must be set if 'matrix' is set.
1025+
matrix (List[List[int]):
1026+
Optional. Values for the confusion matrix. Must be set if 'labels' is set.
1027+
fpr (List[float]):
1028+
Optional. List of false positive rates for the ROC curve. Must be set if 'tpr' or 'thresholds' is set.
1029+
tpr (List[float]):
1030+
Optional. List of true positive rates for the ROC curve. Must be set if 'fpr' or 'thresholds' is set.
1031+
threshold (List[float]):
1032+
Optional. List of thresholds for the ROC curve. Must be set if 'fpr' or 'tpr' is set.
1033+
display_name (str):
1034+
Optional. The user-defined name for the classification metric artifact.
1035+
1036+
Raises:
1037+
ValueError: if 'labels' and 'matrix' are not set together
1038+
or if 'labels' and 'matrix' are not in the same length
1039+
or if 'fpr' and 'tpr' and 'threshold' are not set together
1040+
or if 'fpr' and 'tpr' and 'threshold' are not in the same length
1041+
"""
1042+
if (labels or matrix) and not (labels and matrix):
1043+
raise ValueError("labels and matrix must be set together.")
1044+
1045+
if (fpr or tpr or threshold) and not (fpr and tpr and threshold):
1046+
raise ValueError("fpr, tpr, and thresholds must be set together.")
1047+
1048+
if labels and matrix:
1049+
if len(matrix) != len(labels):
1050+
raise ValueError(
1051+
"Length of labels and matrix must be the same. "
1052+
"Got lengths {} and {} respectively.".format(
1053+
len(labels), len(matrix)
1054+
)
1055+
)
1056+
annotation_specs = [
1057+
schema_utils.AnnotationSpec(display_name=label) for label in labels
1058+
]
1059+
confusion_matrix = schema_utils.ConfusionMatrix(
1060+
annotation_specs=annotation_specs,
1061+
matrix=matrix,
1062+
)
1063+
1064+
if fpr and tpr and threshold:
1065+
if (
1066+
len(fpr) != len(tpr)
1067+
or len(fpr) != len(threshold)
1068+
or len(tpr) != len(threshold)
1069+
):
1070+
raise ValueError(
1071+
"Length of fpr, tpr and threshold must be the same. "
1072+
"Got lengths {}, {} and {} respectively.".format(
1073+
len(fpr), len(tpr), len(threshold)
1074+
)
1075+
)
1076+
1077+
confidence_metrics = [
1078+
schema_utils.ConfidenceMetric(
1079+
confidence_threshold=confidence_threshold,
1080+
false_positive_rate=false_positive_rate,
1081+
recall=recall,
1082+
)
1083+
for confidence_threshold, false_positive_rate, recall in zip(
1084+
threshold, fpr, tpr
1085+
)
1086+
]
1087+
1088+
classification_metrics = google_artifact_schema.ClassificationMetrics(
1089+
display_name=display_name,
1090+
confusion_matrix=confusion_matrix,
1091+
confidence_metrics=confidence_metrics,
1092+
)
1093+
1094+
classfication_metrics = classification_metrics.create()
1095+
self._metadata_node.add_artifacts_and_executions(
1096+
artifact_resource_names=[classfication_metrics.resource_name]
1097+
)
1098+
9931099
@_v1_not_supported
9941100
def get_time_series_data_frame(self) -> "pd.DataFrame": # noqa: F821
9951101
"""Returns all time series in this Run as a DataFrame.
@@ -1149,6 +1255,65 @@ def get_metrics(self) -> Dict[str, Union[float, int, str]]:
11491255
else:
11501256
return self._metadata_node.metadata[constants._METRIC_KEY]
11511257

1258+
@_v1_not_supported
1259+
def get_classification_metrics(self) -> List[Dict[str, Union[str, List]]]:
1260+
"""Get all the classification metrics logged to this run.
1261+
1262+
```
1263+
my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment')
1264+
metric = my_run.get_classification_metrics()[0]
1265+
print(metric)
1266+
## print result:
1267+
{
1268+
"id": "e6c893a4-222e-4c60-a028-6a3b95dfc109",
1269+
"display_name": "my-classification-metrics",
1270+
"labels": ["cat", "dog"],
1271+
"matrix": [[9,1], [1,9]],
1272+
"fpr": [0.1, 0.5, 0.9],
1273+
"tpr": [0.1, 0.7, 0.9],
1274+
"thresholds": [0.9, 0.5, 0.1]
1275+
}
1276+
```
1277+
1278+
Returns:
1279+
List of classification metrics logged to this experiment run.
1280+
"""
1281+
1282+
artifact_list = artifact.Artifact.list(
1283+
filter=metadata_utils._make_filter_string(
1284+
in_context=[self.resource_name],
1285+
schema_title=google_artifact_schema.ClassificationMetrics.schema_title,
1286+
),
1287+
project=self.project,
1288+
location=self.location,
1289+
credentials=self.credentials,
1290+
)
1291+
1292+
metrics = []
1293+
for metric_artifact in artifact_list:
1294+
metric = {}
1295+
metric["id"] = metric_artifact.name
1296+
metric["display_name"] = metric_artifact.display_name
1297+
metadata = metric_artifact.metadata
1298+
if "confusionMatrix" in metadata:
1299+
metric["labels"] = [
1300+
d["displayName"]
1301+
for d in metadata["confusionMatrix"]["annotationSpecs"]
1302+
]
1303+
metric["matrix"] = metadata["confusionMatrix"]["rows"]
1304+
1305+
if "confidenceMetrics" in metadata:
1306+
metric["fpr"] = [
1307+
d["falsePositiveRate"] for d in metadata["confidenceMetrics"]
1308+
]
1309+
metric["tpr"] = [d["recall"] for d in metadata["confidenceMetrics"]]
1310+
metric["threshold"] = [
1311+
d["confidenceThreshold"] for d in metadata["confidenceMetrics"]
1312+
]
1313+
metrics.append(metric)
1314+
1315+
return metrics
1316+
11521317
@_v1_not_supported
11531318
def associate_execution(self, execution: execution.Execution):
11541319
"""Associate an execution to this experiment run.

google/cloud/aiplatform/metadata/metadata.py

+57-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
19-
from typing import Dict, Union, Optional, Any
18+
from typing import Dict, Union, Optional, Any, List
2019

2120
from google.api_core import exceptions
2221
from google.auth import credentials as auth_credentials
@@ -371,6 +370,62 @@ def log_metrics(self, metrics: Dict[str, Union[float, int, str]]):
371370
# query the latest metrics artifact resource before logging.
372371
self._experiment_run.log_metrics(metrics=metrics)
373372

373+
def log_classification_metrics(
374+
self,
375+
*,
376+
labels: Optional[List[str]] = None,
377+
matrix: Optional[List[List[int]]] = None,
378+
fpr: Optional[List[float]] = None,
379+
tpr: Optional[List[float]] = None,
380+
threshold: Optional[List[float]] = None,
381+
display_name: Optional[str] = None,
382+
):
383+
"""Create an artifact for classification metrics and log to ExperimentRun. Currently support confusion matrix and ROC curve.
384+
385+
```
386+
my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment')
387+
my_run.log_classification_metrics(
388+
display_name='my-classification-metrics',
389+
labels=['cat', 'dog'],
390+
matrix=[[9, 1], [1, 9]],
391+
fpr=[0.1, 0.5, 0.9],
392+
tpr=[0.1, 0.7, 0.9],
393+
threshold=[0.9, 0.5, 0.1],
394+
)
395+
```
396+
397+
Args:
398+
labels (List[str]):
399+
Optional. List of label names for the confusion matrix. Must be set if 'matrix' is set.
400+
matrix (List[List[int]):
401+
Optional. Values for the confusion matrix. Must be set if 'labels' is set.
402+
fpr (List[float]):
403+
Optional. List of false positive rates for the ROC curve. Must be set if 'tpr' or 'thresholds' is set.
404+
tpr (List[float]):
405+
Optional. List of true positive rates for the ROC curve. Must be set if 'fpr' or 'thresholds' is set.
406+
threshold (List[float]):
407+
Optional. List of thresholds for the ROC curve. Must be set if 'fpr' or 'tpr' is set.
408+
display_name (str):
409+
Optional. The user-defined name for the classification metric artifact.
410+
411+
Raises:
412+
ValueError: if 'labels' and 'matrix' are not set together
413+
or if 'labels' and 'matrix' are not in the same length
414+
or if 'fpr' and 'tpr' and 'threshold' are not set together
415+
or if 'fpr' and 'tpr' and 'threshold' are not in the same length
416+
"""
417+
418+
self._validate_experiment_and_run(method_name="log_classification_metrics")
419+
# query the latest metrics artifact resource before logging.
420+
self._experiment_run.log_classification_metrics(
421+
display_name=display_name,
422+
labels=labels,
423+
matrix=matrix,
424+
fpr=fpr,
425+
tpr=tpr,
426+
threshold=threshold,
427+
)
428+
374429
def _validate_experiment_and_run(self, method_name: str):
375430
"""Validates Experiment and Run are set and raises informative error message.
376431

google/cloud/aiplatform/metadata/schema/google/artifact_schema.py

+61-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616

1717
import copy
18-
from typing import Optional, Dict
18+
from typing import Optional, Dict, List
1919

2020
from google.cloud.aiplatform.compat.types import artifact as gca_artifact
2121
from google.cloud.aiplatform.metadata.schema import base_artifact
@@ -24,6 +24,12 @@
2424
# The artifact property key for the resource_name
2525
_ARTIFACT_PROPERTY_KEY_RESOURCE_NAME = "resourceName"
2626

27+
_CLASSIFICATION_METRICS_AGGREGATION_TYPE = [
28+
"AGGREGATION_TYPE_UNSPECIFIED",
29+
"MACRO_AVERAGE",
30+
"MICRO_AVERAGE",
31+
]
32+
2733

2834
class VertexDataset(base_artifact.BaseArtifactSchema):
2935
"""An artifact representing a Vertex Dataset."""
@@ -278,9 +284,17 @@ class ClassificationMetrics(base_artifact.BaseArtifactSchema):
278284
def __init__(
279285
self,
280286
*,
287+
aggregation_type: Optional[str] = None,
288+
aggregation_threshold: Optional[float] = None,
289+
recall: Optional[float] = None,
290+
precision: Optional[float] = None,
291+
f1_score: Optional[float] = None,
292+
accuracy: Optional[float] = None,
281293
au_prc: Optional[float] = None,
282294
au_roc: Optional[float] = None,
283295
log_loss: Optional[float] = None,
296+
confusion_matrix: Optional[utils.ConfusionMatrix] = None,
297+
confidence_metrics: Optional[List[utils.ConfidenceMetric]] = None,
284298
artifact_id: Optional[str] = None,
285299
uri: Optional[str] = None,
286300
display_name: Optional[str] = None,
@@ -290,6 +304,22 @@ def __init__(
290304
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
291305
):
292306
"""Args:
307+
aggregation_type (str):
308+
Optional. The way to generate the aggregated metrics. Choose from the following options:
309+
"AGGREGATION_TYPE_UNSPECIFIED": Indicating unset, used for per-class sliced metrics
310+
"MACRO_AVERAGE": The unweighted average, default behavior
311+
"MICRO_AVERAGE": The weighted average
312+
aggregation_threshold (float):
313+
Optional. The threshold used to generate aggregated metrics, default 0 for multi-class classification, 0.5 for binary classification.
314+
recall (float):
315+
Optional. Recall (True Positive Rate) for the given confidence threshold.
316+
precision (float):
317+
Optional. Precision for the given confidence threshold.
318+
f1_score (float):
319+
Optional. The harmonic mean of recall and precision.
320+
accuracy (float):
321+
Optional. Accuracy is the fraction of predictions given the correct label.
322+
For multiclass this is a micro-average metric.
293323
au_prc (float):
294324
Optional. The Area Under Precision-Recall Curve metric.
295325
Micro-averaged for the overall evaluation.
@@ -298,6 +328,10 @@ def __init__(
298328
Micro-averaged for the overall evaluation.
299329
log_loss (float):
300330
Optional. The Log Loss metric.
331+
confusion_matrix (utils.ConfusionMatrix):
332+
Optional. Aggregated confusion matrix.
333+
confidence_metrics (List[utils.ConfidenceMetric]):
334+
Optional. List of metrics for different confidence thresholds.
301335
artifact_id (str):
302336
Optional. The <resource_id> portion of the Artifact name with
303337
the format. This is globally unique in a metadataStore:
@@ -323,12 +357,35 @@ def __init__(
323357
check the validity of state transitions.
324358
"""
325359
extended_metadata = copy.deepcopy(metadata) if metadata else {}
326-
if au_prc:
360+
if aggregation_type:
361+
if aggregation_type not in _CLASSIFICATION_METRICS_AGGREGATION_TYPE:
362+
## Todo: add negative test case for this
363+
raise ValueError(
364+
"aggregation_type can only be 'AGGREGATION_TYPE_UNSPECIFIED', 'MACRO_AVERAGE', or 'MICRO_AVERAGE'."
365+
)
366+
extended_metadata["aggregationType"] = aggregation_type
367+
if aggregation_threshold is not None:
368+
extended_metadata["aggregationThreshold"] = aggregation_threshold
369+
if recall is not None:
370+
extended_metadata["recall"] = recall
371+
if precision is not None:
372+
extended_metadata["precision"] = precision
373+
if f1_score is not None:
374+
extended_metadata["f1Score"] = f1_score
375+
if accuracy is not None:
376+
extended_metadata["accuracy"] = accuracy
377+
if au_prc is not None:
327378
extended_metadata["auPrc"] = au_prc
328-
if au_roc:
379+
if au_roc is not None:
329380
extended_metadata["auRoc"] = au_roc
330-
if log_loss:
381+
if log_loss is not None:
331382
extended_metadata["logLoss"] = log_loss
383+
if confusion_matrix:
384+
extended_metadata["confusionMatrix"] = confusion_matrix.to_dict()
385+
if confidence_metrics:
386+
extended_metadata["confidenceMetrics"] = [
387+
confidence_metric.to_dict() for confidence_metric in confidence_metrics
388+
]
332389

333390
super(ClassificationMetrics, self).__init__(
334391
uri=uri,

0 commit comments

Comments
 (0)