105
105
),
106
106
]
107
107
_TEST_TRAFFIC_SPLIT = {_TEST_ID : 0 , _TEST_ID_2 : 100 , _TEST_ID_3 : 0 }
108
- _TEST_PREDICTION = [{"label" : 1.0 }]
108
+ _TEST_DICT_PREDICTION = [{"label" : 1.0 }]
109
+ _TEST_LIST_PREDICTION = [[1.0 ]]
109
110
_TEST_EXPLANATIONS = [gca_prediction_service .explanation .Explanation (attributions = [])]
110
111
_TEST_ATTRIBUTIONS = [
111
112
gca_prediction_service .explanation .Attribution (
@@ -218,26 +219,54 @@ def get_endpoint_with_models_with_explanation_mock():
218
219
219
220
220
221
@pytest .fixture
221
- def predict_client_predict_mock ():
222
+ def predict_client_predict_dict_mock ():
222
223
with mock .patch .object (
223
224
prediction_service_client .PredictionServiceClient , "predict"
224
225
) as predict_mock :
225
226
predict_mock .return_value = gca_prediction_service .PredictResponse (
226
227
deployed_model_id = _TEST_ID
227
228
)
228
- predict_mock .return_value .predictions .extend (_TEST_PREDICTION )
229
+ predict_mock .return_value .predictions .extend (_TEST_DICT_PREDICTION )
229
230
yield predict_mock
230
231
231
232
232
233
@pytest .fixture
233
- def predict_client_explain_mock ():
234
+ def predict_client_explain_dict_mock ():
234
235
with mock .patch .object (
235
236
prediction_service_client .PredictionServiceClient , "explain"
236
237
) as predict_mock :
237
238
predict_mock .return_value = gca_prediction_service .ExplainResponse (
238
239
deployed_model_id = _TEST_ID ,
239
240
)
240
- predict_mock .return_value .predictions .extend (_TEST_PREDICTION )
241
+ predict_mock .return_value .predictions .extend (_TEST_DICT_PREDICTION )
242
+ predict_mock .return_value .explanations .extend (_TEST_EXPLANATIONS )
243
+ predict_mock .return_value .explanations [0 ].attributions .extend (
244
+ _TEST_ATTRIBUTIONS
245
+ )
246
+ yield predict_mock
247
+
248
+
249
+ @pytest .fixture
250
+ def predict_client_predict_list_mock ():
251
+ with mock .patch .object (
252
+ prediction_service_client .PredictionServiceClient , "predict"
253
+ ) as predict_mock :
254
+ predict_mock .return_value = gca_prediction_service .PredictResponse (
255
+ deployed_model_id = _TEST_ID
256
+ )
257
+ predict_mock .return_value .predictions .extend (_TEST_LIST_PREDICTION )
258
+ yield predict_mock
259
+
260
+
261
+ @pytest .fixture
262
+ def predict_client_explain_list_mock ():
263
+ with mock .patch .object (
264
+ prediction_service_client .PredictionServiceClient , "explain"
265
+ ) as predict_mock :
266
+ predict_mock .return_value = gca_prediction_service .ExplainResponse (
267
+ deployed_model_id = _TEST_ID ,
268
+ )
269
+ predict_mock .return_value .predictions .extend (_TEST_LIST_PREDICTION )
241
270
predict_mock .return_value .explanations .extend (_TEST_EXPLANATIONS )
242
271
predict_mock .return_value .explanations [0 ].attributions .extend (
243
272
_TEST_ATTRIBUTIONS
@@ -312,10 +341,112 @@ def test_create_lit_model_from_tensorflow_with_xai_returns_model(
312
341
assert len (item .values ()) == 2
313
342
314
343
@pytest .mark .usefixtures (
315
- "predict_client_predict_mock" , "get_endpoint_with_models_mock"
344
+ "predict_client_predict_dict_mock" , "get_endpoint_with_models_mock"
345
+ )
346
+ @pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
347
+ def test_create_lit_model_from_dict_endpoint_returns_model (
348
+ self , feature_types , label_types , model_id
349
+ ):
350
+ endpoint = aiplatform .Endpoint (_TEST_ENDPOINT_NAME )
351
+ lit_model = create_lit_model_from_endpoint (
352
+ endpoint , feature_types , label_types , model_id
353
+ )
354
+ test_inputs = [
355
+ {"feature_1" : 1.0 , "feature_2" : 2.0 },
356
+ ]
357
+ outputs = lit_model .predict_minibatch (test_inputs )
358
+
359
+ assert lit_model .input_spec () == dict (feature_types )
360
+ assert lit_model .output_spec () == dict (label_types )
361
+ assert len (outputs ) == 1
362
+ for item in outputs :
363
+ assert item .keys () == {"label" }
364
+ assert len (item .values ()) == 1
365
+
366
+ @pytest .mark .usefixtures (
367
+ "predict_client_explain_dict_mock" ,
368
+ "get_endpoint_with_models_with_explanation_mock" ,
369
+ )
370
+ @pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
371
+ def test_create_lit_model_from_dict_endpoint_with_xai_returns_model (
372
+ self , feature_types , label_types , model_id
373
+ ):
374
+ endpoint = aiplatform .Endpoint (_TEST_ENDPOINT_NAME )
375
+ lit_model = create_lit_model_from_endpoint (
376
+ endpoint , feature_types , label_types , model_id
377
+ )
378
+ test_inputs = [
379
+ {"feature_1" : 1.0 , "feature_2" : 2.0 },
380
+ ]
381
+ outputs = lit_model .predict_minibatch (test_inputs )
382
+
383
+ assert lit_model .input_spec () == dict (feature_types )
384
+ assert lit_model .output_spec () == dict (
385
+ {
386
+ ** label_types ,
387
+ "feature_attribution" : lit_types .FeatureSalience (signed = True ),
388
+ }
389
+ )
390
+ assert len (outputs ) == 1
391
+ for item in outputs :
392
+ assert item .keys () == {"label" , "feature_attribution" }
393
+ assert len (item .values ()) == 2
394
+
395
+ @pytest .mark .usefixtures (
396
+ "predict_client_predict_dict_mock" , "get_endpoint_with_models_mock"
397
+ )
398
+ @pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
399
+ def test_create_lit_model_from_dict_endpoint_name_returns_model (
400
+ self , feature_types , label_types , model_id
401
+ ):
402
+ lit_model = create_lit_model_from_endpoint (
403
+ _TEST_ENDPOINT_NAME , feature_types , label_types , model_id
404
+ )
405
+ test_inputs = [
406
+ {"feature_1" : 1.0 , "feature_2" : 2.0 },
407
+ ]
408
+ outputs = lit_model .predict_minibatch (test_inputs )
409
+
410
+ assert lit_model .input_spec () == dict (feature_types )
411
+ assert lit_model .output_spec () == dict (label_types )
412
+ assert len (outputs ) == 1
413
+ for item in outputs :
414
+ assert item .keys () == {"label" }
415
+ assert len (item .values ()) == 1
416
+
417
+ @pytest .mark .usefixtures (
418
+ "predict_client_explain_dict_mock" ,
419
+ "get_endpoint_with_models_with_explanation_mock" ,
420
+ )
421
+ @pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
422
+ def test_create_lit_model_from_dict_endpoint_name_with_xai_returns_model (
423
+ self , feature_types , label_types , model_id
424
+ ):
425
+ lit_model = create_lit_model_from_endpoint (
426
+ _TEST_ENDPOINT_NAME , feature_types , label_types , model_id
427
+ )
428
+ test_inputs = [
429
+ {"feature_1" : 1.0 , "feature_2" : 2.0 },
430
+ ]
431
+ outputs = lit_model .predict_minibatch (test_inputs )
432
+
433
+ assert lit_model .input_spec () == dict (feature_types )
434
+ assert lit_model .output_spec () == dict (
435
+ {
436
+ ** label_types ,
437
+ "feature_attribution" : lit_types .FeatureSalience (signed = True ),
438
+ }
439
+ )
440
+ assert len (outputs ) == 1
441
+ for item in outputs :
442
+ assert item .keys () == {"label" , "feature_attribution" }
443
+ assert len (item .values ()) == 2
444
+
445
+ @pytest .mark .usefixtures (
446
+ "predict_client_predict_list_mock" , "get_endpoint_with_models_mock"
316
447
)
317
448
@pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
318
- def test_create_lit_model_from_endpoint_returns_model (
449
+ def test_create_lit_model_from_list_endpoint_returns_model (
319
450
self , feature_types , label_types , model_id
320
451
):
321
452
endpoint = aiplatform .Endpoint (_TEST_ENDPOINT_NAME )
@@ -335,10 +466,11 @@ def test_create_lit_model_from_endpoint_returns_model(
335
466
assert len (item .values ()) == 1
336
467
337
468
@pytest .mark .usefixtures (
338
- "predict_client_explain_mock" , "get_endpoint_with_models_with_explanation_mock"
469
+ "predict_client_explain_list_mock" ,
470
+ "get_endpoint_with_models_with_explanation_mock" ,
339
471
)
340
472
@pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
341
- def test_create_lit_model_from_endpoint_with_xai_returns_model (
473
+ def test_create_lit_model_from_list_endpoint_with_xai_returns_model (
342
474
self , feature_types , label_types , model_id
343
475
):
344
476
endpoint = aiplatform .Endpoint (_TEST_ENDPOINT_NAME )
@@ -363,10 +495,10 @@ def test_create_lit_model_from_endpoint_with_xai_returns_model(
363
495
assert len (item .values ()) == 2
364
496
365
497
@pytest .mark .usefixtures (
366
- "predict_client_predict_mock " , "get_endpoint_with_models_mock"
498
+ "predict_client_predict_list_mock " , "get_endpoint_with_models_mock"
367
499
)
368
500
@pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
369
- def test_create_lit_model_from_endpoint_name_returns_model (
501
+ def test_create_lit_model_from_list_endpoint_name_returns_model (
370
502
self , feature_types , label_types , model_id
371
503
):
372
504
lit_model = create_lit_model_from_endpoint (
@@ -385,10 +517,11 @@ def test_create_lit_model_from_endpoint_name_returns_model(
385
517
assert len (item .values ()) == 1
386
518
387
519
@pytest .mark .usefixtures (
388
- "predict_client_explain_mock" , "get_endpoint_with_models_with_explanation_mock"
520
+ "predict_client_explain_list_mock" ,
521
+ "get_endpoint_with_models_with_explanation_mock" ,
389
522
)
390
523
@pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
391
- def test_create_lit_model_from_endpoint_name_with_xai_returns_model (
524
+ def test_create_lit_model_from_list_endpoint_name_with_xai_returns_model (
392
525
self , feature_types , label_types , model_id
393
526
):
394
527
lit_model = create_lit_model_from_endpoint (
0 commit comments