Skip to content

Commit b9a057d

Browse files
authored
fix: Fix create_lit_model_from_endpoint not accepting models that don't return a dictionary. (#1020)
Some models, like Keras squential models, don't return a dictionary for their prediction. We need to support these models as it is commonly used. Fixes b/220167889
1 parent e7d2719 commit b9a057d

File tree

2 files changed

+153
-15
lines changed

2 files changed

+153
-15
lines changed

google/cloud/aiplatform/explain/lit.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919

2020
from google.cloud import aiplatform
21-
from typing import Dict, List, Optional, Tuple, Union
21+
from typing import Dict, List, Mapping, Optional, Tuple, Union
2222

2323
try:
2424
from lit_nlp.api import dataset as lit_dataset
@@ -154,7 +154,12 @@ def predict_minibatch(
154154
prediction_object = self._endpoint.predict(instances)
155155
outputs = []
156156
for prediction in prediction_object.predictions:
157-
outputs.append({key: prediction[key] for key in self._output_types})
157+
if isinstance(prediction, Mapping):
158+
outputs.append({key: prediction[key] for key in self._output_types})
159+
else:
160+
outputs.append(
161+
{key: prediction[i] for i, key in enumerate(self._output_types)}
162+
)
158163
if self._explanation_enabled:
159164
for i, explanation in enumerate(prediction_object.explanations):
160165
attributions = explanation.attributions

tests/unit/aiplatform/test_explain_lit.py

+146-13
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@
105105
),
106106
]
107107
_TEST_TRAFFIC_SPLIT = {_TEST_ID: 0, _TEST_ID_2: 100, _TEST_ID_3: 0}
108-
_TEST_PREDICTION = [{"label": 1.0}]
108+
_TEST_DICT_PREDICTION = [{"label": 1.0}]
109+
_TEST_LIST_PREDICTION = [[1.0]]
109110
_TEST_EXPLANATIONS = [gca_prediction_service.explanation.Explanation(attributions=[])]
110111
_TEST_ATTRIBUTIONS = [
111112
gca_prediction_service.explanation.Attribution(
@@ -218,26 +219,54 @@ def get_endpoint_with_models_with_explanation_mock():
218219

219220

220221
@pytest.fixture
221-
def predict_client_predict_mock():
222+
def predict_client_predict_dict_mock():
222223
with mock.patch.object(
223224
prediction_service_client.PredictionServiceClient, "predict"
224225
) as predict_mock:
225226
predict_mock.return_value = gca_prediction_service.PredictResponse(
226227
deployed_model_id=_TEST_ID
227228
)
228-
predict_mock.return_value.predictions.extend(_TEST_PREDICTION)
229+
predict_mock.return_value.predictions.extend(_TEST_DICT_PREDICTION)
229230
yield predict_mock
230231

231232

232233
@pytest.fixture
233-
def predict_client_explain_mock():
234+
def predict_client_explain_dict_mock():
234235
with mock.patch.object(
235236
prediction_service_client.PredictionServiceClient, "explain"
236237
) as predict_mock:
237238
predict_mock.return_value = gca_prediction_service.ExplainResponse(
238239
deployed_model_id=_TEST_ID,
239240
)
240-
predict_mock.return_value.predictions.extend(_TEST_PREDICTION)
241+
predict_mock.return_value.predictions.extend(_TEST_DICT_PREDICTION)
242+
predict_mock.return_value.explanations.extend(_TEST_EXPLANATIONS)
243+
predict_mock.return_value.explanations[0].attributions.extend(
244+
_TEST_ATTRIBUTIONS
245+
)
246+
yield predict_mock
247+
248+
249+
@pytest.fixture
250+
def predict_client_predict_list_mock():
251+
with mock.patch.object(
252+
prediction_service_client.PredictionServiceClient, "predict"
253+
) as predict_mock:
254+
predict_mock.return_value = gca_prediction_service.PredictResponse(
255+
deployed_model_id=_TEST_ID
256+
)
257+
predict_mock.return_value.predictions.extend(_TEST_LIST_PREDICTION)
258+
yield predict_mock
259+
260+
261+
@pytest.fixture
262+
def predict_client_explain_list_mock():
263+
with mock.patch.object(
264+
prediction_service_client.PredictionServiceClient, "explain"
265+
) as predict_mock:
266+
predict_mock.return_value = gca_prediction_service.ExplainResponse(
267+
deployed_model_id=_TEST_ID,
268+
)
269+
predict_mock.return_value.predictions.extend(_TEST_LIST_PREDICTION)
241270
predict_mock.return_value.explanations.extend(_TEST_EXPLANATIONS)
242271
predict_mock.return_value.explanations[0].attributions.extend(
243272
_TEST_ATTRIBUTIONS
@@ -312,10 +341,112 @@ def test_create_lit_model_from_tensorflow_with_xai_returns_model(
312341
assert len(item.values()) == 2
313342

314343
@pytest.mark.usefixtures(
315-
"predict_client_predict_mock", "get_endpoint_with_models_mock"
344+
"predict_client_predict_dict_mock", "get_endpoint_with_models_mock"
345+
)
346+
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
347+
def test_create_lit_model_from_dict_endpoint_returns_model(
348+
self, feature_types, label_types, model_id
349+
):
350+
endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME)
351+
lit_model = create_lit_model_from_endpoint(
352+
endpoint, feature_types, label_types, model_id
353+
)
354+
test_inputs = [
355+
{"feature_1": 1.0, "feature_2": 2.0},
356+
]
357+
outputs = lit_model.predict_minibatch(test_inputs)
358+
359+
assert lit_model.input_spec() == dict(feature_types)
360+
assert lit_model.output_spec() == dict(label_types)
361+
assert len(outputs) == 1
362+
for item in outputs:
363+
assert item.keys() == {"label"}
364+
assert len(item.values()) == 1
365+
366+
@pytest.mark.usefixtures(
367+
"predict_client_explain_dict_mock",
368+
"get_endpoint_with_models_with_explanation_mock",
369+
)
370+
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
371+
def test_create_lit_model_from_dict_endpoint_with_xai_returns_model(
372+
self, feature_types, label_types, model_id
373+
):
374+
endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME)
375+
lit_model = create_lit_model_from_endpoint(
376+
endpoint, feature_types, label_types, model_id
377+
)
378+
test_inputs = [
379+
{"feature_1": 1.0, "feature_2": 2.0},
380+
]
381+
outputs = lit_model.predict_minibatch(test_inputs)
382+
383+
assert lit_model.input_spec() == dict(feature_types)
384+
assert lit_model.output_spec() == dict(
385+
{
386+
**label_types,
387+
"feature_attribution": lit_types.FeatureSalience(signed=True),
388+
}
389+
)
390+
assert len(outputs) == 1
391+
for item in outputs:
392+
assert item.keys() == {"label", "feature_attribution"}
393+
assert len(item.values()) == 2
394+
395+
@pytest.mark.usefixtures(
396+
"predict_client_predict_dict_mock", "get_endpoint_with_models_mock"
397+
)
398+
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
399+
def test_create_lit_model_from_dict_endpoint_name_returns_model(
400+
self, feature_types, label_types, model_id
401+
):
402+
lit_model = create_lit_model_from_endpoint(
403+
_TEST_ENDPOINT_NAME, feature_types, label_types, model_id
404+
)
405+
test_inputs = [
406+
{"feature_1": 1.0, "feature_2": 2.0},
407+
]
408+
outputs = lit_model.predict_minibatch(test_inputs)
409+
410+
assert lit_model.input_spec() == dict(feature_types)
411+
assert lit_model.output_spec() == dict(label_types)
412+
assert len(outputs) == 1
413+
for item in outputs:
414+
assert item.keys() == {"label"}
415+
assert len(item.values()) == 1
416+
417+
@pytest.mark.usefixtures(
418+
"predict_client_explain_dict_mock",
419+
"get_endpoint_with_models_with_explanation_mock",
420+
)
421+
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
422+
def test_create_lit_model_from_dict_endpoint_name_with_xai_returns_model(
423+
self, feature_types, label_types, model_id
424+
):
425+
lit_model = create_lit_model_from_endpoint(
426+
_TEST_ENDPOINT_NAME, feature_types, label_types, model_id
427+
)
428+
test_inputs = [
429+
{"feature_1": 1.0, "feature_2": 2.0},
430+
]
431+
outputs = lit_model.predict_minibatch(test_inputs)
432+
433+
assert lit_model.input_spec() == dict(feature_types)
434+
assert lit_model.output_spec() == dict(
435+
{
436+
**label_types,
437+
"feature_attribution": lit_types.FeatureSalience(signed=True),
438+
}
439+
)
440+
assert len(outputs) == 1
441+
for item in outputs:
442+
assert item.keys() == {"label", "feature_attribution"}
443+
assert len(item.values()) == 2
444+
445+
@pytest.mark.usefixtures(
446+
"predict_client_predict_list_mock", "get_endpoint_with_models_mock"
316447
)
317448
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
318-
def test_create_lit_model_from_endpoint_returns_model(
449+
def test_create_lit_model_from_list_endpoint_returns_model(
319450
self, feature_types, label_types, model_id
320451
):
321452
endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME)
@@ -335,10 +466,11 @@ def test_create_lit_model_from_endpoint_returns_model(
335466
assert len(item.values()) == 1
336467

337468
@pytest.mark.usefixtures(
338-
"predict_client_explain_mock", "get_endpoint_with_models_with_explanation_mock"
469+
"predict_client_explain_list_mock",
470+
"get_endpoint_with_models_with_explanation_mock",
339471
)
340472
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
341-
def test_create_lit_model_from_endpoint_with_xai_returns_model(
473+
def test_create_lit_model_from_list_endpoint_with_xai_returns_model(
342474
self, feature_types, label_types, model_id
343475
):
344476
endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME)
@@ -363,10 +495,10 @@ def test_create_lit_model_from_endpoint_with_xai_returns_model(
363495
assert len(item.values()) == 2
364496

365497
@pytest.mark.usefixtures(
366-
"predict_client_predict_mock", "get_endpoint_with_models_mock"
498+
"predict_client_predict_list_mock", "get_endpoint_with_models_mock"
367499
)
368500
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
369-
def test_create_lit_model_from_endpoint_name_returns_model(
501+
def test_create_lit_model_from_list_endpoint_name_returns_model(
370502
self, feature_types, label_types, model_id
371503
):
372504
lit_model = create_lit_model_from_endpoint(
@@ -385,10 +517,11 @@ def test_create_lit_model_from_endpoint_name_returns_model(
385517
assert len(item.values()) == 1
386518

387519
@pytest.mark.usefixtures(
388-
"predict_client_explain_mock", "get_endpoint_with_models_with_explanation_mock"
520+
"predict_client_explain_list_mock",
521+
"get_endpoint_with_models_with_explanation_mock",
389522
)
390523
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
391-
def test_create_lit_model_from_endpoint_name_with_xai_returns_model(
524+
def test_create_lit_model_from_list_endpoint_name_with_xai_returns_model(
392525
self, feature_types, label_types, model_id
393526
):
394527
lit_model = create_lit_model_from_endpoint(

0 commit comments

Comments
 (0)