31
31
32
32
# constants used for testing
33
33
USER_EMAIL = ""
34
- PERMANENT_CHURN_ENDPOINT_ID = "8289570005524152320 "
34
+ PERMANENT_CHURN_ENDPOINT_ID = "1843089351408353280 "
35
35
CHURN_MODEL_PATH = "gs://mco-mm/churn"
36
+ DEFAULT_INPUT = {
37
+ "cnt_ad_reward" : 0 ,
38
+ "cnt_challenge_a_friend" : 0 ,
39
+ "cnt_completed_5_levels" : 1 ,
40
+ "cnt_level_complete_quickplay" : 3 ,
41
+ "cnt_level_end_quickplay" : 5 ,
42
+ "cnt_level_reset_quickplay" : 2 ,
43
+ "cnt_level_start_quickplay" : 6 ,
44
+ "cnt_post_score" : 34 ,
45
+ "cnt_spend_virtual_currency" : 0 ,
46
+ "cnt_use_extra_steps" : 0 ,
47
+ "cnt_user_engagement" : 120 ,
48
+ "country" : "Denmark" ,
49
+ "dayofweek" : 3 ,
50
+ "julianday" : 254 ,
51
+ "language" : "da-dk" ,
52
+ "month" : 9 ,
53
+ "operating_system" : "IOS" ,
54
+ "user_pseudo_id" : "104B0770BAE16E8B53DF330C95881893" ,
55
+ }
36
56
37
57
JOB_NAME = "churn"
38
58
@@ -117,10 +137,7 @@ def test_mdm_two_models_one_valid_config(self):
117
137
project = e2e_base ._PROJECT ,
118
138
location = e2e_base ._LOCATION ,
119
139
endpoint = self .endpoint ,
120
- predict_instance_schema_uri = "" ,
121
- analysis_instance_schema_uri = "" ,
122
140
)
123
- assert job is not None
124
141
125
142
gapic_job = job ._gca_resource
126
143
assert (
@@ -156,22 +173,77 @@ def test_mdm_two_models_one_valid_config(self):
156
173
gca_obj_config .prediction_drift_detection_config == drift_config .as_proto ()
157
174
)
158
175
176
+ # delete this job and re-configure it to only enable drift detection for faster testing
177
+ job .delete ()
159
178
job_resource = job ._gca_resource .name
160
179
161
- # test job update and delete()
162
- timeout = time .time () + 3600
163
- new_obj_config = model_monitoring .ObjectiveConfig (skew_config )
180
+ # test job delete
181
+ with pytest .raises (core_exceptions .NotFound ):
182
+ job .api_client .get_model_deployment_monitoring_job (name = job_resource )
183
+
184
+ def test_mdm_pause_and_update_config (self ):
185
+ """Test objective config updates for existing MDM job"""
186
+ job = aiplatform .ModelDeploymentMonitoringJob .create (
187
+ display_name = self ._make_display_name (key = JOB_NAME ),
188
+ logging_sampling_strategy = sampling_strategy ,
189
+ schedule_config = schedule_config ,
190
+ alert_config = alert_config ,
191
+ objective_configs = model_monitoring .ObjectiveConfig (
192
+ drift_detection_config = drift_config
193
+ ),
194
+ create_request_timeout = 3600 ,
195
+ project = e2e_base ._PROJECT ,
196
+ location = e2e_base ._LOCATION ,
197
+ endpoint = self .endpoint ,
198
+ )
199
+ # test unsuccessful job update when it's pending
200
+ DRIFT_THRESHOLDS ["cnt_user_engagement" ] += 0.01
201
+ new_obj_config = model_monitoring .ObjectiveConfig (
202
+ drift_detection_config = model_monitoring .DriftDetectionConfig (
203
+ drift_thresholds = DRIFT_THRESHOLDS ,
204
+ attribute_drift_thresholds = ATTRIB_DRIFT_THRESHOLDS ,
205
+ )
206
+ )
207
+ if job .state == gca_job_state .JobState .JOB_STATE_PENDING :
208
+ with pytest .raises (core_exceptions .FailedPrecondition ):
209
+ job .update (objective_configs = new_obj_config )
210
+
211
+ # generate traffic to force MDM job to come online
212
+ for i in range (2000 ):
213
+ DEFAULT_INPUT ["cnt_user_engagement" ] += i
214
+ self .endpoint .predict ([DEFAULT_INPUT ], use_raw_predict = True )
164
215
165
- while time .time () < timeout :
216
+ # test job update
217
+ while True :
218
+ time .sleep (1 )
166
219
if job .state == gca_job_state .JobState .JOB_STATE_RUNNING :
167
220
job .update (objective_configs = new_obj_config )
168
- assert str (job ._gca_resource .prediction_drift_detection_config ) == ""
169
221
break
170
- time .sleep (5 )
171
222
223
+ # verify job update
224
+ while True :
225
+ time .sleep (1 )
226
+ if job .state == gca_job_state .JobState .JOB_STATE_RUNNING :
227
+ gca_obj_config = (
228
+ job ._gca_resource .model_deployment_monitoring_objective_configs [
229
+ 0
230
+ ].objective_config
231
+ )
232
+ assert (
233
+ gca_obj_config .prediction_drift_detection_config
234
+ == new_obj_config .drift_detection_config .as_proto ()
235
+ )
236
+ break
237
+
238
+ # test pause
239
+ job .pause ()
240
+ while job .state != gca_job_state .JobState .JOB_STATE_PAUSED :
241
+ time .sleep (1 )
172
242
job .delete ()
243
+
244
+ # confirm deletion
173
245
with pytest .raises (core_exceptions .NotFound ):
174
- job .api_client . get_model_deployment_monitoring_job ( name = job_resource )
246
+ job .state
175
247
176
248
def test_mdm_two_models_two_valid_configs (self ):
177
249
[deployed_model1 , deployed_model2 ] = list (
@@ -181,7 +253,6 @@ def test_mdm_two_models_two_valid_configs(self):
181
253
deployed_model1 : objective_config ,
182
254
deployed_model2 : objective_config2 ,
183
255
}
184
- job = None
185
256
job = aiplatform .ModelDeploymentMonitoringJob .create (
186
257
display_name = self ._make_display_name (key = JOB_NAME ),
187
258
logging_sampling_strategy = sampling_strategy ,
@@ -192,10 +263,7 @@ def test_mdm_two_models_two_valid_configs(self):
192
263
project = e2e_base ._PROJECT ,
193
264
location = e2e_base ._LOCATION ,
194
265
endpoint = self .endpoint ,
195
- predict_instance_schema_uri = "" ,
196
- analysis_instance_schema_uri = "" ,
197
266
)
198
- assert job is not None
199
267
200
268
gapic_job = job ._gca_resource
201
269
assert (
@@ -246,8 +314,6 @@ def test_mdm_invalid_config_incorrect_model_id(self):
246
314
project = e2e_base ._PROJECT ,
247
315
location = e2e_base ._LOCATION ,
248
316
endpoint = self .endpoint ,
249
- predict_instance_schema_uri = "" ,
250
- analysis_instance_schema_uri = "" ,
251
317
deployed_model_ids = ["" ],
252
318
)
253
319
assert "Invalid model ID" in str (e .value )
@@ -265,8 +331,6 @@ def test_mdm_invalid_config_xai(self):
265
331
project = e2e_base ._PROJECT ,
266
332
location = e2e_base ._LOCATION ,
267
333
endpoint = self .endpoint ,
268
- predict_instance_schema_uri = "" ,
269
- analysis_instance_schema_uri = "" ,
270
334
)
271
335
assert (
272
336
"`explanation_config` should only be enabled if the model has `explanation_spec populated"
@@ -294,8 +358,6 @@ def test_mdm_two_models_invalid_configs_xai(self):
294
358
project = e2e_base ._PROJECT ,
295
359
location = e2e_base ._LOCATION ,
296
360
endpoint = self .endpoint ,
297
- predict_instance_schema_uri = "" ,
298
- analysis_instance_schema_uri = "" ,
299
361
)
300
362
assert (
301
363
"`explanation_config` should only be enabled if the model has `explanation_spec populated"
0 commit comments