36
36
model_monitoring_stats_v1beta1 as gca_model_monitoring_stats ,
37
37
schedule_service_v1beta1 as gca_schedule_service ,
38
38
schedule_v1beta1 as gca_schedule ,
39
+ job_state_v1beta1 as gca_job_state ,
39
40
explanation_v1beta1 as explanation ,
40
41
)
41
42
from vertexai .resources .preview import (
51
52
52
53
# -*- coding: utf-8 -*-
53
54
54
- _TEST_CREDENTIALS = mock .Mock (spec = auth_credentials .AnonymousCredentials ())
55
+ _TEST_CREDENTIALS = mock .Mock (
56
+ spec = auth_credentials .AnonymousCredentials (),
57
+ universe_domain = "googleapis.com" ,
58
+ )
55
59
_TEST_DESCRIPTION = "test description"
56
60
_TEST_JSON_CONTENT_TYPE = "application/json"
57
61
_TEST_LOCATION = "us-central1"
178
182
user_emails = [_TEST_NOTIFICATION_EMAIL ]
179
183
),
180
184
),
185
+ explanation_spec = explanation .ExplanationSpec (
186
+ parameters = explanation .ExplanationParameters (top_k = 10 )
187
+ ),
181
188
)
182
189
_TEST_UPDATED_MODEL_MONITOR_OBJ = gca_model_monitor .ModelMonitor (
183
190
name = _TEST_MODEL_MONITOR_RESOURCE_NAME ,
222
229
user_emails = [
_TEST_NOTIFICATION_EMAIL ,
"[email protected] " ]
223
230
),
224
231
),
232
+ explanation_spec = explanation .ExplanationSpec (
233
+ parameters = explanation .ExplanationParameters (top_k = 10 )
234
+ ),
225
235
)
226
236
_TEST_CREATE_MODEL_MONITORING_JOB_OBJ = gca_model_monitoring_job .ModelMonitoringJob (
227
237
display_name = _TEST_MODEL_MONITORING_JOB_DISPLAY_NAME ,
249
259
vertex_dataset = _TEST_TARGET_RESOURCE
250
260
)
251
261
),
252
- explanation_spec = explanation .ExplanationSpec (),
262
+ explanation_spec = explanation .ExplanationSpec (
263
+ parameters = explanation .ExplanationParameters (top_k = 10 )
264
+ ),
253
265
),
254
266
output_spec = gca_model_monitoring_spec .ModelMonitoringOutputSpec (
255
267
gcs_base_directory = io .GcsDestination (output_uri_prefix = _TEST_OUTPUT_PATH )
288
300
vertex_dataset = _TEST_TARGET_RESOURCE
289
301
)
290
302
),
291
- explanation_spec = explanation .ExplanationSpec (),
303
+ explanation_spec = explanation .ExplanationSpec (
304
+ parameters = explanation .ExplanationParameters (top_k = 10 )
305
+ ),
292
306
),
293
307
output_spec = gca_model_monitoring_spec .ModelMonitoringOutputSpec (
294
308
gcs_base_directory = io .GcsDestination (output_uri_prefix = _TEST_OUTPUT_PATH )
299
313
)
300
314
),
301
315
),
316
+ state = gca_job_state .JobState .JOB_STATE_SUCCEEDED ,
302
317
)
303
318
_TEST_CRON = r"America/New_York 1 \* \* \* \*"
304
319
_TEST_SCHEDULE_OBJ = gca_schedule .Schedule (
336
351
vertex_dataset = _TEST_TARGET_RESOURCE
337
352
)
338
353
),
339
- explanation_spec = explanation .ExplanationSpec (),
354
+ explanation_spec = explanation .ExplanationSpec (
355
+ parameters = explanation .ExplanationParameters (top_k = 10 )
356
+ ),
340
357
),
341
358
output_spec = gca_model_monitoring_spec .ModelMonitoringOutputSpec (
342
359
gcs_base_directory = io .GcsDestination (output_uri_prefix = _TEST_OUTPUT_PATH )
@@ -564,7 +581,12 @@ def get_model_monitoring_job_mock():
564
581
model_monitoring_service_client .ModelMonitoringServiceClient ,
565
582
"get_model_monitoring_job" ,
566
583
) as get_model_monitoring_job_mock :
567
- get_model_monitor_mock .return_value = _TEST_MODEL_MONITORING_JOB_OBJ
584
+ model_monitoring_job_mock = mock .Mock (
585
+ spec = gca_model_monitoring_job .ModelMonitoringJob
586
+ )
587
+ model_monitoring_job_mock .state = gca_job_state .JobState .JOB_STATE_SUCCEEDED
588
+ model_monitoring_job_mock .name = _TEST_MODEL_MONITORING_JOB_RESOURCE_NAME
589
+ get_model_monitoring_job_mock .return_value = model_monitoring_job_mock
568
590
yield get_model_monitoring_job_mock
569
591
570
592
@@ -762,6 +784,9 @@ def test_create_schedule(self, create_schedule_mock):
762
784
notification_spec = ml_monitoring .spec .NotificationSpec (
763
785
user_emails = [_TEST_NOTIFICATION_EMAIL ]
764
786
),
787
+ explanation_spec = explanation .ExplanationSpec (
788
+ parameters = explanation .ExplanationParameters (top_k = 10 )
789
+ ),
765
790
)
766
791
test_model_monitor .create_schedule (
767
792
display_name = _TEST_SCHEDULE_NAME ,
@@ -851,9 +876,12 @@ def test_update_schedule(self, update_schedule_mock, get_schedule_mock):
851
876
assert get_schedule_mock .call_count == 1
852
877
853
878
@pytest .mark .usefixtures (
854
- "create_model_monitoring_job_mock" , "create_model_monitor_mock"
879
+ "create_model_monitoring_job_mock" ,
880
+ "create_model_monitor_mock" ,
881
+ "get_model_monitoring_job_mock" ,
855
882
)
856
- def test_run_model_monitoring_job (self , create_model_monitoring_job_mock ):
883
+ @pytest .mark .parametrize ("sync" , [True , False ])
884
+ def test_run_model_monitoring_job (self , create_model_monitoring_job_mock , sync ):
857
885
aiplatform .init (
858
886
project = _TEST_PROJECT ,
859
887
location = _TEST_LOCATION ,
@@ -866,6 +894,15 @@ def test_run_model_monitoring_job(self, create_model_monitoring_job_mock):
866
894
model_name = _TEST_MODEL_NAME ,
867
895
display_name = _TEST_MODEL_MONITOR_DISPLAY_NAME ,
868
896
model_version_id = _TEST_MODEL_VERSION_ID ,
897
+ )
898
+ test_model_monitoring_job = test_model_monitor .run (
899
+ display_name = _TEST_MODEL_MONITORING_JOB_DISPLAY_NAME ,
900
+ baseline_dataset = ml_monitoring .spec .MonitoringInput (
901
+ vertex_dataset = _TEST_BASELINE_RESOURCE
902
+ ),
903
+ target_dataset = ml_monitoring .spec .MonitoringInput (
904
+ vertex_dataset = _TEST_TARGET_RESOURCE
905
+ ),
869
906
tabular_objective_spec = ml_monitoring .spec .TabularObjective (
870
907
feature_drift_spec = ml_monitoring .spec .DataDriftSpec (
871
908
default_categorical_alert_threshold = 0.1 ,
@@ -876,13 +913,15 @@ def test_run_model_monitoring_job(self, create_model_monitoring_job_mock):
876
913
notification_spec = ml_monitoring .spec .NotificationSpec (
877
914
user_emails = [_TEST_NOTIFICATION_EMAIL ]
878
915
),
879
- )
880
- test_model_monitor .run (
881
- display_name = _TEST_MODEL_MONITORING_JOB_DISPLAY_NAME ,
882
- target_dataset = ml_monitoring .spec .MonitoringInput (
883
- vertex_dataset = _TEST_TARGET_RESOURCE
916
+ explanation_spec = explanation .ExplanationSpec (
917
+ parameters = explanation .ExplanationParameters (top_k = 10 )
884
918
),
919
+ sync = sync ,
885
920
)
921
+
922
+ if not sync :
923
+ test_model_monitoring_job .wait ()
924
+
886
925
create_model_monitoring_job_mock .assert_called_once_with (
887
926
request = gca_model_monitoring_service .CreateModelMonitoringJobRequest (
888
927
parent = _TEST_MODEL_MONITOR_RESOURCE_NAME ,
@@ -891,7 +930,9 @@ def test_run_model_monitoring_job(self, create_model_monitoring_job_mock):
891
930
)
892
931
893
932
@pytest .mark .usefixtures (
894
- "create_model_monitoring_job_mock" , "create_model_monitor_mock"
933
+ "create_model_monitoring_job_mock" ,
934
+ "create_model_monitor_mock" ,
935
+ "get_model_monitoring_job_mock" ,
895
936
)
896
937
def test_run_model_monitoring_job_with_user_id (
897
938
self , create_model_monitoring_job_mock
@@ -908,6 +949,15 @@ def test_run_model_monitoring_job_with_user_id(
908
949
model_name = _TEST_MODEL_NAME ,
909
950
display_name = _TEST_MODEL_MONITOR_DISPLAY_NAME ,
910
951
model_version_id = _TEST_MODEL_VERSION_ID ,
952
+ )
953
+ test_model_monitor .run (
954
+ display_name = _TEST_MODEL_MONITORING_JOB_DISPLAY_NAME ,
955
+ baseline_dataset = ml_monitoring .spec .MonitoringInput (
956
+ vertex_dataset = _TEST_BASELINE_RESOURCE
957
+ ),
958
+ target_dataset = ml_monitoring .spec .MonitoringInput (
959
+ vertex_dataset = _TEST_TARGET_RESOURCE
960
+ ),
911
961
tabular_objective_spec = ml_monitoring .spec .TabularObjective (
912
962
feature_drift_spec = ml_monitoring .spec .DataDriftSpec (
913
963
default_categorical_alert_threshold = 0.1 ,
@@ -918,11 +968,8 @@ def test_run_model_monitoring_job_with_user_id(
918
968
notification_spec = ml_monitoring .spec .NotificationSpec (
919
969
user_emails = [_TEST_NOTIFICATION_EMAIL ]
920
970
),
921
- )
922
- test_model_monitor .run (
923
- display_name = _TEST_MODEL_MONITORING_JOB_DISPLAY_NAME ,
924
- target_dataset = ml_monitoring .spec .MonitoringInput (
925
- vertex_dataset = _TEST_TARGET_RESOURCE
971
+ explanation_spec = explanation .ExplanationSpec (
972
+ parameters = explanation .ExplanationParameters (top_k = 10 )
926
973
),
927
974
model_monitoring_job_id = _TEST_MODEL_MONITORING_JOB_USER_ID ,
928
975
)
@@ -938,6 +985,7 @@ def test_run_model_monitoring_job_with_user_id(
938
985
"create_model_monitoring_job_mock" ,
939
986
"create_model_monitor_mock" ,
940
987
"search_metrics_mock" ,
988
+ "get_model_monitoring_job_mock" ,
941
989
)
942
990
def test_search_metrics (self , search_metrics_mock ):
943
991
aiplatform .init (
@@ -978,6 +1026,7 @@ def test_search_metrics(self, search_metrics_mock):
978
1026
"create_model_monitoring_job_mock" ,
979
1027
"create_model_monitor_mock" ,
980
1028
"search_alerts_mock" ,
1029
+ "get_model_monitoring_job_mock" ,
981
1030
)
982
1031
def test_search_alerts (self , search_alerts_mock ):
983
1032
aiplatform .init (
@@ -1047,14 +1096,17 @@ def test_delete_model_monitor(self, delete_model_monitor_mock, force):
1047
1096
)
1048
1097
)
1049
1098
1050
- @pytest .mark .usefixtures ("create_model_monitoring_job_mock" )
1051
- def test_create_model_monitoring_job (self , create_model_monitoring_job_mock ):
1099
+ @pytest .mark .usefixtures (
1100
+ "create_model_monitoring_job_mock" , "get_model_monitoring_job_mock"
1101
+ )
1102
+ @pytest .mark .parametrize ("sync" , [True , False ])
1103
+ def test_create_model_monitoring_job (self , create_model_monitoring_job_mock , sync ):
1052
1104
aiplatform .init (
1053
1105
project = _TEST_PROJECT ,
1054
1106
location = _TEST_LOCATION ,
1055
1107
credentials = _TEST_CREDENTIALS ,
1056
1108
)
1057
- ModelMonitoringJob .create (
1109
+ test_model_monitoring_job = ModelMonitoringJob .create (
1058
1110
display_name = _TEST_MODEL_MONITORING_JOB_DISPLAY_NAME ,
1059
1111
model_monitor_name = _TEST_MODEL_MONITOR_RESOURCE_NAME ,
1060
1112
tabular_objective_spec = ml_monitoring .spec .TabularObjective (
@@ -1073,8 +1125,15 @@ def test_create_model_monitoring_job(self, create_model_monitoring_job_mock):
1073
1125
notification_spec = ml_monitoring .spec .NotificationSpec (
1074
1126
user_emails = [_TEST_NOTIFICATION_EMAIL ]
1075
1127
),
1076
- explanation_spec = explanation .ExplanationSpec (),
1128
+ explanation_spec = explanation .ExplanationSpec (
1129
+ parameters = explanation .ExplanationParameters (top_k = 10 )
1130
+ ),
1131
+ sync = sync ,
1077
1132
)
1133
+
1134
+ if not sync :
1135
+ test_model_monitoring_job .wait ()
1136
+
1078
1137
create_model_monitoring_job_mock .assert_called_once_with (
1079
1138
request = gca_model_monitoring_service .CreateModelMonitoringJobRequest (
1080
1139
parent = _TEST_MODEL_MONITOR_RESOURCE_NAME ,
@@ -1086,6 +1145,7 @@ def test_create_model_monitoring_job(self, create_model_monitoring_job_mock):
1086
1145
"create_model_monitor_mock" ,
1087
1146
"create_model_monitoring_job_mock" ,
1088
1147
"delete_model_monitoring_job_mock" ,
1148
+ "get_model_monitoring_job_mock" ,
1089
1149
)
1090
1150
def test_delete_model_monitoring_job (self , delete_model_monitoring_job_mock ):
1091
1151
aiplatform .init (
0 commit comments