Skip to content

Commit 51df86e

Browse files
sararobcopybara-github
authored andcommitted
feat: add model.evaluate() method to Model class
PiperOrigin-RevId: 553544432
1 parent 77ed9ef commit 51df86e

File tree

9 files changed

+2587
-15
lines changed

9 files changed

+2587
-15
lines changed

google/cloud/aiplatform/model_evaluation/__init__.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2022 Google LLC
3+
# Copyright 2023 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -15,6 +15,11 @@
1515
# limitations under the License.
1616
#
1717

18-
from google.cloud.aiplatform.model_evaluation.model_evaluation import ModelEvaluation
18+
from google.cloud.aiplatform.model_evaluation.model_evaluation import (
19+
ModelEvaluation,
20+
)
21+
from google.cloud.aiplatform.model_evaluation.model_evaluation_job import (
22+
_ModelEvaluationJob,
23+
)
1924

20-
__all__ = ("ModelEvaluation",)
25+
__all__ = ("ModelEvaluation", "_ModelEvaluationJob")

google/cloud/aiplatform/model_evaluation/model_evaluation.py

+33-8
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@
1515
# limitations under the License.
1616
#
1717

18+
from typing import List, Optional
19+
20+
from google.protobuf import struct_pb2
21+
1822
from google.auth import credentials as auth_credentials
1923

24+
from google.cloud import aiplatform
2025
from google.cloud.aiplatform import base
21-
from google.cloud.aiplatform import utils
2226
from google.cloud.aiplatform import models
23-
from google.protobuf import struct_pb2
24-
25-
from typing import List, Optional
27+
from google.cloud.aiplatform import pipeline_jobs
28+
from google.cloud.aiplatform import utils
2629

2730

2831
class ModelEvaluation(base.VertexAiResourceNounWithFutureManager):
@@ -36,13 +39,35 @@ class ModelEvaluation(base.VertexAiResourceNounWithFutureManager):
3639
_format_resource_name_method = "model_evaluation_path"
3740

3841
@property
39-
def metrics(self) -> Optional[struct_pb2.Value]:
42+
def metrics(self) -> struct_pb2.Value:
4043
"""Gets the evaluation metrics from the Model Evaluation.
44+
45+
Returns:
46+
A struct_pb2.Value with model metrics created from the Model Evaluation
47+
Raises:
48+
ValueError: If the Model Evaluation doesn't have metrics.
49+
"""
50+
if self._gca_resource.metrics:
51+
return self._gca_resource.metrics
52+
53+
raise ValueError(
54+
"This ModelEvaluation does not have any metrics, this could be because the Evaluation job failed. Check the logs for details."
55+
)
56+
57+
@property
58+
def _backing_pipeline_job(self) -> Optional["pipeline_jobs.PipelineJob"]:
59+
"""The managed pipeline for this model evaluation job.
4160
Returns:
42-
A dict with model metrics created from the Model Evaluation or
43-
None if the metrics for this evaluation are empty.
61+
The PipelineJob resource if this evaluation ran from a managed pipeline or None.
4462
"""
45-
return self._gca_resource.metrics
63+
if (
64+
"metadata" in self._gca_resource
65+
and "pipeline_job_resource_name" in self._gca_resource.metadata
66+
):
67+
return aiplatform.PipelineJob.get(
68+
resource_name=self._gca_resource.metadata["pipeline_job_resource_name"],
69+
credentials=self.credentials,
70+
)
4671

4772
def __init__(
4873
self,

google/cloud/aiplatform/model_evaluation/model_evaluation_job.py

+410
Large diffs are not rendered by default.

google/cloud/aiplatform/models.py

+231
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@
9393
"saved_model.pbtxt",
9494
]
9595

96+
_SUPPORTED_EVAL_PREDICTION_TYPES = [
97+
"classification",
98+
"regression",
99+
]
100+
96101

97102
class VersionInfo(NamedTuple):
98103
"""VersionInfo class envelopes returned Model version information.
@@ -4895,6 +4900,232 @@ def get_model_evaluation(
48954900
credentials=self.credentials,
48964901
)
48974902

4903+
def evaluate(
4904+
self,
4905+
prediction_type: str,
4906+
target_field_name: str,
4907+
gcs_source_uris: Optional[List[str]] = None,
4908+
bigquery_source_uri: Optional[str] = None,
4909+
bigquery_destination_output_uri: Optional[str] = None,
4910+
class_labels: Optional[List[str]] = None,
4911+
prediction_label_column: Optional[str] = None,
4912+
prediction_score_column: Optional[str] = None,
4913+
staging_bucket: Optional[str] = None,
4914+
service_account: Optional[str] = None,
4915+
generate_feature_attributions: bool = False,
4916+
evaluation_pipeline_display_name: Optional[str] = None,
4917+
evaluation_metrics_display_name: Optional[str] = None,
4918+
network: Optional[str] = None,
4919+
encryption_spec_key_name: Optional[str] = None,
4920+
experiment: Optional[Union[str, "aiplatform.Experiment"]] = None,
4921+
) -> "model_evaluation._ModelEvaluationJob":
4922+
"""Creates a model evaluation job running on Vertex Pipelines and returns the resulting
4923+
ModelEvaluationJob resource.
4924+
4925+
Example usage:
4926+
4927+
```
4928+
my_model = Model(
4929+
model_name="projects/123/locations/us-central1/models/456"
4930+
)
4931+
my_evaluation_job = my_model.evaluate(
4932+
prediction_type="classification",
4933+
target_field_name="type",
4934+
data_source_uris=["gs://sdk-model-eval/my-prediction-data.csv"],
4935+
staging_bucket="gs://my-staging-bucket/eval_pipeline_root",
4936+
)
4937+
my_evaluation_job.wait()
4938+
my_evaluation = my_evaluation_job.get_model_evaluation()
4939+
my_evaluation.metrics
4940+
```
4941+
4942+
Args:
4943+
prediction_type (str):
4944+
Required. The problem type being addressed by this evaluation run. 'classification' and 'regression'
4945+
are the currently supported problem types.
4946+
target_field_name (str):
4947+
Required. The column name of the field containing the label for this prediction task.
4948+
gcs_source_uris (List[str]):
4949+
Optional. A list of Cloud Storage data files containing the ground truth data to use for this
4950+
evaluation job. These files should contain your model's prediction column. Currently only Google Cloud Storage
4951+
urls are supported, for example: "gs://path/to/your/data.csv". The provided data files must be
4952+
either CSV or JSONL. One of `gcs_source_uris` or `bigquery_source_uri` is required.
4953+
bigquery_source_uri (str):
4954+
Optional. A bigquery table URI containing the ground truth data to use for this evaluation job. This uri should
4955+
be in the format 'bq://my-project-id.dataset.table'. One of `gcs_source_uris` or `bigquery_source_uri` is
4956+
required.
4957+
bigquery_destination_output_uri (str):
4958+
Optional. A bigquery table URI where the Batch Prediction job associated with your Model Evaluation will write
4959+
prediction output. This can be a BigQuery URI to a project ('bq://my-project'), a dataset
4960+
('bq://my-project.my-dataset'), or a table ('bq://my-project.my-dataset.my-table'). Required if `bigquery_source_uri`
4961+
is provided.
4962+
class_labels (List[str]):
4963+
Optional. For custom (non-AutoML) classification models, a list of possible class names, in the
4964+
same order that predictions are generated. This argument is required when prediction_type is 'classification'.
4965+
For example, in a classification model with 3 possible classes that are outputted in the format: [0.97, 0.02, 0.01]
4966+
with the class names "cat", "dog", and "fish", the value of `class_labels` should be `["cat", "dog", "fish"]` where
4967+
the class "cat" corresponds with 0.97 in the example above.
4968+
prediction_label_column (str):
4969+
Optional. The column name of the field containing classes the model is scoring. Formatted to be able to find nested
4970+
columns, delimeted by `.`. If not set, defaulted to `prediction.classes` for classification.
4971+
prediction_score_column (str):
4972+
Optional. The column name of the field containing batch prediction scores. Formatted to be able to find nested columns,
4973+
delimeted by `.`. If not set, defaulted to `prediction.scores` for a `classification` problem_type, `prediction.value`
4974+
for a `regression` problem_type.
4975+
staging_bucket (str):
4976+
Optional. The GCS directory to use for staging files from this evaluation job. Defaults to the value set in
4977+
aiplatform.init(staging_bucket=...) if not provided. Required if staging_bucket is not set in aiplatform.init().
4978+
service_account (str):
4979+
Specifies the service account for workload run-as account for this Model Evaluation PipelineJob.
4980+
Users submitting jobs must have act-as permission on this run-as account. The service account running
4981+
this Model Evaluation job needs the following permissions: Dataflow Worker, Storage Admin,
4982+
Vertex AI Administrator, and Vertex AI Service Agent.
4983+
generate_feature_attributions (boolean):
4984+
Optional. Whether the model evaluation job should generate feature attributions. Defaults to False if not specified.
4985+
evaluation_pipeline_display_name (str):
4986+
Optional. The display name of your model evaluation job. This is the display name that will be applied to the
4987+
Vertex Pipeline run for your evaluation job. If not set, a display name will be generated automatically.
4988+
evaluation_metrics_display_name (str):
4989+
Optional. The display name of the model evaluation resource uploaded to Vertex from your Model Evaluation pipeline.
4990+
network (str):
4991+
The full name of the Compute Engine network to which the job
4992+
should be peered. For example, projects/12345/global/networks/myVPC.
4993+
Private services access must already be configured for the network.
4994+
If left unspecified, the job is not peered with any network.
4995+
encryption_spec_key_name (str):
4996+
Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the job. Has the
4997+
form: ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. The key needs to be in the same
4998+
region as where the compute resource is created. If this is set, then all
4999+
resources created by the PipelineJob for this Model Evaluation will be encrypted with the provided encryption key.
5000+
If not specified, encryption_spec of original PipelineJob will be used.
5001+
experiment (Union[str, experiments_resource.Experiment]):
5002+
Optional. The Vertex AI experiment name or instance to associate to the PipelineJob executing
5003+
this model evaluation job. Metrics produced by the PipelineJob as system.Metric Artifacts
5004+
will be associated as metrics to the provided experiment, and parameters from this PipelineJob
5005+
will be associated as parameters to the provided experiment.
5006+
Returns:
5007+
model_evaluation.ModelEvaluationJob: Instantiated representation of the
5008+
_ModelEvaluationJob.
5009+
Raises:
5010+
ValueError:
5011+
If staging_bucket was not set in aiplatform.init() and staging_bucket was not provided.
5012+
If the provided `prediction_type` is not valid.
5013+
If the provided `data_source_uris` don't start with 'gs://'.
5014+
"""
5015+
5016+
if (gcs_source_uris is None) == (bigquery_source_uri is None):
5017+
raise ValueError(
5018+
"Exactly one of `gcs_source_uris` or `bigquery_source_uri` must be provided."
5019+
)
5020+
5021+
if isinstance(gcs_source_uris, str):
5022+
gcs_source_uris = [gcs_source_uris]
5023+
5024+
if bigquery_source_uri and not isinstance(bigquery_source_uri, str):
5025+
raise ValueError("The provided `bigquery_source_uri` must be a string.")
5026+
5027+
if bigquery_source_uri and not bigquery_destination_output_uri:
5028+
raise ValueError(
5029+
"`bigquery_destination_output_uri` must be provided if `bigquery_source_uri` is used as the data source."
5030+
)
5031+
5032+
if gcs_source_uris is not None and not all(
5033+
uri.startswith("gs://") for uri in gcs_source_uris
5034+
):
5035+
raise ValueError("`gcs_source_uris` must start with 'gs://'.")
5036+
5037+
if bigquery_source_uri is not None and not bigquery_source_uri.startswith(
5038+
"bq://"
5039+
):
5040+
raise ValueError(
5041+
"`bigquery_source_uri` and `bigquery_destination_output_uri` must start with 'bq://'"
5042+
)
5043+
5044+
if (
5045+
bigquery_destination_output_uri is not None
5046+
and not bigquery_destination_output_uri.startswith("bq://")
5047+
):
5048+
raise ValueError(
5049+
"`bigquery_source_uri` and `bigquery_destination_output_uri` must start with 'bq://'"
5050+
)
5051+
5052+
SUPPORTED_INSTANCES_FORMAT_FILE_EXTENSIONS = [".jsonl", ".csv"]
5053+
5054+
if not staging_bucket and initializer.global_config.staging_bucket:
5055+
staging_bucket = initializer.global_config.staging_bucket
5056+
elif not staging_bucket and not initializer.global_config.staging_bucket:
5057+
raise ValueError(
5058+
"Please provide `evaluation_staging_bucket` when calling evaluate or set one using aiplatform.init(staging_bucket=...)"
5059+
)
5060+
5061+
if prediction_type not in _SUPPORTED_EVAL_PREDICTION_TYPES:
5062+
raise ValueError(
5063+
f"Please provide a supported model prediction type, one of: {_SUPPORTED_EVAL_PREDICTION_TYPES}."
5064+
)
5065+
5066+
if generate_feature_attributions:
5067+
if not self._gca_resource.explanation_spec:
5068+
raise ValueError(
5069+
"To generate feature attributions with your evaluation, call evaluate on a model with an explanation spec. To run evaluation on the current model, call evaluate with `generate_feature_attributions=False`."
5070+
)
5071+
5072+
instances_format = None
5073+
5074+
if gcs_source_uris:
5075+
5076+
data_file_path_obj = pathlib.Path(gcs_source_uris[0])
5077+
5078+
data_file_extension = data_file_path_obj.suffix
5079+
if data_file_extension not in SUPPORTED_INSTANCES_FORMAT_FILE_EXTENSIONS:
5080+
_LOGGER.warning(
5081+
f"Only the following data file extensions are currently supported: '{SUPPORTED_INSTANCES_FORMAT_FILE_EXTENSIONS}'"
5082+
)
5083+
else:
5084+
instances_format = data_file_extension[1:]
5085+
5086+
elif bigquery_source_uri:
5087+
instances_format = "bigquery"
5088+
5089+
if (
5090+
self._gca_resource.metadata_schema_uri
5091+
== "https://storage.googleapis.com/google-cloud-aiplatform/schema/model/metadata/automl_tabular_1.0.0.yaml"
5092+
):
5093+
model_type = "automl_tabular"
5094+
else:
5095+
model_type = "other"
5096+
5097+
if (
5098+
model_type == "other"
5099+
and prediction_type == "classification"
5100+
and not class_labels
5101+
):
5102+
raise ValueError(
5103+
"Please provide `class_labels` when running evaluation on a custom classification model."
5104+
)
5105+
5106+
return model_evaluation._ModelEvaluationJob.submit(
5107+
model_name=self.versioned_resource_name,
5108+
prediction_type=prediction_type,
5109+
target_field_name=target_field_name,
5110+
gcs_source_uris=gcs_source_uris,
5111+
bigquery_source_uri=bigquery_source_uri,
5112+
batch_predict_bigquery_destination_output_uri=bigquery_destination_output_uri,
5113+
class_labels=class_labels,
5114+
prediction_label_column=prediction_label_column,
5115+
prediction_score_column=prediction_score_column,
5116+
service_account=service_account,
5117+
pipeline_root=staging_bucket,
5118+
instances_format=instances_format,
5119+
model_type=model_type,
5120+
generate_feature_attributions=generate_feature_attributions,
5121+
evaluation_pipeline_display_name=evaluation_pipeline_display_name,
5122+
evaluation_metrics_display_name=evaluation_metrics_display_name,
5123+
network=network,
5124+
encryption_spec_key_name=encryption_spec_key_name,
5125+
credentials=self.credentials,
5126+
experiment=experiment,
5127+
)
5128+
48985129

48995130
# TODO (b/232546878): Async support
49005131
class ModelRegistry:

0 commit comments

Comments
 (0)