Skip to content

Commit 186872d

Browse files
fix: fix endpoint parsing in ModelDeploymentMonitoringJob.update (#1671)
* fix: fix endpoint parsing in ModelDeploymentMonitoringJob.update() function * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * addressed PR feedback * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * addressed PR comments * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * addressed more PR feedback * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * addressed more PR comments * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * removed unused code * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * addressed more PR feedback * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * fixing linter issues * addressed more PR comments * fixing pylint errors * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * silencing unused import warning * fixed unused import error Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 876fb2a commit 186872d

File tree

3 files changed

+197
-37
lines changed

3 files changed

+197
-37
lines changed

google/cloud/aiplatform/jobs.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -2484,32 +2484,31 @@ def update(
24842484
update_mask.append("model_deployment_monitoring_objective_configs")
24852485
current_job.model_deployment_monitoring_objective_configs = (
24862486
ModelDeploymentMonitoringJob._parse_configs(
2487-
objective_configs,
2488-
current_job.endpoint,
2489-
deployed_model_ids,
2487+
objective_configs=objective_configs,
2488+
endpoint=aiplatform.Endpoint(
2489+
current_job.endpoint, credentials=self.credentials
2490+
),
2491+
deployed_model_ids=deployed_model_ids,
24902492
)
24912493
)
2492-
if self.state == gca_job_state.JobState.JOB_STATE_RUNNING:
2493-
self.api_client.update_model_deployment_monitoring_job(
2494-
model_deployment_monitoring_job=current_job,
2495-
update_mask=field_mask_pb2.FieldMask(paths=update_mask),
2496-
)
2494+
self.api_client.update_model_deployment_monitoring_job(
2495+
model_deployment_monitoring_job=current_job,
2496+
update_mask=field_mask_pb2.FieldMask(paths=update_mask),
2497+
)
24972498
return self
24982499

24992500
def pause(self) -> "ModelDeploymentMonitoringJob":
25002501
"""Pause a running MDM job."""
2501-
if self.state == gca_job_state.JobState.JOB_STATE_RUNNING:
2502-
self.api_client.pause_model_deployment_monitoring_job(
2503-
name=self._gca_resource.name
2504-
)
2502+
self.api_client.pause_model_deployment_monitoring_job(
2503+
name=self._gca_resource.name
2504+
)
25052505
return self
25062506

25072507
def resume(self) -> "ModelDeploymentMonitoringJob":
25082508
"""Resumes a paused MDM job."""
2509-
if self.state == gca_job_state.JobState.JOB_STATE_PAUSED:
2510-
self.api_client.resume_model_deployment_monitoring_job(
2511-
name=self._gca_resource.name
2512-
)
2509+
self.api_client.resume_model_deployment_monitoring_job(
2510+
name=self._gca_resource.name
2511+
)
25132512
return self
25142513

25152514
def delete(self) -> None:

tests/system/aiplatform/test_model_monitoring.py

+83-21
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,28 @@
3131

3232
# constants used for testing
3333
USER_EMAIL = ""
34-
PERMANENT_CHURN_ENDPOINT_ID = "8289570005524152320"
34+
PERMANENT_CHURN_ENDPOINT_ID = "1843089351408353280"
3535
CHURN_MODEL_PATH = "gs://mco-mm/churn"
36+
DEFAULT_INPUT = {
37+
"cnt_ad_reward": 0,
38+
"cnt_challenge_a_friend": 0,
39+
"cnt_completed_5_levels": 1,
40+
"cnt_level_complete_quickplay": 3,
41+
"cnt_level_end_quickplay": 5,
42+
"cnt_level_reset_quickplay": 2,
43+
"cnt_level_start_quickplay": 6,
44+
"cnt_post_score": 34,
45+
"cnt_spend_virtual_currency": 0,
46+
"cnt_use_extra_steps": 0,
47+
"cnt_user_engagement": 120,
48+
"country": "Denmark",
49+
"dayofweek": 3,
50+
"julianday": 254,
51+
"language": "da-dk",
52+
"month": 9,
53+
"operating_system": "IOS",
54+
"user_pseudo_id": "104B0770BAE16E8B53DF330C95881893",
55+
}
3656

3757
JOB_NAME = "churn"
3858

@@ -117,10 +137,7 @@ def test_mdm_two_models_one_valid_config(self):
117137
project=e2e_base._PROJECT,
118138
location=e2e_base._LOCATION,
119139
endpoint=self.endpoint,
120-
predict_instance_schema_uri="",
121-
analysis_instance_schema_uri="",
122140
)
123-
assert job is not None
124141

125142
gapic_job = job._gca_resource
126143
assert (
@@ -156,22 +173,77 @@ def test_mdm_two_models_one_valid_config(self):
156173
gca_obj_config.prediction_drift_detection_config == drift_config.as_proto()
157174
)
158175

176+
# delete this job and re-configure it to only enable drift detection for faster testing
177+
job.delete()
159178
job_resource = job._gca_resource.name
160179

161-
# test job update and delete()
162-
timeout = time.time() + 3600
163-
new_obj_config = model_monitoring.ObjectiveConfig(skew_config)
180+
# test job delete
181+
with pytest.raises(core_exceptions.NotFound):
182+
job.api_client.get_model_deployment_monitoring_job(name=job_resource)
183+
184+
def test_mdm_pause_and_update_config(self):
185+
"""Test objective config updates for existing MDM job"""
186+
job = aiplatform.ModelDeploymentMonitoringJob.create(
187+
display_name=self._make_display_name(key=JOB_NAME),
188+
logging_sampling_strategy=sampling_strategy,
189+
schedule_config=schedule_config,
190+
alert_config=alert_config,
191+
objective_configs=model_monitoring.ObjectiveConfig(
192+
drift_detection_config=drift_config
193+
),
194+
create_request_timeout=3600,
195+
project=e2e_base._PROJECT,
196+
location=e2e_base._LOCATION,
197+
endpoint=self.endpoint,
198+
)
199+
# test unsuccessful job update when it's pending
200+
DRIFT_THRESHOLDS["cnt_user_engagement"] += 0.01
201+
new_obj_config = model_monitoring.ObjectiveConfig(
202+
drift_detection_config=model_monitoring.DriftDetectionConfig(
203+
drift_thresholds=DRIFT_THRESHOLDS,
204+
attribute_drift_thresholds=ATTRIB_DRIFT_THRESHOLDS,
205+
)
206+
)
207+
if job.state == gca_job_state.JobState.JOB_STATE_PENDING:
208+
with pytest.raises(core_exceptions.FailedPrecondition):
209+
job.update(objective_configs=new_obj_config)
210+
211+
# generate traffic to force MDM job to come online
212+
for i in range(2000):
213+
DEFAULT_INPUT["cnt_user_engagement"] += i
214+
self.endpoint.predict([DEFAULT_INPUT], use_raw_predict=True)
164215

165-
while time.time() < timeout:
216+
# test job update
217+
while True:
218+
time.sleep(1)
166219
if job.state == gca_job_state.JobState.JOB_STATE_RUNNING:
167220
job.update(objective_configs=new_obj_config)
168-
assert str(job._gca_resource.prediction_drift_detection_config) == ""
169221
break
170-
time.sleep(5)
171222

223+
# verify job update
224+
while True:
225+
time.sleep(1)
226+
if job.state == gca_job_state.JobState.JOB_STATE_RUNNING:
227+
gca_obj_config = (
228+
job._gca_resource.model_deployment_monitoring_objective_configs[
229+
0
230+
].objective_config
231+
)
232+
assert (
233+
gca_obj_config.prediction_drift_detection_config
234+
== new_obj_config.drift_detection_config.as_proto()
235+
)
236+
break
237+
238+
# test pause
239+
job.pause()
240+
while job.state != gca_job_state.JobState.JOB_STATE_PAUSED:
241+
time.sleep(1)
172242
job.delete()
243+
244+
# confirm deletion
173245
with pytest.raises(core_exceptions.NotFound):
174-
job.api_client.get_model_deployment_monitoring_job(name=job_resource)
246+
job.state
175247

176248
def test_mdm_two_models_two_valid_configs(self):
177249
[deployed_model1, deployed_model2] = list(
@@ -181,7 +253,6 @@ def test_mdm_two_models_two_valid_configs(self):
181253
deployed_model1: objective_config,
182254
deployed_model2: objective_config2,
183255
}
184-
job = None
185256
job = aiplatform.ModelDeploymentMonitoringJob.create(
186257
display_name=self._make_display_name(key=JOB_NAME),
187258
logging_sampling_strategy=sampling_strategy,
@@ -192,10 +263,7 @@ def test_mdm_two_models_two_valid_configs(self):
192263
project=e2e_base._PROJECT,
193264
location=e2e_base._LOCATION,
194265
endpoint=self.endpoint,
195-
predict_instance_schema_uri="",
196-
analysis_instance_schema_uri="",
197266
)
198-
assert job is not None
199267

200268
gapic_job = job._gca_resource
201269
assert (
@@ -246,8 +314,6 @@ def test_mdm_invalid_config_incorrect_model_id(self):
246314
project=e2e_base._PROJECT,
247315
location=e2e_base._LOCATION,
248316
endpoint=self.endpoint,
249-
predict_instance_schema_uri="",
250-
analysis_instance_schema_uri="",
251317
deployed_model_ids=[""],
252318
)
253319
assert "Invalid model ID" in str(e.value)
@@ -265,8 +331,6 @@ def test_mdm_invalid_config_xai(self):
265331
project=e2e_base._PROJECT,
266332
location=e2e_base._LOCATION,
267333
endpoint=self.endpoint,
268-
predict_instance_schema_uri="",
269-
analysis_instance_schema_uri="",
270334
)
271335
assert (
272336
"`explanation_config` should only be enabled if the model has `explanation_spec populated"
@@ -294,8 +358,6 @@ def test_mdm_two_models_invalid_configs_xai(self):
294358
project=e2e_base._PROJECT,
295359
location=e2e_base._LOCATION,
296360
endpoint=self.endpoint,
297-
predict_instance_schema_uri="",
298-
analysis_instance_schema_uri="",
299361
)
300362
assert (
301363
"`explanation_config` should only be enabled if the model has `explanation_spec populated"

tests/unit/aiplatform/test_jobs.py

+99
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,16 @@
3838
job_state as gca_job_state_compat,
3939
machine_resources as gca_machine_resources_compat,
4040
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
41+
model_deployment_monitoring_job as gca_model_deployment_monitoring_job_compat,
42+
model_monitoring as gca_model_monitoring_compat,
4143
)
4244

4345
from google.cloud.aiplatform.compat.services import (
4446
job_service_client,
4547
)
48+
from google.protobuf import field_mask_pb2 # type: ignore
49+
50+
from test_endpoints import get_endpoint_with_models_mock # noqa: F401
4651

4752
_TEST_API_CLIENT = job_service_client.JobServiceClient
4853

@@ -84,6 +89,11 @@
8489
f"bq://{_TEST_BATCH_PREDICTION_BQ_PREFIX}"
8590
)
8691

92+
_TEST_MDM_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/modelDeploymentMonitoringJobs/{_TEST_ID}"
93+
_TEST_ENDPOINT = (
94+
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}"
95+
)
96+
8797
_TEST_JOB_STATE_SUCCESS = gca_job_state_compat.JobState(4)
8898
_TEST_JOB_STATE_RUNNING = gca_job_state_compat.JobState(3)
8999
_TEST_JOB_STATE_PENDING = gca_job_state_compat.JobState(2)
@@ -164,6 +174,8 @@
164174
_TEST_JOB_DELETE_METHOD_NAME = "delete_custom_job"
165175
_TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}"
166176

177+
_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG = {"TEST_KEY": 0.01}
178+
167179
# TODO(b/171333554): Move reusable test fixtures to conftest.py file
168180

169181

@@ -969,3 +981,90 @@ def test_batch_predict_job_with_versioned_model(
969981
].model
970982
== _TEST_VERSIONED_MODEL_NAME
971983
)
984+
985+
986+
@pytest.fixture
987+
def get_mdm_job_mock():
988+
with mock.patch.object(
989+
_TEST_API_CLIENT, "get_model_deployment_monitoring_job"
990+
) as get_mdm_job_mock:
991+
get_mdm_job_mock.return_value = (
992+
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
993+
name=_TEST_MDM_JOB_NAME,
994+
display_name=_TEST_DISPLAY_NAME,
995+
state=_TEST_JOB_STATE_RUNNING,
996+
endpoint=_TEST_ENDPOINT,
997+
)
998+
)
999+
yield get_mdm_job_mock
1000+
1001+
1002+
@pytest.fixture
1003+
@pytest.mark.usefixtures("get_mdm_job_mock")
1004+
def update_mdm_job_mock(get_endpoint_with_models_mock): # noqa: F811
1005+
with mock.patch.object(
1006+
_TEST_API_CLIENT, "update_model_deployment_monitoring_job"
1007+
) as update_mdm_job_mock:
1008+
expected_objective_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig(
1009+
prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig(
1010+
drift_thresholds={
1011+
"TEST_KEY": gca_model_monitoring_compat.ThresholdConfig(value=0.01)
1012+
}
1013+
)
1014+
)
1015+
all_configs = []
1016+
for model in get_endpoint_with_models_mock.return_value.deployed_models:
1017+
all_configs.append(
1018+
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(
1019+
deployed_model_id=model.id,
1020+
objective_config=expected_objective_config,
1021+
)
1022+
)
1023+
1024+
update_mdm_job_mock.return_vaue.result_type = (
1025+
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
1026+
name=_TEST_MDM_JOB_NAME,
1027+
display_name=_TEST_DISPLAY_NAME,
1028+
state=_TEST_JOB_STATE_RUNNING,
1029+
endpoint=_TEST_ENDPOINT,
1030+
model_deployment_monitoring_objective_configs=all_configs,
1031+
)
1032+
)
1033+
yield update_mdm_job_mock
1034+
1035+
1036+
@pytest.mark.usefixtures("google_auth_mock")
1037+
class TestModelDeploymentMonitoringJob:
1038+
def setup_method(self):
1039+
reload(initializer)
1040+
reload(aiplatform)
1041+
1042+
def teardown_method(self):
1043+
initializer.global_pool.shutdown(wait=True)
1044+
1045+
def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock):
1046+
job = jobs.ModelDeploymentMonitoringJob(
1047+
model_deployment_monitoring_job_name=_TEST_MDM_JOB_NAME
1048+
)
1049+
drift_detection_config = aiplatform.model_monitoring.DriftDetectionConfig(
1050+
drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG
1051+
)
1052+
new_config = aiplatform.model_monitoring.ObjectiveConfig(
1053+
drift_detection_config=drift_detection_config
1054+
)
1055+
job.update(objective_configs=new_config)
1056+
assert (
1057+
job._gca_resource.model_deployment_monitoring_objective_configs[
1058+
0
1059+
].objective_config.prediction_drift_detection_config
1060+
== drift_detection_config.as_proto()
1061+
)
1062+
get_mdm_job_mock.assert_called_with(
1063+
name=_TEST_MDM_JOB_NAME,
1064+
)
1065+
update_mdm_job_mock.assert_called_once_with(
1066+
model_deployment_monitoring_job=get_mdm_job_mock.return_value,
1067+
update_mask=field_mask_pb2.FieldMask(
1068+
paths=["model_deployment_monitoring_objective_configs"]
1069+
),
1070+
)

0 commit comments

Comments
 (0)