|
39 | 39 | from google.cloud.aiplatform.metadata import metadata
|
40 | 40 | from google.cloud.aiplatform.metadata import resource
|
41 | 41 | from google.cloud.aiplatform.metadata import utils as metadata_utils
|
| 42 | +from google.cloud.aiplatform.metadata.schema import utils as schema_utils |
| 43 | +from google.cloud.aiplatform.metadata.schema.google import ( |
| 44 | + artifact_schema as google_artifact_schema, |
| 45 | +) |
42 | 46 | from google.cloud.aiplatform.tensorboard import tensorboard_resource
|
43 | 47 | from google.cloud.aiplatform.utils import rest_utils
|
44 | 48 |
|
@@ -990,6 +994,108 @@ def log_metrics(self, metrics: Dict[str, Union[float, int, str]]):
|
990 | 994 | # TODO: query the latest metrics artifact resource before logging.
|
991 | 995 | self._metadata_node.update(metadata={constants._METRIC_KEY: metrics})
|
992 | 996 |
|
| 997 | + @_v1_not_supported |
| 998 | + def log_classification_metrics( |
| 999 | + self, |
| 1000 | + *, |
| 1001 | + labels: Optional[List[str]] = None, |
| 1002 | + matrix: Optional[List[List[int]]] = None, |
| 1003 | + fpr: Optional[List[float]] = None, |
| 1004 | + tpr: Optional[List[float]] = None, |
| 1005 | + threshold: Optional[List[float]] = None, |
| 1006 | + display_name: Optional[str] = None, |
| 1007 | + ): |
| 1008 | + """Create an artifact for classification metrics and log to ExperimentRun. Currently supports confusion matrix and ROC curve. |
| 1009 | +
|
| 1010 | + ``` |
| 1011 | + my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment') |
| 1012 | + my_run.log_classification_metrics( |
| 1013 | + display_name='my-classification-metrics', |
| 1014 | + labels=['cat', 'dog'], |
| 1015 | + matrix=[[9, 1], [1, 9]], |
| 1016 | + fpr=[0.1, 0.5, 0.9], |
| 1017 | + tpr=[0.1, 0.7, 0.9], |
| 1018 | + threshold=[0.9, 0.5, 0.1], |
| 1019 | + ) |
| 1020 | + ``` |
| 1021 | +
|
| 1022 | + Args: |
| 1023 | + labels (List[str]): |
| 1024 | + Optional. List of label names for the confusion matrix. Must be set if 'matrix' is set. |
| 1025 | + matrix (List[List[int]): |
| 1026 | + Optional. Values for the confusion matrix. Must be set if 'labels' is set. |
| 1027 | + fpr (List[float]): |
| 1028 | + Optional. List of false positive rates for the ROC curve. Must be set if 'tpr' or 'thresholds' is set. |
| 1029 | + tpr (List[float]): |
| 1030 | + Optional. List of true positive rates for the ROC curve. Must be set if 'fpr' or 'thresholds' is set. |
| 1031 | + threshold (List[float]): |
| 1032 | + Optional. List of thresholds for the ROC curve. Must be set if 'fpr' or 'tpr' is set. |
| 1033 | + display_name (str): |
| 1034 | + Optional. The user-defined name for the classification metric artifact. |
| 1035 | +
|
| 1036 | + Raises: |
| 1037 | + ValueError: if 'labels' and 'matrix' are not set together |
| 1038 | + or if 'labels' and 'matrix' are not in the same length |
| 1039 | + or if 'fpr' and 'tpr' and 'threshold' are not set together |
| 1040 | + or if 'fpr' and 'tpr' and 'threshold' are not in the same length |
| 1041 | + """ |
| 1042 | + if (labels or matrix) and not (labels and matrix): |
| 1043 | + raise ValueError("labels and matrix must be set together.") |
| 1044 | + |
| 1045 | + if (fpr or tpr or threshold) and not (fpr and tpr and threshold): |
| 1046 | + raise ValueError("fpr, tpr, and thresholds must be set together.") |
| 1047 | + |
| 1048 | + if labels and matrix: |
| 1049 | + if len(matrix) != len(labels): |
| 1050 | + raise ValueError( |
| 1051 | + "Length of labels and matrix must be the same. " |
| 1052 | + "Got lengths {} and {} respectively.".format( |
| 1053 | + len(labels), len(matrix) |
| 1054 | + ) |
| 1055 | + ) |
| 1056 | + annotation_specs = [ |
| 1057 | + schema_utils.AnnotationSpec(display_name=label) for label in labels |
| 1058 | + ] |
| 1059 | + confusion_matrix = schema_utils.ConfusionMatrix( |
| 1060 | + annotation_specs=annotation_specs, |
| 1061 | + matrix=matrix, |
| 1062 | + ) |
| 1063 | + |
| 1064 | + if fpr and tpr and threshold: |
| 1065 | + if ( |
| 1066 | + len(fpr) != len(tpr) |
| 1067 | + or len(fpr) != len(threshold) |
| 1068 | + or len(tpr) != len(threshold) |
| 1069 | + ): |
| 1070 | + raise ValueError( |
| 1071 | + "Length of fpr, tpr and threshold must be the same. " |
| 1072 | + "Got lengths {}, {} and {} respectively.".format( |
| 1073 | + len(fpr), len(tpr), len(threshold) |
| 1074 | + ) |
| 1075 | + ) |
| 1076 | + |
| 1077 | + confidence_metrics = [ |
| 1078 | + schema_utils.ConfidenceMetric( |
| 1079 | + confidence_threshold=confidence_threshold, |
| 1080 | + false_positive_rate=false_positive_rate, |
| 1081 | + recall=recall, |
| 1082 | + ) |
| 1083 | + for confidence_threshold, false_positive_rate, recall in zip( |
| 1084 | + threshold, fpr, tpr |
| 1085 | + ) |
| 1086 | + ] |
| 1087 | + |
| 1088 | + classification_metrics = google_artifact_schema.ClassificationMetrics( |
| 1089 | + display_name=display_name, |
| 1090 | + confusion_matrix=confusion_matrix, |
| 1091 | + confidence_metrics=confidence_metrics, |
| 1092 | + ) |
| 1093 | + |
| 1094 | + classfication_metrics = classification_metrics.create() |
| 1095 | + self._metadata_node.add_artifacts_and_executions( |
| 1096 | + artifact_resource_names=[classfication_metrics.resource_name] |
| 1097 | + ) |
| 1098 | + |
993 | 1099 | @_v1_not_supported
|
994 | 1100 | def get_time_series_data_frame(self) -> "pd.DataFrame": # noqa: F821
|
995 | 1101 | """Returns all time series in this Run as a DataFrame.
|
@@ -1149,6 +1255,65 @@ def get_metrics(self) -> Dict[str, Union[float, int, str]]:
|
1149 | 1255 | else:
|
1150 | 1256 | return self._metadata_node.metadata[constants._METRIC_KEY]
|
1151 | 1257 |
|
| 1258 | + @_v1_not_supported |
| 1259 | + def get_classification_metrics(self) -> List[Dict[str, Union[str, List]]]: |
| 1260 | + """Get all the classification metrics logged to this run. |
| 1261 | +
|
| 1262 | + ``` |
| 1263 | + my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment') |
| 1264 | + metric = my_run.get_classification_metrics()[0] |
| 1265 | + print(metric) |
| 1266 | + ## print result: |
| 1267 | + { |
| 1268 | + "id": "e6c893a4-222e-4c60-a028-6a3b95dfc109", |
| 1269 | + "display_name": "my-classification-metrics", |
| 1270 | + "labels": ["cat", "dog"], |
| 1271 | + "matrix": [[9,1], [1,9]], |
| 1272 | + "fpr": [0.1, 0.5, 0.9], |
| 1273 | + "tpr": [0.1, 0.7, 0.9], |
| 1274 | + "thresholds": [0.9, 0.5, 0.1] |
| 1275 | + } |
| 1276 | + ``` |
| 1277 | +
|
| 1278 | + Returns: |
| 1279 | + List of classification metrics logged to this experiment run. |
| 1280 | + """ |
| 1281 | + |
| 1282 | + artifact_list = artifact.Artifact.list( |
| 1283 | + filter=metadata_utils._make_filter_string( |
| 1284 | + in_context=[self.resource_name], |
| 1285 | + schema_title=google_artifact_schema.ClassificationMetrics.schema_title, |
| 1286 | + ), |
| 1287 | + project=self.project, |
| 1288 | + location=self.location, |
| 1289 | + credentials=self.credentials, |
| 1290 | + ) |
| 1291 | + |
| 1292 | + metrics = [] |
| 1293 | + for metric_artifact in artifact_list: |
| 1294 | + metric = {} |
| 1295 | + metric["id"] = metric_artifact.name |
| 1296 | + metric["display_name"] = metric_artifact.display_name |
| 1297 | + metadata = metric_artifact.metadata |
| 1298 | + if "confusionMatrix" in metadata: |
| 1299 | + metric["labels"] = [ |
| 1300 | + d["displayName"] |
| 1301 | + for d in metadata["confusionMatrix"]["annotationSpecs"] |
| 1302 | + ] |
| 1303 | + metric["matrix"] = metadata["confusionMatrix"]["rows"] |
| 1304 | + |
| 1305 | + if "confidenceMetrics" in metadata: |
| 1306 | + metric["fpr"] = [ |
| 1307 | + d["falsePositiveRate"] for d in metadata["confidenceMetrics"] |
| 1308 | + ] |
| 1309 | + metric["tpr"] = [d["recall"] for d in metadata["confidenceMetrics"]] |
| 1310 | + metric["threshold"] = [ |
| 1311 | + d["confidenceThreshold"] for d in metadata["confidenceMetrics"] |
| 1312 | + ] |
| 1313 | + metrics.append(metric) |
| 1314 | + |
| 1315 | + return metrics |
| 1316 | + |
1152 | 1317 | @_v1_not_supported
|
1153 | 1318 | def associate_execution(self, execution: execution.Execution):
|
1154 | 1319 | """Associate an execution to this experiment run.
|
|
0 commit comments