Skip to content

Commit ea16849

Browse files
authored
feat: Open LIT with a deployed model (#963)
1 parent 7a7f0d4 commit ea16849

File tree

2 files changed

+589
-189
lines changed

2 files changed

+589
-189
lines changed

google/cloud/aiplatform/explain/lit.py

+205-70
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818
import os
1919

20+
from google.cloud import aiplatform
2021
from typing import Dict, List, Optional, Tuple, Union
2122

2223
try:
@@ -61,11 +62,11 @@ def __init__(
6162
):
6263
"""Construct a VertexLitDataset.
6364
Args:
64-
dataset:
65-
Required. A Pandas DataFrame that includes feature column names and data.
66-
column_types:
67-
Required. An OrderedDict of string names matching the columns of the dataset
68-
as the key, and the associated LitType of the column.
65+
dataset:
66+
Required. A Pandas DataFrame that includes feature column names and data.
67+
column_types:
68+
Required. An OrderedDict of string names matching the columns of the dataset
69+
as the key, and the associated LitType of the column.
6970
"""
7071
self._examples = dataset.to_dict(orient="records")
7172
self._column_types = column_types
@@ -75,8 +76,109 @@ def spec(self):
7576
return dict(self._column_types)
7677

7778

78-
class _VertexLitModel(lit_model.Model):
79-
"""LIT model class for the Vertex LIT integration.
79+
class _EndpointLitModel(lit_model.Model):
80+
"""LIT model class for the Vertex LIT integration with a model deployed to an endpoint.
81+
82+
This is used in the create_lit_model function.
83+
"""
84+
85+
def __init__(
86+
self,
87+
endpoint: Union[str, aiplatform.Endpoint],
88+
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
89+
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
90+
model_id: Optional[str] = None,
91+
):
92+
"""Construct a VertexLitModel.
93+
Args:
94+
model:
95+
Required. The name of the Endpoint resource. Format:
96+
``projects/{project}/locations/{location}/endpoints/{endpoint}``
97+
input_types:
98+
Required. An OrderedDict of string names matching the features of the model
99+
as the key, and the associated LitType of the feature.
100+
output_types:
101+
Required. An OrderedDict of string names matching the labels of the model
102+
as the key, and the associated LitType of the label.
103+
model_id:
104+
Optional. A string of the specific model in the endpoint to create the
105+
LIT model from. If this is not set, any usable model in the endpoint is
106+
used to create the LIT model.
107+
Raises:
108+
ValueError if the model_id was not found in the endpoint.
109+
"""
110+
if isinstance(endpoint, str):
111+
self._endpoint = aiplatform.Endpoint(endpoint)
112+
else:
113+
self._endpoint = endpoint
114+
self._model_id = model_id
115+
self._input_types = input_types
116+
self._output_types = output_types
117+
# Check if the model with the model ID has explanation enabled
118+
if model_id:
119+
deployed_model = next(
120+
filter(
121+
lambda model: model.id == model_id, self._endpoint.list_models()
122+
),
123+
None,
124+
)
125+
if not deployed_model:
126+
raise ValueError(
127+
"A model with id {model_id} was not found in the endpoint {endpoint}.".format(
128+
model_id=model_id, endpoint=endpoint
129+
)
130+
)
131+
self._explanation_enabled = bool(deployed_model.explanation_spec)
132+
# Check if all models in the endpoint have explanation enabled
133+
else:
134+
self._explanation_enabled = all(
135+
model.explanation_spec for model in self._endpoint.list_models()
136+
)
137+
138+
def predict_minibatch(
139+
self, inputs: List[lit_types.JsonDict]
140+
) -> List[lit_types.JsonDict]:
141+
"""Retun predictions based on a batch of inputs.
142+
Args:
143+
inputs: Requred. a List of instances to predict on based on the input spec.
144+
Returns:
145+
A list of predictions based on the output spec.
146+
"""
147+
instances = []
148+
for input in inputs:
149+
instance = [input[feature] for feature in self._input_types]
150+
instances.append(instance)
151+
if self._explanation_enabled:
152+
prediction_object = self._endpoint.explain(instances)
153+
else:
154+
prediction_object = self._endpoint.predict(instances)
155+
outputs = []
156+
for prediction in prediction_object.predictions:
157+
outputs.append({key: prediction[key] for key in self._output_types})
158+
if self._explanation_enabled:
159+
for i, explanation in enumerate(prediction_object.explanations):
160+
attributions = explanation.attributions
161+
outputs[i]["feature_attribution"] = lit_dtypes.FeatureSalience(
162+
attributions
163+
)
164+
return outputs
165+
166+
def input_spec(self) -> lit_types.Spec:
167+
"""Return a spec describing model inputs."""
168+
return dict(self._input_types)
169+
170+
def output_spec(self) -> lit_types.Spec:
171+
"""Return a spec describing model outputs."""
172+
output_spec_dict = dict(self._output_types)
173+
if self._explanation_enabled:
174+
output_spec_dict["feature_attribution"] = lit_types.FeatureSalience(
175+
signed=True
176+
)
177+
return output_spec_dict
178+
179+
180+
class _TensorFlowLitModel(lit_model.Model):
181+
"""LIT model class for the Vertex LIT integration with a TensorFlow saved model.
80182
81183
This is used in the create_lit_model function.
82184
"""
@@ -90,19 +192,19 @@ def __init__(
90192
):
91193
"""Construct a VertexLitModel.
92194
Args:
93-
model:
94-
Required. A string reference to a local TensorFlow saved model directory.
95-
The model must have at most one input and one output tensor.
96-
input_types:
97-
Required. An OrderedDict of string names matching the features of the model
98-
as the key, and the associated LitType of the feature.
99-
output_types:
100-
Required. An OrderedDict of string names matching the labels of the model
101-
as the key, and the associated LitType of the label.
102-
attribution_method:
103-
Optional. A string to choose what attribution configuration to
104-
set up the explainer with. Valid options are 'sampled_shapley'
105-
or 'integrated_gradients'.
195+
model:
196+
Required. A string reference to a local TensorFlow saved model directory.
197+
The model must have at most one input and one output tensor.
198+
input_types:
199+
Required. An OrderedDict of string names matching the features of the model
200+
as the key, and the associated LitType of the feature.
201+
output_types:
202+
Required. An OrderedDict of string names matching the labels of the model
203+
as the key, and the associated LitType of the label.
204+
attribution_method:
205+
Optional. A string to choose what attribution configuration to
206+
set up the explainer with. Valid options are 'sampled_shapley'
207+
or 'integrated_gradients'.
106208
"""
107209
self._load_model(model)
108210
self._input_types = input_types
@@ -120,6 +222,12 @@ def attribution_explainer(self,) -> Optional["AttributionExplainer"]: # noqa: F
120222
def predict_minibatch(
121223
self, inputs: List[lit_types.JsonDict]
122224
) -> List[lit_types.JsonDict]:
225+
"""Retun predictions based on a batch of inputs.
226+
Args:
227+
inputs: Requred. a List of instances to predict on based on the input spec.
228+
Returns:
229+
A list of predictions based on the output spec.
230+
"""
123231
instances = []
124232
for input in inputs:
125233
instance = [input[feature] for feature in self._input_types]
@@ -166,7 +274,7 @@ def output_spec(self) -> lit_types.Spec:
166274
def _load_model(self, model: str):
167275
"""Loads a TensorFlow saved model and populates the input and output signature attributes of the class.
168276
Args:
169-
model: Required. A string reference to a TensorFlow saved model directory.
277+
model: Required. A string reference to a TensorFlow saved model directory.
170278
Raises:
171279
ValueError if the model has more than one input tensor or more than one output tensor.
172280
"""
@@ -188,11 +296,11 @@ def _set_up_attribution_explainer(
188296
):
189297
"""Populates the attribution explainer attribute of the class.
190298
Args:
191-
model: Required. A string reference to a TensorFlow saved model directory.
299+
model: Required. A string reference to a TensorFlow saved model directory.
192300
attribution_method:
193-
Optional. A string to choose what attribution configuration to
194-
set up the explainer with. Valid options are 'sampled_shapley'
195-
or 'integrated_gradients'.
301+
Optional. A string to choose what attribution configuration to
302+
set up the explainer with. Valid options are 'sampled_shapley'
303+
or 'integrated_gradients'.
196304
"""
197305
try:
198306
import explainable_ai_sdk
@@ -228,17 +336,44 @@ def create_lit_dataset(
228336
) -> lit_dataset.Dataset:
229337
"""Creates a LIT Dataset object.
230338
Args:
231-
dataset:
232-
Required. A Pandas DataFrame that includes feature column names and data.
233-
column_types:
234-
Required. An OrderedDict of string names matching the columns of the dataset
235-
as the key, and the associated LitType of the column.
339+
dataset:
340+
Required. A Pandas DataFrame that includes feature column names and data.
341+
column_types:
342+
Required. An OrderedDict of string names matching the columns of the dataset
343+
as the key, and the associated LitType of the column.
236344
Returns:
237345
A LIT Dataset object that has the data from the dataset provided.
238346
"""
239347
return _VertexLitDataset(dataset, column_types)
240348

241349

350+
def create_lit_model_from_endpoint(
351+
endpoint: Union[str, aiplatform.Endpoint],
352+
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
353+
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
354+
model_id: Optional[str] = None,
355+
) -> lit_model.Model:
356+
"""Creates a LIT Model object.
357+
Args:
358+
model:
359+
Required. The name of the Endpoint resource or an Endpoint instance.
360+
Endpoint name format: ``projects/{project}/locations/{location}/endpoints/{endpoint}``
361+
input_types:
362+
Required. An OrderedDict of string names matching the features of the model
363+
as the key, and the associated LitType of the feature.
364+
output_types:
365+
Required. An OrderedDict of string names matching the labels of the model
366+
as the key, and the associated LitType of the label.
367+
model_id:
368+
Optional. A string of the specific model in the endpoint to create the
369+
LIT model from. If this is not set, any usable model in the endpoint is
370+
used to create the LIT model.
371+
Returns:
372+
A LIT Model object that has the same functionality as the model provided.
373+
"""
374+
return _EndpointLitModel(endpoint, input_types, output_types, model_id)
375+
376+
242377
def create_lit_model(
243378
model: str,
244379
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
@@ -247,23 +382,23 @@ def create_lit_model(
247382
) -> lit_model.Model:
248383
"""Creates a LIT Model object.
249384
Args:
250-
model:
251-
Required. A string reference to a local TensorFlow saved model directory.
252-
The model must have at most one input and one output tensor.
253-
input_types:
254-
Required. An OrderedDict of string names matching the features of the model
255-
as the key, and the associated LitType of the feature.
256-
output_types:
257-
Required. An OrderedDict of string names matching the labels of the model
258-
as the key, and the associated LitType of the label.
259-
attribution_method:
260-
Optional. A string to choose what attribution configuration to
261-
set up the explainer with. Valid options are 'sampled_shapley'
262-
or 'integrated_gradients'.
385+
model:
386+
Required. A string reference to a local TensorFlow saved model directory.
387+
The model must have at most one input and one output tensor.
388+
input_types:
389+
Required. An OrderedDict of string names matching the features of the model
390+
as the key, and the associated LitType of the feature.
391+
output_types:
392+
Required. An OrderedDict of string names matching the labels of the model
393+
as the key, and the associated LitType of the label.
394+
attribution_method:
395+
Optional. A string to choose what attribution configuration to
396+
set up the explainer with. Valid options are 'sampled_shapley'
397+
or 'integrated_gradients'.
263398
Returns:
264399
A LIT Model object that has the same functionality as the model provided.
265400
"""
266-
return _VertexLitModel(model, input_types, output_types, attribution_method)
401+
return _TensorFlowLitModel(model, input_types, output_types, attribution_method)
267402

268403

269404
def open_lit(
@@ -273,12 +408,12 @@ def open_lit(
273408
):
274409
"""Open LIT from the provided models and datasets.
275410
Args:
276-
models:
277-
Required. A list of LIT models to open LIT with.
278-
input_types:
279-
Required. A lit of LIT datasets to open LIT with.
280-
open_in_new_tab:
281-
Optional. A boolean to choose if LIT open in a new tab or not.
411+
models:
412+
Required. A list of LIT models to open LIT with.
413+
input_types:
414+
Required. A lit of LIT datasets to open LIT with.
415+
open_in_new_tab:
416+
Optional. A boolean to choose if LIT open in a new tab or not.
282417
Raises:
283418
ImportError if LIT is not installed.
284419
"""
@@ -297,26 +432,26 @@ def set_up_and_open_lit(
297432
) -> Tuple[lit_dataset.Dataset, lit_model.Model]:
298433
"""Creates a LIT dataset and model and opens LIT.
299434
Args:
300-
dataset:
301-
Required. A Pandas DataFrame that includes feature column names and data.
302-
column_types:
303-
Required. An OrderedDict of string names matching the columns of the dataset
304-
as the key, and the associated LitType of the column.
305-
model:
306-
Required. A string reference to a TensorFlow saved model directory.
307-
The model must have at most one input and one output tensor.
308-
input_types:
309-
Required. An OrderedDict of string names matching the features of the model
310-
as the key, and the associated LitType of the feature.
311-
output_types:
312-
Required. An OrderedDict of string names matching the labels of the model
313-
as the key, and the associated LitType of the label.
314-
attribution_method:
315-
Optional. A string to choose what attribution configuration to
316-
set up the explainer with. Valid options are 'sampled_shapley'
317-
or 'integrated_gradients'.
318-
open_in_new_tab:
319-
Optional. A boolean to choose if LIT open in a new tab or not.
435+
dataset:
436+
Required. A Pandas DataFrame that includes feature column names and data.
437+
column_types:
438+
Required. An OrderedDict of string names matching the columns of the dataset
439+
as the key, and the associated LitType of the column.
440+
model:
441+
Required. A string reference to a TensorFlow saved model directory.
442+
The model must have at most one input and one output tensor.
443+
input_types:
444+
Required. An OrderedDict of string names matching the features of the model
445+
as the key, and the associated LitType of the feature.
446+
output_types:
447+
Required. An OrderedDict of string names matching the labels of the model
448+
as the key, and the associated LitType of the label.
449+
attribution_method:
450+
Optional. A string to choose what attribution configuration to
451+
set up the explainer with. Valid options are 'sampled_shapley'
452+
or 'integrated_gradients'.
453+
open_in_new_tab:
454+
Optional. A boolean to choose if LIT open in a new tab or not.
320455
Returns:
321456
A Tuple of the LIT dataset and model created.
322457
Raises:

0 commit comments

Comments
 (0)