Skip to content

Commit 957703f

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: add explanationSpec to TrainingPipeline-based custom jobs
PiperOrigin-RevId: 492054553
1 parent 43a2679 commit 957703f

File tree

5 files changed

+381
-90
lines changed

5 files changed

+381
-90
lines changed

google/cloud/aiplatform/explain/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
ExplanationParameters = explanation_compat.ExplanationParameters
3636
FeatureNoiseSigma = explanation_compat.FeatureNoiseSigma
3737

38+
ExplanationSpec = explanation_compat.ExplanationSpec
39+
3840
# Classes used by ExplanationParameters
3941
IntegratedGradientsAttribution = explanation_compat.IntegratedGradientsAttribution
4042
SampledShapleyAttribution = explanation_compat.SampledShapleyAttribution
@@ -44,6 +46,7 @@
4446

4547
__all__ = (
4648
"Encoding",
49+
"ExplanationSpec",
4750
"ExplanationMetadata",
4851
"ExplanationParameters",
4952
"FeatureNoiseSigma",

google/cloud/aiplatform/models.py

+33-87
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
from google.cloud.aiplatform import models
4848
from google.cloud.aiplatform import utils
4949
from google.cloud.aiplatform.utils import gcs_utils
50+
from google.cloud.aiplatform.utils import _explanation_utils
5051
from google.cloud.aiplatform import model_evaluation
51-
5252
from google.cloud.aiplatform.compat.services import endpoint_service_client
5353

5454
from google.cloud.aiplatform.compat.types import (
@@ -617,10 +617,6 @@ def _validate_deploy_args(
617617
deployed_model_display_name: Optional[str],
618618
traffic_split: Optional[Dict[str, int]],
619619
traffic_percentage: Optional[int],
620-
explanation_metadata: Optional[aiplatform.explain.ExplanationMetadata] = None,
621-
explanation_parameters: Optional[
622-
aiplatform.explain.ExplanationParameters
623-
] = None,
624620
):
625621
"""Helper method to validate deploy arguments.
626622
@@ -663,20 +659,10 @@ def _validate_deploy_args(
663659
not be provided. Traffic of previously deployed models at the endpoint
664660
will be scaled down to accommodate new deployed model's traffic.
665661
Should not be provided if traffic_split is provided.
666-
explanation_metadata (aiplatform.explain.ExplanationMetadata):
667-
Optional. Metadata describing the Model's input and output for explanation.
668-
`explanation_metadata` is optional while `explanation_parameters` must be
669-
specified when used.
670-
For more details, see `Ref docs <http://tinyurl.com/1igh60kt>`
671-
explanation_parameters (aiplatform.explain.ExplanationParameters):
672-
Optional. Parameters to configure explaining for Model's predictions.
673-
For more details, see `Ref docs <http://tinyurl.com/1an4zake>`
674662
675663
Raises:
676664
ValueError: if Min or Max replica is negative. Traffic percentage > 100 or
677665
< 0. Or if traffic_split does not sum to 100.
678-
ValueError: if explanation_metadata is specified while explanation_parameters
679-
is not.
680666
"""
681667
if min_replica_count < 0:
682668
raise ValueError("Min replica cannot be negative.")
@@ -697,11 +683,6 @@ def _validate_deploy_args(
697683
"Sum of all traffic within traffic split needs to be 100."
698684
)
699685

700-
if bool(explanation_metadata) and not bool(explanation_parameters):
701-
raise ValueError(
702-
"To get model explanation, `explanation_parameters` must be specified."
703-
)
704-
705686
# Raises ValueError if invalid accelerator
706687
if accelerator_type:
707688
utils.validate_accelerator_type(accelerator_type)
@@ -817,6 +798,9 @@ def deploy(
817798
deployed_model_display_name=deployed_model_display_name,
818799
traffic_split=traffic_split,
819800
traffic_percentage=traffic_percentage,
801+
)
802+
803+
explanation_spec = _explanation_utils.create_and_validate_explanation_spec(
820804
explanation_metadata=explanation_metadata,
821805
explanation_parameters=explanation_parameters,
822806
)
@@ -832,8 +816,7 @@ def deploy(
832816
accelerator_type=accelerator_type,
833817
accelerator_count=accelerator_count,
834818
service_account=service_account,
835-
explanation_metadata=explanation_metadata,
836-
explanation_parameters=explanation_parameters,
819+
explanation_spec=explanation_spec,
837820
metadata=metadata,
838821
sync=sync,
839822
deploy_request_timeout=deploy_request_timeout,
@@ -854,10 +837,7 @@ def _deploy(
854837
accelerator_type: Optional[str] = None,
855838
accelerator_count: Optional[int] = None,
856839
service_account: Optional[str] = None,
857-
explanation_metadata: Optional[aiplatform.explain.ExplanationMetadata] = None,
858-
explanation_parameters: Optional[
859-
aiplatform.explain.ExplanationParameters
860-
] = None,
840+
explanation_spec: Optional[aiplatform.explain.ExplanationSpec] = None,
861841
metadata: Optional[Sequence[Tuple[str, str]]] = (),
862842
sync=True,
863843
deploy_request_timeout: Optional[float] = None,
@@ -919,14 +899,8 @@ def _deploy(
919899
to the resource project.
920900
Users deploying the Model must have the `iam.serviceAccounts.actAs`
921901
permission on this service account.
922-
explanation_metadata (aiplatform.explain.ExplanationMetadata):
923-
Optional. Metadata describing the Model's input and output for explanation.
924-
`explanation_metadata` is optional while `explanation_parameters` must be
925-
specified when used.
926-
For more details, see `Ref docs <http://tinyurl.com/1igh60kt>`
927-
explanation_parameters (aiplatform.explain.ExplanationParameters):
928-
Optional. Parameters to configure explaining for Model's predictions.
929-
For more details, see `Ref docs <http://tinyurl.com/1an4zake>`
902+
explanation_spec (aiplatform.explain.ExplanationSpec):
903+
Optional. Specification of Model explanation.
930904
metadata (Sequence[Tuple[str, str]]):
931905
Optional. Strings which should be sent along with the request as
932906
metadata.
@@ -963,8 +937,7 @@ def _deploy(
963937
accelerator_type=accelerator_type,
964938
accelerator_count=accelerator_count,
965939
service_account=service_account,
966-
explanation_metadata=explanation_metadata,
967-
explanation_parameters=explanation_parameters,
940+
explanation_spec=explanation_spec,
968941
metadata=metadata,
969942
deploy_request_timeout=deploy_request_timeout,
970943
autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization,
@@ -992,10 +965,7 @@ def _deploy_call(
992965
accelerator_type: Optional[str] = None,
993966
accelerator_count: Optional[int] = None,
994967
service_account: Optional[str] = None,
995-
explanation_metadata: Optional[aiplatform.explain.ExplanationMetadata] = None,
996-
explanation_parameters: Optional[
997-
aiplatform.explain.ExplanationParameters
998-
] = None,
968+
explanation_spec: Optional[aiplatform.explain.ExplanationSpec] = None,
999969
metadata: Optional[Sequence[Tuple[str, str]]] = (),
1000970
deploy_request_timeout: Optional[float] = None,
1001971
autoscaling_target_cpu_utilization: Optional[int] = None,
@@ -1066,14 +1036,8 @@ def _deploy_call(
10661036
to the resource project.
10671037
Users deploying the Model must have the `iam.serviceAccounts.actAs`
10681038
permission on this service account.
1069-
explanation_metadata (aiplatform.explain.ExplanationMetadata):
1070-
Optional. Metadata describing the Model's input and output for explanation.
1071-
`explanation_metadata` is optional while `explanation_parameters` must be
1072-
specified when used.
1073-
For more details, see `Ref docs <http://tinyurl.com/1igh60kt>`
1074-
explanation_parameters (aiplatform.explain.ExplanationParameters):
1075-
Optional. Parameters to configure explaining for Model's predictions.
1076-
For more details, see `Ref docs <http://tinyurl.com/1an4zake>`
1039+
explanation_spec (aiplatform.explain.ExplanationSpec):
1040+
Optional. Specification of Model explanation.
10771041
metadata (Sequence[Tuple[str, str]]):
10781042
Optional. Strings which should be sent along with the request as
10791043
metadata.
@@ -1199,13 +1163,7 @@ def _deploy_call(
11991163
"See https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#google.cloud.aiplatform.v1.Model.FIELDS.repeated.google.cloud.aiplatform.v1.Model.DeploymentResourcesType.google.cloud.aiplatform.v1.Model.supported_deployment_resources_types"
12001164
)
12011165

1202-
# Service will throw error if explanation_parameters is not provided
1203-
if explanation_parameters:
1204-
explanation_spec = gca_endpoint_compat.explanation.ExplanationSpec()
1205-
explanation_spec.parameters = explanation_parameters
1206-
if explanation_metadata:
1207-
explanation_spec.metadata = explanation_metadata
1208-
deployed_model.explanation_spec = explanation_spec
1166+
deployed_model.explanation_spec = explanation_spec
12091167

12101168
# Checking if traffic percentage is valid
12111169
# TODO(b/221059294) PrivateEndpoint should support traffic split
@@ -2332,6 +2290,9 @@ def deploy(
23322290
deployed_model_display_name=deployed_model_display_name,
23332291
traffic_split=None,
23342292
traffic_percentage=100,
2293+
)
2294+
2295+
explanation_spec = _explanation_utils.create_and_validate_explanation_spec(
23352296
explanation_metadata=explanation_metadata,
23362297
explanation_parameters=explanation_parameters,
23372298
)
@@ -2347,8 +2308,7 @@ def deploy(
23472308
accelerator_type=accelerator_type,
23482309
accelerator_count=accelerator_count,
23492310
service_account=service_account,
2350-
explanation_metadata=explanation_metadata,
2351-
explanation_parameters=explanation_parameters,
2311+
explanation_spec=explanation_spec,
23522312
metadata=metadata,
23532313
sync=sync,
23542314
)
@@ -3004,11 +2964,6 @@ def upload(
30042964
if labels:
30052965
utils.validate_labels(labels)
30062966

3007-
if bool(explanation_metadata) and not bool(explanation_parameters):
3008-
raise ValueError(
3009-
"To get model explanation, `explanation_parameters` must be specified."
3010-
)
3011-
30122967
appended_user_agent = None
30132968
if local_model:
30142969
container_spec = local_model.get_serving_container_spec()
@@ -3109,13 +3064,12 @@ def upload(
31093064
if artifact_uri:
31103065
managed_model.artifact_uri = artifact_uri
31113066

3112-
# Override explanation_spec if required field is provided
3113-
if explanation_parameters:
3114-
explanation_spec = gca_endpoint_compat.explanation.ExplanationSpec()
3115-
explanation_spec.parameters = explanation_parameters
3116-
if explanation_metadata:
3117-
explanation_spec.metadata = explanation_metadata
3118-
managed_model.explanation_spec = explanation_spec
3067+
managed_model.explanation_spec = (
3068+
_explanation_utils.create_and_validate_explanation_spec(
3069+
explanation_metadata=explanation_metadata,
3070+
explanation_parameters=explanation_parameters,
3071+
)
3072+
)
31193073

31203074
request = gca_model_service_compat.UploadModelRequest(
31213075
parent=initializer.global_config.common_location_path(project, location),
@@ -3283,8 +3237,6 @@ def deploy(
32833237
deployed_model_display_name=deployed_model_display_name,
32843238
traffic_split=traffic_split,
32853239
traffic_percentage=traffic_percentage,
3286-
explanation_metadata=explanation_metadata,
3287-
explanation_parameters=explanation_parameters,
32883240
)
32893241

32903242
if isinstance(endpoint, PrivateEndpoint):
@@ -3295,6 +3247,11 @@ def deploy(
32953247
"A maximum of one model can be deployed to each private Endpoint."
32963248
)
32973249

3250+
explanation_spec = _explanation_utils.create_and_validate_explanation_spec(
3251+
explanation_metadata=explanation_metadata,
3252+
explanation_parameters=explanation_parameters,
3253+
)
3254+
32983255
return self._deploy(
32993256
endpoint=endpoint,
33003257
deployed_model_display_name=deployed_model_display_name,
@@ -3306,8 +3263,7 @@ def deploy(
33063263
accelerator_type=accelerator_type,
33073264
accelerator_count=accelerator_count,
33083265
service_account=service_account,
3309-
explanation_metadata=explanation_metadata,
3310-
explanation_parameters=explanation_parameters,
3266+
explanation_spec=explanation_spec,
33113267
metadata=metadata,
33123268
encryption_spec_key_name=encryption_spec_key_name
33133269
or initializer.global_config.encryption_spec_key_name,
@@ -3331,10 +3287,7 @@ def _deploy(
33313287
accelerator_type: Optional[str] = None,
33323288
accelerator_count: Optional[int] = None,
33333289
service_account: Optional[str] = None,
3334-
explanation_metadata: Optional[aiplatform.explain.ExplanationMetadata] = None,
3335-
explanation_parameters: Optional[
3336-
aiplatform.explain.ExplanationParameters
3337-
] = None,
3290+
explanation_spec: Optional[aiplatform.explain.ExplanationSpec] = None,
33383291
metadata: Optional[Sequence[Tuple[str, str]]] = (),
33393292
encryption_spec_key_name: Optional[str] = None,
33403293
network: Optional[str] = None,
@@ -3398,14 +3351,8 @@ def _deploy(
33983351
to the resource project.
33993352
Users deploying the Model must have the `iam.serviceAccounts.actAs`
34003353
permission on this service account.
3401-
explanation_metadata (aiplatform.explain.ExplanationMetadata):
3402-
Optional. Metadata describing the Model's input and output for explanation.
3403-
`explanation_metadata` is optional while `explanation_parameters` must be
3404-
specified when used.
3405-
For more details, see `Ref docs <http://tinyurl.com/1igh60kt>`
3406-
explanation_parameters (aiplatform.explain.ExplanationParameters):
3407-
Optional. Parameters to configure explaining for Model's predictions.
3408-
For more details, see `Ref docs <http://tinyurl.com/1an4zake>`
3354+
explanation_spec (aiplatform.explain.ExplanationSpec):
3355+
Optional. Specification of Model explanation.
34093356
metadata (Sequence[Tuple[str, str]]):
34103357
Optional. Strings which should be sent along with the request as
34113358
metadata.
@@ -3483,8 +3430,7 @@ def _deploy(
34833430
accelerator_type=accelerator_type,
34843431
accelerator_count=accelerator_count,
34853432
service_account=service_account,
3486-
explanation_metadata=explanation_metadata,
3487-
explanation_parameters=explanation_parameters,
3433+
explanation_spec=explanation_spec,
34883434
metadata=metadata,
34893435
deploy_request_timeout=deploy_request_timeout,
34903436
autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization,

0 commit comments

Comments
 (0)