16
16
#
17
17
18
18
import pytest
19
+ import copy
19
20
20
21
from unittest import mock
21
22
from importlib import reload
24
25
from google .cloud import storage
25
26
from google .cloud import bigquery
26
27
28
+ from google .api_core import operation
27
29
from google .auth import credentials as auth_credentials
28
30
29
31
from google .cloud import aiplatform
46
48
job_service_client ,
47
49
)
48
50
from google .protobuf import field_mask_pb2 # type: ignore
51
+ from google .protobuf import duration_pb2 # type: ignore
49
52
53
+ import test_endpoints # noqa: F401
50
54
from test_endpoints import get_endpoint_with_models_mock # noqa: F401
51
55
52
56
_TEST_API_CLIENT = job_service_client .JobServiceClient
175
179
_TEST_JOB_RESOURCE_NAME = f"{ _TEST_PARENT } /customJobs/{ _TEST_ID } "
176
180
177
181
_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG = {"TEST_KEY" : 0.01 }
182
+ _TEST_MDM_USER_EMAIL = "TEST_EMAIL"
183
+ _TEST_MDM_SAMPLE_RATE = 0.5
184
+ _TEST_MDM_LABEL = {"TEST KEY" : "TEST VAL" }
185
+ _TEST_LOG_TTL_IN_DAYS = 1
186
+ _TEST_MDM_NEW_NAME = "NEW_NAME"
187
+
188
+ _TEST_MDM_OLD_JOB = (
189
+ gca_model_deployment_monitoring_job_compat .ModelDeploymentMonitoringJob (
190
+ name = _TEST_MDM_JOB_NAME ,
191
+ display_name = _TEST_DISPLAY_NAME ,
192
+ endpoint = _TEST_ENDPOINT ,
193
+ state = _TEST_JOB_STATE_RUNNING ,
194
+ )
195
+ )
196
+
197
+ _TEST_MDM_EXPECTED_NEW_JOB = gca_model_deployment_monitoring_job_compat .ModelDeploymentMonitoringJob (
198
+ name = _TEST_MDM_JOB_NAME ,
199
+ display_name = _TEST_MDM_NEW_NAME ,
200
+ endpoint = _TEST_ENDPOINT ,
201
+ state = _TEST_JOB_STATE_RUNNING ,
202
+ model_deployment_monitoring_objective_configs = [
203
+ gca_model_deployment_monitoring_job_compat .ModelDeploymentMonitoringObjectiveConfig (
204
+ deployed_model_id = model_id ,
205
+ objective_config = gca_model_monitoring_compat .ModelMonitoringObjectiveConfig (
206
+ prediction_drift_detection_config = gca_model_monitoring_compat .ModelMonitoringObjectiveConfig .PredictionDriftDetectionConfig (
207
+ drift_thresholds = {
208
+ "TEST_KEY" : gca_model_monitoring_compat .ThresholdConfig (
209
+ value = 0.01
210
+ )
211
+ }
212
+ )
213
+ ),
214
+ )
215
+ for model_id in [model .id for model in test_endpoints ._TEST_DEPLOYED_MODELS ]
216
+ ],
217
+ logging_sampling_strategy = gca_model_monitoring_compat .SamplingStrategy (
218
+ random_sample_config = gca_model_monitoring_compat .SamplingStrategy .RandomSampleConfig (
219
+ sample_rate = _TEST_MDM_SAMPLE_RATE
220
+ )
221
+ ),
222
+ labels = _TEST_MDM_LABEL ,
223
+ model_monitoring_alert_config = gca_model_monitoring_compat .ModelMonitoringAlertConfig (
224
+ email_alert_config = gca_model_monitoring_compat .ModelMonitoringAlertConfig .EmailAlertConfig (
225
+ user_emails = [_TEST_MDM_USER_EMAIL ]
226
+ )
227
+ ),
228
+ model_deployment_monitoring_schedule_config = gca_model_deployment_monitoring_job_compat .ModelDeploymentMonitoringScheduleConfig (
229
+ monitor_interval = duration_pb2 .Duration (seconds = 3600 )
230
+ ),
231
+ log_ttl = duration_pb2 .Duration (seconds = _TEST_LOG_TTL_IN_DAYS * 86400 ),
232
+ enable_monitoring_pipeline_logs = True ,
233
+ )
178
234
179
235
# TODO(b/171333554): Move reusable test fixtures to conftest.py file
180
236
@@ -988,48 +1044,23 @@ def get_mdm_job_mock():
988
1044
with mock .patch .object (
989
1045
_TEST_API_CLIENT , "get_model_deployment_monitoring_job"
990
1046
) as get_mdm_job_mock :
991
- get_mdm_job_mock .return_value = (
992
- gca_model_deployment_monitoring_job_compat .ModelDeploymentMonitoringJob (
993
- name = _TEST_MDM_JOB_NAME ,
994
- display_name = _TEST_DISPLAY_NAME ,
995
- state = _TEST_JOB_STATE_RUNNING ,
996
- endpoint = _TEST_ENDPOINT ,
997
- )
998
- )
1047
+ get_mdm_job_mock .side_effect = [
1048
+ _TEST_MDM_OLD_JOB ,
1049
+ _TEST_MDM_OLD_JOB ,
1050
+ _TEST_MDM_OLD_JOB ,
1051
+ _TEST_MDM_EXPECTED_NEW_JOB ,
1052
+ ]
999
1053
yield get_mdm_job_mock
1000
1054
1001
1055
1002
1056
@pytest .fixture
1003
- @pytest .mark .usefixtures ("get_mdm_job_mock" )
1004
1057
def update_mdm_job_mock (get_endpoint_with_models_mock ): # noqa: F811
1005
1058
with mock .patch .object (
1006
1059
_TEST_API_CLIENT , "update_model_deployment_monitoring_job"
1007
1060
) as update_mdm_job_mock :
1008
- expected_objective_config = gca_model_monitoring_compat .ModelMonitoringObjectiveConfig (
1009
- prediction_drift_detection_config = gca_model_monitoring_compat .ModelMonitoringObjectiveConfig .PredictionDriftDetectionConfig (
1010
- drift_thresholds = {
1011
- "TEST_KEY" : gca_model_monitoring_compat .ThresholdConfig (value = 0.01 )
1012
- }
1013
- )
1014
- )
1015
- all_configs = []
1016
- for model in get_endpoint_with_models_mock .return_value .deployed_models :
1017
- all_configs .append (
1018
- gca_model_deployment_monitoring_job_compat .ModelDeploymentMonitoringObjectiveConfig (
1019
- deployed_model_id = model .id ,
1020
- objective_config = expected_objective_config ,
1021
- )
1022
- )
1023
-
1024
- update_mdm_job_mock .return_vaue .result_type = (
1025
- gca_model_deployment_monitoring_job_compat .ModelDeploymentMonitoringJob (
1026
- name = _TEST_MDM_JOB_NAME ,
1027
- display_name = _TEST_DISPLAY_NAME ,
1028
- state = _TEST_JOB_STATE_RUNNING ,
1029
- endpoint = _TEST_ENDPOINT ,
1030
- model_deployment_monitoring_objective_configs = all_configs ,
1031
- )
1032
- )
1061
+ update_mdm_job_lro_mock = mock .Mock (operation .Operation )
1062
+ update_mdm_job_lro_mock .result .return_value = _TEST_MDM_EXPECTED_NEW_JOB
1063
+ update_mdm_job_mock .return_value = update_mdm_job_lro_mock
1033
1064
yield update_mdm_job_mock
1034
1065
1035
1066
@@ -1046,25 +1077,66 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock):
1046
1077
job = jobs .ModelDeploymentMonitoringJob (
1047
1078
model_deployment_monitoring_job_name = _TEST_MDM_JOB_NAME
1048
1079
)
1080
+ old_job = copy .deepcopy (job ._gca_resource )
1049
1081
drift_detection_config = aiplatform .model_monitoring .DriftDetectionConfig (
1050
1082
drift_thresholds = _TEST_MDM_JOB_DRIFT_DETECTION_CONFIG
1051
1083
)
1084
+ schedule_config = aiplatform .model_monitoring .ScheduleConfig (monitor_interval = 1 )
1085
+ alert_config = aiplatform .model_monitoring .EmailAlertConfig (
1086
+ user_emails = [_TEST_MDM_USER_EMAIL ]
1087
+ )
1088
+ sampling_strategy = aiplatform .model_monitoring .RandomSampleConfig (
1089
+ sample_rate = _TEST_MDM_SAMPLE_RATE
1090
+ )
1091
+ labels = _TEST_MDM_LABEL
1092
+ log_ttl = _TEST_LOG_TTL_IN_DAYS
1093
+ display_name = _TEST_MDM_NEW_NAME
1052
1094
new_config = aiplatform .model_monitoring .ObjectiveConfig (
1053
1095
drift_detection_config = drift_detection_config
1054
1096
)
1055
- job .update (objective_configs = new_config )
1097
+ job .update (
1098
+ display_name = display_name ,
1099
+ schedule_config = schedule_config ,
1100
+ alert_config = alert_config ,
1101
+ logging_sampling_strategy = sampling_strategy ,
1102
+ labels = labels ,
1103
+ bigquery_tables_log_ttl = log_ttl ,
1104
+ enable_monitoring_pipeline_logs = True ,
1105
+ objective_configs = new_config ,
1106
+ )
1107
+ new_job = job ._gca_resource
1108
+ assert old_job != new_job
1109
+ assert new_job .display_name == display_name
1110
+ assert new_job .logging_sampling_strategy == sampling_strategy .as_proto ()
1111
+ assert (
1112
+ new_job .model_deployment_monitoring_schedule_config
1113
+ == schedule_config .as_proto ()
1114
+ )
1115
+ assert new_job .labels == labels
1116
+ assert new_job .model_monitoring_alert_config == alert_config .as_proto ()
1117
+ assert new_job .log_ttl .days == _TEST_LOG_TTL_IN_DAYS
1118
+ assert new_job .enable_monitoring_pipeline_logs
1056
1119
assert (
1057
- job . _gca_resource .model_deployment_monitoring_objective_configs [
1120
+ new_job .model_deployment_monitoring_objective_configs [
1058
1121
0
1059
1122
].objective_config .prediction_drift_detection_config
1060
1123
== drift_detection_config .as_proto ()
1061
1124
)
1062
1125
get_mdm_job_mock .assert_called_with (
1063
- name = _TEST_MDM_JOB_NAME ,
1126
+ name = _TEST_MDM_JOB_NAME , retry = base . _DEFAULT_RETRY
1064
1127
)
1065
1128
update_mdm_job_mock .assert_called_once_with (
1066
- model_deployment_monitoring_job = get_mdm_job_mock . return_value ,
1129
+ model_deployment_monitoring_job = new_job ,
1067
1130
update_mask = field_mask_pb2 .FieldMask (
1068
- paths = ["model_deployment_monitoring_objective_configs" ]
1131
+ paths = [
1132
+ "display_name" ,
1133
+ "model_deployment_monitoring_schedule_config" ,
1134
+ "model_monitoring_alert_config" ,
1135
+ "logging_sampling_strategy" ,
1136
+ "labels" ,
1137
+ "log_ttl" ,
1138
+ "enable_monitoring_pipeline_logs" ,
1139
+ "model_deployment_monitoring_objective_configs" ,
1140
+ ]
1069
1141
),
1070
1142
)
0 commit comments