Skip to content

Commit 8e2ad75

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Adding support for concurrent explanations
PiperOrigin-RevId: 586740015
1 parent ae3677c commit 8e2ad75

File tree

2 files changed

+338
-0
lines changed

2 files changed

+338
-0
lines changed

google/cloud/aiplatform/preview/models.py

+198
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,36 @@
3232
endpoint_service_client,
3333
)
3434
from google.cloud.aiplatform.compat.types import (
35+
prediction_service_v1beta1 as gca_prediction_service_compat,
3536
deployed_model_ref_v1beta1 as gca_deployed_model_ref_compat,
3637
deployment_resource_pool_v1beta1 as gca_deployment_resource_pool_compat,
38+
explanation_v1beta1 as gca_explanation_compat,
3739
endpoint_v1beta1 as gca_endpoint_compat,
3840
machine_resources_v1beta1 as gca_machine_resources_compat,
3941
model_v1 as gca_model_compat,
4042
)
43+
from google.protobuf import json_format
4144

4245
_DEFAULT_MACHINE_TYPE = "n1-standard-2"
4346

4447
_LOGGER = base.Logger(__name__)
4548

4649

50+
class Prediction(models.Prediction):
51+
"""Prediction class envelopes returned Model predictions and the Model id.
52+
53+
Attributes:
54+
concurrent_explanations:
55+
Map of explanations that were requested concurrently in addition to
56+
the default explanation for the Model's predictions. It has the same
57+
number of elements as instances to be explained. Default is None.
58+
"""
59+
60+
concurrent_explanations: Optional[
61+
Dict[str, Sequence[gca_explanation_compat.Explanation]]
62+
] = None
63+
64+
4765
class DeploymentResourcePool(base.VertexAiResourceNounWithFutureManager):
4866
client_class = utils.DeploymentResourcePoolClientWithOverride
4967
_resource_noun = "deploymentResourcePools"
@@ -1013,6 +1031,186 @@ def _deploy_call(
10131031

10141032
operation_future.result(timeout=None)
10151033

1034+
def explain(
1035+
self,
1036+
instances: List[Dict],
1037+
parameters: Optional[Dict] = None,
1038+
deployed_model_id: Optional[str] = None,
1039+
timeout: Optional[float] = None,
1040+
explanation_spec_override: Optional[Dict] = None,
1041+
concurrent_explanation_spec_override: Optional[Dict] = None,
1042+
) -> Prediction:
1043+
"""Make a prediction with explanations against this Endpoint.
1044+
1045+
Example usage:
1046+
response = my_endpoint.explain(instances=[...])
1047+
my_explanations = response.explanations
1048+
1049+
Args:
1050+
instances (List):
1051+
Required. The instances that are the input to the
1052+
prediction call. A DeployedModel may have an upper limit
1053+
on the number of instances it supports per request, and
1054+
when it is exceeded the prediction call errors in case
1055+
of AutoML Models, or, in case of customer created
1056+
Models, the behaviour is as documented by that Model.
1057+
The schema of any single instance may be specified via
1058+
Endpoint's DeployedModels'
1059+
[Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model]
1060+
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
1061+
``instance_schema_uri``.
1062+
parameters (Dict):
1063+
The parameters that govern the prediction. The schema of
1064+
the parameters may be specified via Endpoint's
1065+
DeployedModels' [Model's
1066+
][google.cloud.aiplatform.v1beta1.DeployedModel.model]
1067+
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
1068+
``parameters_schema_uri``.
1069+
deployed_model_id (str):
1070+
Optional. If specified, this ExplainRequest will be served by the
1071+
chosen DeployedModel, overriding this Endpoint's traffic split.
1072+
timeout (float): Optional. The timeout for this request in seconds.
1073+
explanation_spec_override (Dict):
1074+
Optional. Represents overrides to the explaination
1075+
specification used when the model was deployed.
1076+
The Explanation Override will
1077+
be merged with model's existing [Explanation Spec
1078+
][google.cloud.aiplatform.v1beta1.ExplanationSpec].
1079+
concurrent_explanation_spec_override (Dict):
1080+
Optional. The ``explain`` endpoint supports multiple
1081+
explanations in parallel. To request concurrent explanation in
1082+
addition to the configured explaination method, use this field.
1083+
1084+
Returns:
1085+
prediction (aiplatform.Prediction):
1086+
Prediction with returned predictions, explanations, and Model ID.
1087+
"""
1088+
self.wait()
1089+
request = gca_prediction_service_compat.ExplainRequest()
1090+
1091+
if instances is not None:
1092+
request.instances.extend(instances)
1093+
if parameters is not None:
1094+
request.parameters = parameters
1095+
if deployed_model_id is not None:
1096+
request.deployed_model_id = deployed_model_id
1097+
if explanation_spec_override is not None:
1098+
request.explanation_spec_override = explanation_spec_override
1099+
if concurrent_explanation_spec_override is not None:
1100+
request.concurrent_explanation_spec_override = (
1101+
concurrent_explanation_spec_override
1102+
)
1103+
1104+
explain_response = self._prediction_client.select_version("v1beta1").explain(
1105+
request, timeout=timeout
1106+
)
1107+
1108+
prediction = Prediction(
1109+
predictions=[
1110+
json_format.MessageToDict(item)
1111+
for item in explain_response.predictions.pb
1112+
],
1113+
deployed_model_id=explain_response.deployed_model_id,
1114+
explanations=explain_response.explanations,
1115+
)
1116+
1117+
concurrent_explanation = {}
1118+
for k, e in explain_response.concurrent_explanations.items():
1119+
concurrent_explanation[k] = e.explanations
1120+
1121+
prediction.concurrent_explanations = concurrent_explanation
1122+
1123+
return prediction
1124+
1125+
async def explain_async(
1126+
self,
1127+
instances: List[Dict],
1128+
*,
1129+
parameters: Optional[Dict] = None,
1130+
deployed_model_id: Optional[str] = None,
1131+
timeout: Optional[float] = None,
1132+
explanation_spec_override: Optional[Dict] = None,
1133+
concurrent_explanation_spec_override: Optional[Dict] = None,
1134+
) -> Prediction:
1135+
"""Make a prediction with explanations against this Endpoint.
1136+
1137+
Example usage:
1138+
```
1139+
response = await my_endpoint.explain_async(instances=[...])
1140+
my_explanations = response.explanations
1141+
```
1142+
1143+
Args:
1144+
instances (List):
1145+
Required. The instances that are the input to the
1146+
prediction call. A DeployedModel may have an upper limit
1147+
on the number of instances it supports per request, and
1148+
when it is exceeded the prediction call errors in case
1149+
of AutoML Models, or, in case of customer created
1150+
Models, the behaviour is as documented by that Model.
1151+
The schema of any single instance may be specified via
1152+
Endpoint's DeployedModels'
1153+
[Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model]
1154+
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
1155+
``instance_schema_uri``.
1156+
parameters (Dict):
1157+
The parameters that govern the prediction. The schema of
1158+
the parameters may be specified via Endpoint's
1159+
DeployedModels' [Model's
1160+
][google.cloud.aiplatform.v1beta1.DeployedModel.model]
1161+
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
1162+
``parameters_schema_uri``.
1163+
deployed_model_id (str):
1164+
Optional. If specified, this ExplainRequest will be served by the
1165+
chosen DeployedModel, overriding this Endpoint's traffic split.
1166+
timeout (float): Optional. The timeout for this request in seconds.
1167+
explanation_spec_override (Dict):
1168+
Optional. Represents overrides to the explaination
1169+
specification used when the model was deployed.
1170+
The Explanation Override will
1171+
be merged with model's existing [Explanation Spec
1172+
][google.cloud.aiplatform.v1beta1.ExplanationSpec].
1173+
concurrent_explanation_spec_override (Dict):
1174+
Optional. The ``explain`` endpoint supports multiple
1175+
explanations in parallel. To request concurrent explanation in
1176+
addition to the configured explaination method, use this field.
1177+
1178+
Returns:
1179+
prediction (aiplatform.Prediction):
1180+
Prediction with returned predictions, explanations, and Model ID.
1181+
"""
1182+
self.wait()
1183+
1184+
request = gca_prediction_service_compat.ExplainRequest(
1185+
endpoint=self.resource_name,
1186+
instances=instances,
1187+
parameters=parameters,
1188+
deployed_model_id=deployed_model_id,
1189+
explanation_spec_override=explanation_spec_override,
1190+
concurrent_explanation_spec_override=concurrent_explanation_spec_override,
1191+
)
1192+
1193+
explain_response = await self._prediction_async_client.select_version(
1194+
"v1beta1"
1195+
).explain(request, timeout=timeout)
1196+
1197+
prediction = Prediction(
1198+
predictions=[
1199+
json_format.MessageToDict(item)
1200+
for item in explain_response.predictions.pb
1201+
],
1202+
deployed_model_id=explain_response.deployed_model_id,
1203+
explanations=explain_response.explanations,
1204+
)
1205+
1206+
concurrent_explanation = {}
1207+
for k, e in explain_response.concurrent_explanations.items():
1208+
concurrent_explanation[k] = e.explanations
1209+
1210+
prediction.concurrent_explanations = concurrent_explanation
1211+
1212+
return prediction
1213+
10161214

10171215
class Model(aiplatform.Model):
10181216
def deploy(

0 commit comments

Comments
 (0)