Skip to content

Commit bbec998

Browse files
authored
feat: support model monitoring for batch prediction in Vertex SDK (#1570)
* feat: support model monitoring for batch prediction in Vertex SDK * fixed broken tests * fixing syntax error * addressed comments * updated test variable name
1 parent 3d3e0aa commit bbec998

File tree

6 files changed

+224
-58
lines changed

6 files changed

+224
-58
lines changed

google/cloud/aiplatform/jobs.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,13 @@ def create(
385385
sync: bool = True,
386386
create_request_timeout: Optional[float] = None,
387387
batch_size: Optional[int] = None,
388+
model_monitoring_objective_config: Optional[
389+
"aiplatform.model_monitoring.ObjectiveConfig"
390+
] = None,
391+
model_monitoring_alert_config: Optional[
392+
"aiplatform.model_monitoring.AlertConfig"
393+
] = None,
394+
analysis_instance_schema_uri: Optional[str] = None,
388395
) -> "BatchPredictionJob":
389396
"""Create a batch prediction job.
390397
@@ -551,6 +558,23 @@ def create(
551558
but too high value will result in a whole batch not fitting in a machine's memory,
552559
and the whole operation will fail.
553560
The default value is 64.
561+
model_monitoring_objective_config (aiplatform.model_monitoring.ObjectiveConfig):
562+
Optional. The objective config for model monitoring. Passing this parameter enables
563+
monitoring on the model associated with this batch prediction job.
564+
model_monitoring_alert_config (aiplatform.model_monitoring.EmailAlertConfig):
565+
Optional. Configures how model monitoring alerts are sent to the user. Right now
566+
only email alert is supported.
567+
analysis_instance_schema_uri (str):
568+
Optional. Only applicable if model_monitoring_objective_config is also passed.
569+
This parameter specifies the YAML schema file uri describing the format of a single
570+
instance that you want Tensorflow Data Validation (TFDV) to
571+
analyze. If this field is empty, all the feature data types are
572+
inferred from predict_instance_schema_uri, meaning that TFDV
573+
will use the data in the exact format as prediction request/response.
574+
If there are any data type differences between predict instance
575+
and TFDV instance, this field can be used to override the schema.
576+
For models trained with Vertex AI, this field must be set as all the
577+
fields in predict instance formatted as string.
554578
Returns:
555579
(jobs.BatchPredictionJob):
556580
Instantiated representation of the created batch prediction job.
@@ -601,7 +625,18 @@ def create(
601625
f"{predictions_format} is not an accepted prediction format "
602626
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
603627
)
604-
628+
# TODO: remove temporary import statements once model monitoring for batch prediction is GA
629+
if model_monitoring_objective_config:
630+
from google.cloud.aiplatform.compat.types import (
631+
io_v1beta1 as gca_io_compat,
632+
batch_prediction_job_v1beta1 as gca_bp_job_compat,
633+
model_monitoring_v1beta1 as gca_model_monitoring_compat,
634+
)
635+
else:
636+
from google.cloud.aiplatform.compat.types import (
637+
io as gca_io_compat,
638+
batch_prediction_job as gca_bp_job_compat,
639+
)
605640
gapic_batch_prediction_job = gca_bp_job_compat.BatchPredictionJob()
606641

607642
# Required Fields
@@ -688,6 +723,28 @@ def create(
688723
)
689724
)
690725

726+
# Model Monitoring
727+
if model_monitoring_objective_config:
728+
if model_monitoring_objective_config.drift_detection_config:
729+
_LOGGER.info(
730+
"Drift detection config is currently not supported for monitoring models associated with batch prediction jobs."
731+
)
732+
if model_monitoring_objective_config.explanation_config:
733+
_LOGGER.info(
734+
"XAI config is currently not supported for monitoring models associated with batch prediction jobs."
735+
)
736+
gapic_batch_prediction_job.model_monitoring_config = (
737+
gca_model_monitoring_compat.ModelMonitoringConfig(
738+
objective_configs=[
739+
model_monitoring_objective_config.as_proto(config_for_bp=True)
740+
],
741+
alert_config=model_monitoring_alert_config.as_proto(
742+
config_for_bp=True
743+
),
744+
analysis_instance_schema_uri=analysis_instance_schema_uri,
745+
)
746+
)
747+
691748
empty_batch_prediction_job = cls._empty_constructor(
692749
project=project,
693750
location=location,
@@ -702,6 +759,11 @@ def create(
702759
sync=sync,
703760
create_request_timeout=create_request_timeout,
704761
)
762+
# TODO: b/242108750
763+
from google.cloud.aiplatform.compat.types import (
764+
io as gca_io_compat,
765+
batch_prediction_job as gca_bp_job_compat,
766+
)
705767

706768
@classmethod
707769
@base.optional_sync(return_input_arg="empty_batch_prediction_job")

google/cloud/aiplatform/model_monitoring/alert.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,16 @@
1717

1818
from typing import Optional, List
1919
from google.cloud.aiplatform_v1.types import (
20-
model_monitoring as gca_model_monitoring,
20+
model_monitoring as gca_model_monitoring_v1,
2121
)
2222

23+
# TODO: remove imports from v1beta1 once model monitoring for batch prediction is GA
24+
from google.cloud.aiplatform_v1beta1.types import (
25+
model_monitoring as gca_model_monitoring_v1beta1,
26+
)
27+
28+
gca_model_monitoring = gca_model_monitoring_v1
29+
2330

2431
class EmailAlertConfig:
2532
def __init__(
@@ -40,8 +47,19 @@ def __init__(
4047
self.enable_logging = enable_logging
4148
self.user_emails = user_emails
4249

43-
def as_proto(self):
44-
"""Returns EmailAlertConfig as a proto message."""
50+
# TODO: remove config_for_bp parameter when model monitoring for batch prediction is GA
51+
def as_proto(self, config_for_bp: bool = False):
52+
"""Returns EmailAlertConfig as a proto message.
53+
54+
Args:
55+
config_for_bp (bool):
56+
Optional. Set this parameter to True if the config object
57+
is used for model monitoring on a batch prediction job.
58+
"""
59+
if config_for_bp:
60+
gca_model_monitoring = gca_model_monitoring_v1beta1
61+
else:
62+
gca_model_monitoring = gca_model_monitoring_v1
4563
user_email_alert_config = (
4664
gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig(
4765
user_emails=self.user_emails

google/cloud/aiplatform/model_monitoring/objective.py

+63-33
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,19 @@
1818
from typing import Optional, Dict
1919

2020
from google.cloud.aiplatform_v1.types import (
21-
io as gca_io,
22-
ThresholdConfig as gca_threshold_config,
23-
model_monitoring as gca_model_monitoring,
21+
io as gca_io_v1,
22+
model_monitoring as gca_model_monitoring_v1,
2423
)
2524

25+
# TODO: b/242108750
26+
from google.cloud.aiplatform_v1beta1.types import (
27+
io as gca_io_v1beta1,
28+
model_monitoring as gca_model_monitoring_v1beta1,
29+
)
30+
31+
gca_model_monitoring = gca_model_monitoring_v1
32+
gca_io = gca_io_v1
33+
2634
TF_RECORD = "tf-record"
2735
CSV = "csv"
2836
JSONL = "jsonl"
@@ -80,19 +88,20 @@ def __init__(
8088
self.attribute_skew_thresholds = attribute_skew_thresholds
8189
self.data_format = data_format
8290
self.target_field = target_field
83-
self.training_dataset = None
8491

8592
def as_proto(self):
8693
"""Returns _SkewDetectionConfig as a proto message."""
8794
skew_thresholds_mapping = {}
8895
attribution_score_skew_thresholds_mapping = {}
8996
if self.skew_thresholds is not None:
9097
for key in self.skew_thresholds.keys():
91-
skew_threshold = gca_threshold_config(value=self.skew_thresholds[key])
98+
skew_threshold = gca_model_monitoring.ThresholdConfig(
99+
value=self.skew_thresholds[key]
100+
)
92101
skew_thresholds_mapping[key] = skew_threshold
93102
if self.attribute_skew_thresholds is not None:
94103
for key in self.attribute_skew_thresholds.keys():
95-
attribution_score_skew_threshold = gca_threshold_config(
104+
attribution_score_skew_threshold = gca_model_monitoring.ThresholdConfig(
96105
value=self.attribute_skew_thresholds[key]
97106
)
98107
attribution_score_skew_thresholds_mapping[
@@ -134,12 +143,16 @@ def as_proto(self):
134143
attribution_score_drift_thresholds_mapping = {}
135144
if self.drift_thresholds is not None:
136145
for key in self.drift_thresholds.keys():
137-
drift_threshold = gca_threshold_config(value=self.drift_thresholds[key])
146+
drift_threshold = gca_model_monitoring.ThresholdConfig(
147+
value=self.drift_thresholds[key]
148+
)
138149
drift_thresholds_mapping[key] = drift_threshold
139150
if self.attribute_drift_thresholds is not None:
140151
for key in self.attribute_drift_thresholds.keys():
141-
attribution_score_drift_threshold = gca_threshold_config(
142-
value=self.attribute_drift_thresholds[key]
152+
attribution_score_drift_threshold = (
153+
gca_model_monitoring.ThresholdConfig(
154+
value=self.attribute_drift_thresholds[key]
155+
)
143156
)
144157
attribution_score_drift_thresholds_mapping[
145158
key
@@ -186,11 +199,49 @@ def __init__(
186199
self.drift_detection_config = drift_detection_config
187200
self.explanation_config = explanation_config
188201

189-
def as_proto(self):
190-
"""Returns _ObjectiveConfig as a proto message."""
202+
# TODO: b/242108750
203+
def as_proto(self, config_for_bp: bool = False):
204+
"""Returns _SkewDetectionConfig as a proto message.
205+
206+
Args:
207+
config_for_bp (bool):
208+
Optional. Set this parameter to True if the config object
209+
is used for model monitoring on a batch prediction job.
210+
"""
211+
if config_for_bp:
212+
gca_io = gca_io_v1beta1
213+
gca_model_monitoring = gca_model_monitoring_v1beta1
214+
else:
215+
gca_io = gca_io_v1
216+
gca_model_monitoring = gca_model_monitoring_v1
191217
training_dataset = None
192218
if self.skew_detection_config is not None:
193-
training_dataset = self.skew_detection_config.training_dataset
219+
training_dataset = (
220+
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
221+
target_field=self.skew_detection_config.target_field
222+
)
223+
)
224+
if self.skew_detection_config.data_source.startswith("bq:/"):
225+
training_dataset.bigquery_source = gca_io.BigQuerySource(
226+
input_uri=self.skew_detection_config.data_source
227+
)
228+
elif self.skew_detection_config.data_source.startswith("gs:/"):
229+
training_dataset.gcs_source = gca_io.GcsSource(
230+
uris=[self.skew_detection_config.data_source]
231+
)
232+
if (
233+
self.skew_detection_config.data_format is not None
234+
and self.skew_detection_config.data_format
235+
not in [TF_RECORD, CSV, JSONL]
236+
):
237+
raise ValueError(
238+
"Unsupported value in skew detection config. `data_format` must be one of %s, %s, or %s"
239+
% (TF_RECORD, CSV, JSONL)
240+
)
241+
training_dataset.data_format = self.skew_detection_config.data_format
242+
else:
243+
training_dataset.dataset = self.skew_detection_config.data_source
244+
194245
return gca_model_monitoring.ModelMonitoringObjectiveConfig(
195246
training_dataset=training_dataset,
196247
training_prediction_skew_detection_config=self.skew_detection_config.as_proto()
@@ -271,27 +322,6 @@ def __init__(
271322
data_format,
272323
)
273324

274-
training_dataset = (
275-
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
276-
target_field=target_field
277-
)
278-
)
279-
if data_source.startswith("bq:/"):
280-
training_dataset.bigquery_source = gca_io.BigQuerySource(
281-
input_uri=data_source
282-
)
283-
elif data_source.startswith("gs:/"):
284-
training_dataset.gcs_source = gca_io.GcsSource(uris=[data_source])
285-
if data_format is not None and data_format not in [TF_RECORD, CSV, JSONL]:
286-
raise ValueError(
287-
"Unsupported value. `data_format` must be one of %s, %s, or %s"
288-
% (TF_RECORD, CSV, JSONL)
289-
)
290-
training_dataset.data_format = data_format
291-
else:
292-
training_dataset.dataset = data_source
293-
self.training_dataset = training_dataset
294-
295325

296326
class DriftDetectionConfig(_DriftDetectionConfig):
297327
"""A class that configures prediction drift detection for models deployed to an endpoint.

tests/system/aiplatform/test_model_monitoring.py

+27-9
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,15 @@
2424
from google.api_core import exceptions as core_exceptions
2525
from tests.system.aiplatform import e2e_base
2626

27+
from google.cloud.aiplatform_v1.types import (
28+
io as gca_io,
29+
model_monitoring as gca_model_monitoring,
30+
)
31+
2732
# constants used for testing
2833
USER_EMAIL = ""
29-
MODEL_NAME = "churn"
30-
MODEL_NAME2 = "churn2"
34+
MODEL_DISPLAYNAME_KEY = "churn"
35+
MODEL_DISPLAYNAME_KEY2 = "churn2"
3136
IMAGE = "us-docker.pkg.dev/cloud-aiplatform/prediction/tf2-cpu.2-5:latest"
3237
ENDPOINT = "us-central1-aiplatform.googleapis.com"
3338
CHURN_MODEL_PATH = "gs://mco-mm/churn"
@@ -139,7 +144,7 @@ def temp_endpoint(self, shared_state):
139144
)
140145

141146
model = aiplatform.Model.upload(
142-
display_name=self._make_display_name(key=MODEL_NAME),
147+
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY),
143148
artifact_uri=CHURN_MODEL_PATH,
144149
serving_container_image_uri=IMAGE,
145150
)
@@ -157,19 +162,19 @@ def temp_endpoint_with_two_models(self, shared_state):
157162
)
158163

159164
model1 = aiplatform.Model.upload(
160-
display_name=self._make_display_name(key=MODEL_NAME),
165+
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY),
161166
artifact_uri=CHURN_MODEL_PATH,
162167
serving_container_image_uri=IMAGE,
163168
)
164169

165170
model2 = aiplatform.Model.upload(
166-
display_name=self._make_display_name(key=MODEL_NAME),
171+
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY2),
167172
artifact_uri=CHURN_MODEL_PATH,
168173
serving_container_image_uri=IMAGE,
169174
)
170175
shared_state["resources"] = [model1, model2]
171176
endpoint = aiplatform.Endpoint.create(
172-
display_name=self._make_display_name(key=MODEL_NAME)
177+
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY)
173178
)
174179
endpoint.deploy(
175180
model=model1, machine_type="n1-standard-2", traffic_percentage=100
@@ -224,7 +229,14 @@ def test_mdm_one_model_one_valid_config(self, shared_state):
224229
gca_obj_config = gapic_job.model_deployment_monitoring_objective_configs[
225230
0
226231
].objective_config
227-
assert gca_obj_config.training_dataset == skew_config.training_dataset
232+
233+
expected_training_dataset = (
234+
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
235+
bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI),
236+
target_field=TARGET,
237+
)
238+
)
239+
assert gca_obj_config.training_dataset == expected_training_dataset
228240
assert (
229241
gca_obj_config.training_prediction_skew_detection_config
230242
== skew_config.as_proto()
@@ -297,12 +309,18 @@ def test_mdm_two_models_two_valid_configs(self, shared_state):
297309
)
298310
assert gapic_job.model_monitoring_alert_config.enable_logging
299311

312+
expected_training_dataset = (
313+
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
314+
bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI),
315+
target_field=TARGET,
316+
)
317+
)
318+
300319
for config in gapic_job.model_deployment_monitoring_objective_configs:
301320
gca_obj_config = config.objective_config
302321
deployed_model_id = config.deployed_model_id
303322
assert (
304-
gca_obj_config.training_dataset
305-
== all_configs[deployed_model_id].skew_detection_config.training_dataset
323+
gca_obj_config.as_proto().training_dataset == expected_training_dataset
306324
)
307325
assert (
308326
gca_obj_config.training_prediction_skew_detection_config

0 commit comments

Comments
 (0)