Skip to content

Commit 9e77c61

Browse files
fix: added proto message conversion to MDMJob.update fields (#1718)
* fix: added proto message conversion to MDMJob.update fields * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * addressed PR comment * formatting * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * replaced string literal with constant * adding _gca_resource re-assignmnet to mdm job class * Added side effects in get_mdm_job pytest mock * fixing side effects * formatting * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * minor edits to variable names * Addressed PR feedback * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * addressed more PR commentes * addressed PR comments * fix linter errors Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 3747ce3 commit 9e77c61

File tree

2 files changed

+126
-49
lines changed

2 files changed

+126
-49
lines changed

google/cloud/aiplatform/jobs.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -2427,7 +2427,8 @@ def update(
24272427
are allowed. See https://goo.gl/xmQnxf for more information
24282428
and examples of labels.
24292429
bigquery_tables_log_ttl (int):
2430-
Optional. The TTL(time to live) of BigQuery tables in user projects
2430+
Optional. The number of days for which the logs are stored.
2431+
The TTL(time to live) of BigQuery tables in user projects
24312432
which stores logs. A day is the basic unit of
24322433
the TTL and we take the ceil of TTL/86400(a
24332434
day). e.g. { second: 3600} indicates ttl = 1
@@ -2453,28 +2454,30 @@ def update(
24532454
will be applied to all deployed models.
24542455
"""
24552456
self._sync_gca_resource()
2456-
current_job = self.api_client.get_model_deployment_monitoring_job(
2457-
name=self._gca_resource.name
2458-
)
2457+
current_job = copy.deepcopy(self._gca_resource)
24592458
update_mask: List[str] = []
24602459
if display_name is not None:
24612460
update_mask.append("display_name")
24622461
current_job.display_name = display_name
24632462
if schedule_config is not None:
24642463
update_mask.append("model_deployment_monitoring_schedule_config")
2465-
current_job.model_deployment_monitoring_schedule_config = schedule_config
2464+
current_job.model_deployment_monitoring_schedule_config = (
2465+
schedule_config.as_proto()
2466+
)
24662467
if alert_config is not None:
24672468
update_mask.append("model_monitoring_alert_config")
2468-
current_job.model_monitoring_alert_config = alert_config
2469+
current_job.model_monitoring_alert_config = alert_config.as_proto()
24692470
if logging_sampling_strategy is not None:
24702471
update_mask.append("logging_sampling_strategy")
2471-
current_job.logging_sampling_strategy = logging_sampling_strategy
2472+
current_job.logging_sampling_strategy = logging_sampling_strategy.as_proto()
24722473
if labels is not None:
24732474
update_mask.append("labels")
2474-
current_job.lables = labels
2475+
current_job.labels = labels
24752476
if bigquery_tables_log_ttl is not None:
24762477
update_mask.append("log_ttl")
2477-
current_job.log_ttl = bigquery_tables_log_ttl
2478+
current_job.log_ttl = duration_pb2.Duration(
2479+
seconds=bigquery_tables_log_ttl * 86400
2480+
)
24782481
if enable_monitoring_pipeline_logs is not None:
24792482
update_mask.append("enable_monitoring_pipeline_logs")
24802483
current_job.enable_monitoring_pipeline_logs = (
@@ -2491,10 +2494,12 @@ def update(
24912494
deployed_model_ids=deployed_model_ids,
24922495
)
24932496
)
2494-
self.api_client.update_model_deployment_monitoring_job(
2497+
# TODO: b/254285776 add optional_sync support to model monitoring job
2498+
lro = self.api_client.update_model_deployment_monitoring_job(
24952499
model_deployment_monitoring_job=current_job,
24962500
update_mask=field_mask_pb2.FieldMask(paths=update_mask),
24972501
)
2502+
self._gca_resource = lro.result()
24982503
return self
24992504

25002505
def pause(self) -> "ModelDeploymentMonitoringJob":

tests/unit/aiplatform/test_jobs.py

+111-39
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
import pytest
19+
import copy
1920

2021
from unittest import mock
2122
from importlib import reload
@@ -24,6 +25,7 @@
2425
from google.cloud import storage
2526
from google.cloud import bigquery
2627

28+
from google.api_core import operation
2729
from google.auth import credentials as auth_credentials
2830

2931
from google.cloud import aiplatform
@@ -46,7 +48,9 @@
4648
job_service_client,
4749
)
4850
from google.protobuf import field_mask_pb2 # type: ignore
51+
from google.protobuf import duration_pb2 # type: ignore
4952

53+
import test_endpoints # noqa: F401
5054
from test_endpoints import get_endpoint_with_models_mock # noqa: F401
5155

5256
_TEST_API_CLIENT = job_service_client.JobServiceClient
@@ -175,6 +179,58 @@
175179
_TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}"
176180

177181
_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG = {"TEST_KEY": 0.01}
182+
_TEST_MDM_USER_EMAIL = "TEST_EMAIL"
183+
_TEST_MDM_SAMPLE_RATE = 0.5
184+
_TEST_MDM_LABEL = {"TEST KEY": "TEST VAL"}
185+
_TEST_LOG_TTL_IN_DAYS = 1
186+
_TEST_MDM_NEW_NAME = "NEW_NAME"
187+
188+
_TEST_MDM_OLD_JOB = (
189+
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
190+
name=_TEST_MDM_JOB_NAME,
191+
display_name=_TEST_DISPLAY_NAME,
192+
endpoint=_TEST_ENDPOINT,
193+
state=_TEST_JOB_STATE_RUNNING,
194+
)
195+
)
196+
197+
_TEST_MDM_EXPECTED_NEW_JOB = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
198+
name=_TEST_MDM_JOB_NAME,
199+
display_name=_TEST_MDM_NEW_NAME,
200+
endpoint=_TEST_ENDPOINT,
201+
state=_TEST_JOB_STATE_RUNNING,
202+
model_deployment_monitoring_objective_configs=[
203+
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(
204+
deployed_model_id=model_id,
205+
objective_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig(
206+
prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig(
207+
drift_thresholds={
208+
"TEST_KEY": gca_model_monitoring_compat.ThresholdConfig(
209+
value=0.01
210+
)
211+
}
212+
)
213+
),
214+
)
215+
for model_id in [model.id for model in test_endpoints._TEST_DEPLOYED_MODELS]
216+
],
217+
logging_sampling_strategy=gca_model_monitoring_compat.SamplingStrategy(
218+
random_sample_config=gca_model_monitoring_compat.SamplingStrategy.RandomSampleConfig(
219+
sample_rate=_TEST_MDM_SAMPLE_RATE
220+
)
221+
),
222+
labels=_TEST_MDM_LABEL,
223+
model_monitoring_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig(
224+
email_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig.EmailAlertConfig(
225+
user_emails=[_TEST_MDM_USER_EMAIL]
226+
)
227+
),
228+
model_deployment_monitoring_schedule_config=gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringScheduleConfig(
229+
monitor_interval=duration_pb2.Duration(seconds=3600)
230+
),
231+
log_ttl=duration_pb2.Duration(seconds=_TEST_LOG_TTL_IN_DAYS * 86400),
232+
enable_monitoring_pipeline_logs=True,
233+
)
178234

179235
# TODO(b/171333554): Move reusable test fixtures to conftest.py file
180236

@@ -988,48 +1044,23 @@ def get_mdm_job_mock():
9881044
with mock.patch.object(
9891045
_TEST_API_CLIENT, "get_model_deployment_monitoring_job"
9901046
) as get_mdm_job_mock:
991-
get_mdm_job_mock.return_value = (
992-
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
993-
name=_TEST_MDM_JOB_NAME,
994-
display_name=_TEST_DISPLAY_NAME,
995-
state=_TEST_JOB_STATE_RUNNING,
996-
endpoint=_TEST_ENDPOINT,
997-
)
998-
)
1047+
get_mdm_job_mock.side_effect = [
1048+
_TEST_MDM_OLD_JOB,
1049+
_TEST_MDM_OLD_JOB,
1050+
_TEST_MDM_OLD_JOB,
1051+
_TEST_MDM_EXPECTED_NEW_JOB,
1052+
]
9991053
yield get_mdm_job_mock
10001054

10011055

10021056
@pytest.fixture
1003-
@pytest.mark.usefixtures("get_mdm_job_mock")
10041057
def update_mdm_job_mock(get_endpoint_with_models_mock): # noqa: F811
10051058
with mock.patch.object(
10061059
_TEST_API_CLIENT, "update_model_deployment_monitoring_job"
10071060
) as update_mdm_job_mock:
1008-
expected_objective_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig(
1009-
prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig(
1010-
drift_thresholds={
1011-
"TEST_KEY": gca_model_monitoring_compat.ThresholdConfig(value=0.01)
1012-
}
1013-
)
1014-
)
1015-
all_configs = []
1016-
for model in get_endpoint_with_models_mock.return_value.deployed_models:
1017-
all_configs.append(
1018-
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(
1019-
deployed_model_id=model.id,
1020-
objective_config=expected_objective_config,
1021-
)
1022-
)
1023-
1024-
update_mdm_job_mock.return_vaue.result_type = (
1025-
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
1026-
name=_TEST_MDM_JOB_NAME,
1027-
display_name=_TEST_DISPLAY_NAME,
1028-
state=_TEST_JOB_STATE_RUNNING,
1029-
endpoint=_TEST_ENDPOINT,
1030-
model_deployment_monitoring_objective_configs=all_configs,
1031-
)
1032-
)
1061+
update_mdm_job_lro_mock = mock.Mock(operation.Operation)
1062+
update_mdm_job_lro_mock.result.return_value = _TEST_MDM_EXPECTED_NEW_JOB
1063+
update_mdm_job_mock.return_value = update_mdm_job_lro_mock
10331064
yield update_mdm_job_mock
10341065

10351066

@@ -1046,25 +1077,66 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock):
10461077
job = jobs.ModelDeploymentMonitoringJob(
10471078
model_deployment_monitoring_job_name=_TEST_MDM_JOB_NAME
10481079
)
1080+
old_job = copy.deepcopy(job._gca_resource)
10491081
drift_detection_config = aiplatform.model_monitoring.DriftDetectionConfig(
10501082
drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG
10511083
)
1084+
schedule_config = aiplatform.model_monitoring.ScheduleConfig(monitor_interval=1)
1085+
alert_config = aiplatform.model_monitoring.EmailAlertConfig(
1086+
user_emails=[_TEST_MDM_USER_EMAIL]
1087+
)
1088+
sampling_strategy = aiplatform.model_monitoring.RandomSampleConfig(
1089+
sample_rate=_TEST_MDM_SAMPLE_RATE
1090+
)
1091+
labels = _TEST_MDM_LABEL
1092+
log_ttl = _TEST_LOG_TTL_IN_DAYS
1093+
display_name = _TEST_MDM_NEW_NAME
10521094
new_config = aiplatform.model_monitoring.ObjectiveConfig(
10531095
drift_detection_config=drift_detection_config
10541096
)
1055-
job.update(objective_configs=new_config)
1097+
job.update(
1098+
display_name=display_name,
1099+
schedule_config=schedule_config,
1100+
alert_config=alert_config,
1101+
logging_sampling_strategy=sampling_strategy,
1102+
labels=labels,
1103+
bigquery_tables_log_ttl=log_ttl,
1104+
enable_monitoring_pipeline_logs=True,
1105+
objective_configs=new_config,
1106+
)
1107+
new_job = job._gca_resource
1108+
assert old_job != new_job
1109+
assert new_job.display_name == display_name
1110+
assert new_job.logging_sampling_strategy == sampling_strategy.as_proto()
1111+
assert (
1112+
new_job.model_deployment_monitoring_schedule_config
1113+
== schedule_config.as_proto()
1114+
)
1115+
assert new_job.labels == labels
1116+
assert new_job.model_monitoring_alert_config == alert_config.as_proto()
1117+
assert new_job.log_ttl.days == _TEST_LOG_TTL_IN_DAYS
1118+
assert new_job.enable_monitoring_pipeline_logs
10561119
assert (
1057-
job._gca_resource.model_deployment_monitoring_objective_configs[
1120+
new_job.model_deployment_monitoring_objective_configs[
10581121
0
10591122
].objective_config.prediction_drift_detection_config
10601123
== drift_detection_config.as_proto()
10611124
)
10621125
get_mdm_job_mock.assert_called_with(
1063-
name=_TEST_MDM_JOB_NAME,
1126+
name=_TEST_MDM_JOB_NAME, retry=base._DEFAULT_RETRY
10641127
)
10651128
update_mdm_job_mock.assert_called_once_with(
1066-
model_deployment_monitoring_job=get_mdm_job_mock.return_value,
1129+
model_deployment_monitoring_job=new_job,
10671130
update_mask=field_mask_pb2.FieldMask(
1068-
paths=["model_deployment_monitoring_objective_configs"]
1131+
paths=[
1132+
"display_name",
1133+
"model_deployment_monitoring_schedule_config",
1134+
"model_monitoring_alert_config",
1135+
"logging_sampling_strategy",
1136+
"labels",
1137+
"log_ttl",
1138+
"enable_monitoring_pipeline_logs",
1139+
"model_deployment_monitoring_objective_configs",
1140+
]
10691141
),
10701142
)

0 commit comments

Comments
 (0)