Skip to content

Commit 22151e2

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Fix the sync option for Model Monitor job creation
PiperOrigin-RevId: 653408498
1 parent 217faf8 commit 22151e2

File tree

2 files changed

+296
-116
lines changed

2 files changed

+296
-116
lines changed

tests/unit/vertexai/test_model_monitors.py

+82-22
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
model_monitoring_stats_v1beta1 as gca_model_monitoring_stats,
3737
schedule_service_v1beta1 as gca_schedule_service,
3838
schedule_v1beta1 as gca_schedule,
39+
job_state_v1beta1 as gca_job_state,
3940
explanation_v1beta1 as explanation,
4041
)
4142
from vertexai.resources.preview import (
@@ -51,7 +52,10 @@
5152

5253
# -*- coding: utf-8 -*-
5354

54-
_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials())
55+
_TEST_CREDENTIALS = mock.Mock(
56+
spec=auth_credentials.AnonymousCredentials(),
57+
universe_domain="googleapis.com",
58+
)
5559
_TEST_DESCRIPTION = "test description"
5660
_TEST_JSON_CONTENT_TYPE = "application/json"
5761
_TEST_LOCATION = "us-central1"
@@ -178,6 +182,9 @@
178182
user_emails=[_TEST_NOTIFICATION_EMAIL]
179183
),
180184
),
185+
explanation_spec=explanation.ExplanationSpec(
186+
parameters=explanation.ExplanationParameters(top_k=10)
187+
),
181188
)
182189
_TEST_UPDATED_MODEL_MONITOR_OBJ = gca_model_monitor.ModelMonitor(
183190
name=_TEST_MODEL_MONITOR_RESOURCE_NAME,
@@ -222,6 +229,9 @@
222229
user_emails=[_TEST_NOTIFICATION_EMAIL, "[email protected]"]
223230
),
224231
),
232+
explanation_spec=explanation.ExplanationSpec(
233+
parameters=explanation.ExplanationParameters(top_k=10)
234+
),
225235
)
226236
_TEST_CREATE_MODEL_MONITORING_JOB_OBJ = gca_model_monitoring_job.ModelMonitoringJob(
227237
display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME,
@@ -249,7 +259,9 @@
249259
vertex_dataset=_TEST_TARGET_RESOURCE
250260
)
251261
),
252-
explanation_spec=explanation.ExplanationSpec(),
262+
explanation_spec=explanation.ExplanationSpec(
263+
parameters=explanation.ExplanationParameters(top_k=10)
264+
),
253265
),
254266
output_spec=gca_model_monitoring_spec.ModelMonitoringOutputSpec(
255267
gcs_base_directory=io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_PATH)
@@ -288,7 +300,9 @@
288300
vertex_dataset=_TEST_TARGET_RESOURCE
289301
)
290302
),
291-
explanation_spec=explanation.ExplanationSpec(),
303+
explanation_spec=explanation.ExplanationSpec(
304+
parameters=explanation.ExplanationParameters(top_k=10)
305+
),
292306
),
293307
output_spec=gca_model_monitoring_spec.ModelMonitoringOutputSpec(
294308
gcs_base_directory=io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_PATH)
@@ -299,6 +313,7 @@
299313
)
300314
),
301315
),
316+
state=gca_job_state.JobState.JOB_STATE_SUCCEEDED,
302317
)
303318
_TEST_CRON = r"America/New_York 1 \* \* \* \*"
304319
_TEST_SCHEDULE_OBJ = gca_schedule.Schedule(
@@ -336,7 +351,9 @@
336351
vertex_dataset=_TEST_TARGET_RESOURCE
337352
)
338353
),
339-
explanation_spec=explanation.ExplanationSpec(),
354+
explanation_spec=explanation.ExplanationSpec(
355+
parameters=explanation.ExplanationParameters(top_k=10)
356+
),
340357
),
341358
output_spec=gca_model_monitoring_spec.ModelMonitoringOutputSpec(
342359
gcs_base_directory=io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_PATH)
@@ -564,7 +581,12 @@ def get_model_monitoring_job_mock():
564581
model_monitoring_service_client.ModelMonitoringServiceClient,
565582
"get_model_monitoring_job",
566583
) as get_model_monitoring_job_mock:
567-
get_model_monitor_mock.return_value = _TEST_MODEL_MONITORING_JOB_OBJ
584+
model_monitoring_job_mock = mock.Mock(
585+
spec=gca_model_monitoring_job.ModelMonitoringJob
586+
)
587+
model_monitoring_job_mock.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED
588+
model_monitoring_job_mock.name = _TEST_MODEL_MONITORING_JOB_RESOURCE_NAME
589+
get_model_monitoring_job_mock.return_value = model_monitoring_job_mock
568590
yield get_model_monitoring_job_mock
569591

570592

@@ -762,6 +784,9 @@ def test_create_schedule(self, create_schedule_mock):
762784
notification_spec=ml_monitoring.spec.NotificationSpec(
763785
user_emails=[_TEST_NOTIFICATION_EMAIL]
764786
),
787+
explanation_spec=explanation.ExplanationSpec(
788+
parameters=explanation.ExplanationParameters(top_k=10)
789+
),
765790
)
766791
test_model_monitor.create_schedule(
767792
display_name=_TEST_SCHEDULE_NAME,
@@ -851,9 +876,12 @@ def test_update_schedule(self, update_schedule_mock, get_schedule_mock):
851876
assert get_schedule_mock.call_count == 1
852877

853878
@pytest.mark.usefixtures(
854-
"create_model_monitoring_job_mock", "create_model_monitor_mock"
879+
"create_model_monitoring_job_mock",
880+
"create_model_monitor_mock",
881+
"get_model_monitoring_job_mock",
855882
)
856-
def test_run_model_monitoring_job(self, create_model_monitoring_job_mock):
883+
@pytest.mark.parametrize("sync", [True, False])
884+
def test_run_model_monitoring_job(self, create_model_monitoring_job_mock, sync):
857885
aiplatform.init(
858886
project=_TEST_PROJECT,
859887
location=_TEST_LOCATION,
@@ -866,6 +894,15 @@ def test_run_model_monitoring_job(self, create_model_monitoring_job_mock):
866894
model_name=_TEST_MODEL_NAME,
867895
display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME,
868896
model_version_id=_TEST_MODEL_VERSION_ID,
897+
)
898+
test_model_monitoring_job = test_model_monitor.run(
899+
display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME,
900+
baseline_dataset=ml_monitoring.spec.MonitoringInput(
901+
vertex_dataset=_TEST_BASELINE_RESOURCE
902+
),
903+
target_dataset=ml_monitoring.spec.MonitoringInput(
904+
vertex_dataset=_TEST_TARGET_RESOURCE
905+
),
869906
tabular_objective_spec=ml_monitoring.spec.TabularObjective(
870907
feature_drift_spec=ml_monitoring.spec.DataDriftSpec(
871908
default_categorical_alert_threshold=0.1,
@@ -876,13 +913,15 @@ def test_run_model_monitoring_job(self, create_model_monitoring_job_mock):
876913
notification_spec=ml_monitoring.spec.NotificationSpec(
877914
user_emails=[_TEST_NOTIFICATION_EMAIL]
878915
),
879-
)
880-
test_model_monitor.run(
881-
display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME,
882-
target_dataset=ml_monitoring.spec.MonitoringInput(
883-
vertex_dataset=_TEST_TARGET_RESOURCE
916+
explanation_spec=explanation.ExplanationSpec(
917+
parameters=explanation.ExplanationParameters(top_k=10)
884918
),
919+
sync=sync,
885920
)
921+
922+
if not sync:
923+
test_model_monitoring_job.wait()
924+
886925
create_model_monitoring_job_mock.assert_called_once_with(
887926
request=gca_model_monitoring_service.CreateModelMonitoringJobRequest(
888927
parent=_TEST_MODEL_MONITOR_RESOURCE_NAME,
@@ -891,7 +930,9 @@ def test_run_model_monitoring_job(self, create_model_monitoring_job_mock):
891930
)
892931

893932
@pytest.mark.usefixtures(
894-
"create_model_monitoring_job_mock", "create_model_monitor_mock"
933+
"create_model_monitoring_job_mock",
934+
"create_model_monitor_mock",
935+
"get_model_monitoring_job_mock",
895936
)
896937
def test_run_model_monitoring_job_with_user_id(
897938
self, create_model_monitoring_job_mock
@@ -908,6 +949,15 @@ def test_run_model_monitoring_job_with_user_id(
908949
model_name=_TEST_MODEL_NAME,
909950
display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME,
910951
model_version_id=_TEST_MODEL_VERSION_ID,
952+
)
953+
test_model_monitor.run(
954+
display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME,
955+
baseline_dataset=ml_monitoring.spec.MonitoringInput(
956+
vertex_dataset=_TEST_BASELINE_RESOURCE
957+
),
958+
target_dataset=ml_monitoring.spec.MonitoringInput(
959+
vertex_dataset=_TEST_TARGET_RESOURCE
960+
),
911961
tabular_objective_spec=ml_monitoring.spec.TabularObjective(
912962
feature_drift_spec=ml_monitoring.spec.DataDriftSpec(
913963
default_categorical_alert_threshold=0.1,
@@ -918,11 +968,8 @@ def test_run_model_monitoring_job_with_user_id(
918968
notification_spec=ml_monitoring.spec.NotificationSpec(
919969
user_emails=[_TEST_NOTIFICATION_EMAIL]
920970
),
921-
)
922-
test_model_monitor.run(
923-
display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME,
924-
target_dataset=ml_monitoring.spec.MonitoringInput(
925-
vertex_dataset=_TEST_TARGET_RESOURCE
971+
explanation_spec=explanation.ExplanationSpec(
972+
parameters=explanation.ExplanationParameters(top_k=10)
926973
),
927974
model_monitoring_job_id=_TEST_MODEL_MONITORING_JOB_USER_ID,
928975
)
@@ -938,6 +985,7 @@ def test_run_model_monitoring_job_with_user_id(
938985
"create_model_monitoring_job_mock",
939986
"create_model_monitor_mock",
940987
"search_metrics_mock",
988+
"get_model_monitoring_job_mock",
941989
)
942990
def test_search_metrics(self, search_metrics_mock):
943991
aiplatform.init(
@@ -978,6 +1026,7 @@ def test_search_metrics(self, search_metrics_mock):
9781026
"create_model_monitoring_job_mock",
9791027
"create_model_monitor_mock",
9801028
"search_alerts_mock",
1029+
"get_model_monitoring_job_mock",
9811030
)
9821031
def test_search_alerts(self, search_alerts_mock):
9831032
aiplatform.init(
@@ -1047,14 +1096,17 @@ def test_delete_model_monitor(self, delete_model_monitor_mock, force):
10471096
)
10481097
)
10491098

1050-
@pytest.mark.usefixtures("create_model_monitoring_job_mock")
1051-
def test_create_model_monitoring_job(self, create_model_monitoring_job_mock):
1099+
@pytest.mark.usefixtures(
1100+
"create_model_monitoring_job_mock", "get_model_monitoring_job_mock"
1101+
)
1102+
@pytest.mark.parametrize("sync", [True, False])
1103+
def test_create_model_monitoring_job(self, create_model_monitoring_job_mock, sync):
10521104
aiplatform.init(
10531105
project=_TEST_PROJECT,
10541106
location=_TEST_LOCATION,
10551107
credentials=_TEST_CREDENTIALS,
10561108
)
1057-
ModelMonitoringJob.create(
1109+
test_model_monitoring_job = ModelMonitoringJob.create(
10581110
display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME,
10591111
model_monitor_name=_TEST_MODEL_MONITOR_RESOURCE_NAME,
10601112
tabular_objective_spec=ml_monitoring.spec.TabularObjective(
@@ -1073,8 +1125,15 @@ def test_create_model_monitoring_job(self, create_model_monitoring_job_mock):
10731125
notification_spec=ml_monitoring.spec.NotificationSpec(
10741126
user_emails=[_TEST_NOTIFICATION_EMAIL]
10751127
),
1076-
explanation_spec=explanation.ExplanationSpec(),
1128+
explanation_spec=explanation.ExplanationSpec(
1129+
parameters=explanation.ExplanationParameters(top_k=10)
1130+
),
1131+
sync=sync,
10771132
)
1133+
1134+
if not sync:
1135+
test_model_monitoring_job.wait()
1136+
10781137
create_model_monitoring_job_mock.assert_called_once_with(
10791138
request=gca_model_monitoring_service.CreateModelMonitoringJobRequest(
10801139
parent=_TEST_MODEL_MONITOR_RESOURCE_NAME,
@@ -1086,6 +1145,7 @@ def test_create_model_monitoring_job(self, create_model_monitoring_job_mock):
10861145
"create_model_monitor_mock",
10871146
"create_model_monitoring_job_mock",
10881147
"delete_model_monitoring_job_mock",
1148+
"get_model_monitoring_job_mock",
10891149
)
10901150
def test_delete_model_monitoring_job(self, delete_model_monitoring_job_mock):
10911151
aiplatform.init(

0 commit comments

Comments
 (0)