Skip to content

Commit f9ca1d5

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: add sdk support for xai example-based explanations
PiperOrigin-RevId: 545803031
1 parent 718f04b commit f9ca1d5

File tree

2 files changed

+131
-0
lines changed

2 files changed

+131
-0
lines changed

google/cloud/aiplatform/explain/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
SampledShapleyAttribution = explanation_compat.SampledShapleyAttribution
4343
SmoothGradConfig = explanation_compat.SmoothGradConfig
4444
XraiAttribution = explanation_compat.XraiAttribution
45+
Presets = explanation_compat.Presets
46+
Examples = explanation_compat.Examples
4547

4648

4749
__all__ = (
@@ -58,4 +60,6 @@
5860
"SmoothGradConfig",
5961
"Visualization",
6062
"XraiAttribution",
63+
"Presets",
64+
"Examples",
6165
)

tests/unit/aiplatform/test_models.py

+127
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,59 @@
152152
_TEST_EXPLANATION_PARAMETERS = (
153153
test_constants.ModelConstants._TEST_EXPLANATION_PARAMETERS
154154
)
155+
_TEST_EXPLANATION_METADATA_EXAMPLES = explain.ExplanationMetadata(
156+
outputs={"embedding": {"output_tensor_name": "embedding"}},
157+
inputs={
158+
"my_input": {
159+
"input_tensor_name": "bytes_inputs",
160+
"encoding": "IDENTITY",
161+
"modality": "image",
162+
},
163+
"id": {"input_tensor_name": "id", "encoding": "IDENTITY"},
164+
},
165+
)
166+
_TEST_EXPLANATION_PARAMETERS_EXAMPLES_PRESETS = explain.ExplanationParameters(
167+
{
168+
"examples": {
169+
"example_gcs_source": {
170+
"gcs_source": {
171+
"uris": ["gs://example-bucket/folder/instance1.jsonl"],
172+
},
173+
},
174+
"neighbor_count": 10,
175+
"presets": {"query": "FAST", "modality": "TEXT"},
176+
}
177+
}
178+
)
179+
_TEST_EXPLANATION_PARAMETERS_EXAMPLES_FULL_CONFIG = explain.ExplanationParameters(
180+
{
181+
"examples": {
182+
"example_gcs_source": {
183+
"gcs_source": {
184+
"uris": ["gs://example-bucket/folder/instance1.jsonl"],
185+
},
186+
},
187+
"neighbor_count": 10,
188+
"nearest_neighbor_search_config": [
189+
{
190+
"contentsDeltaUri": "",
191+
"config": {
192+
"dimensions": 50,
193+
"approximateNeighborsCount": 10,
194+
"distanceMeasureType": "SQUARED_L2_DISTANCE",
195+
"featureNormType": "NONE",
196+
"algorithmConfig": {
197+
"treeAhConfig": {
198+
"leafNodeEmbeddingCount": 1000,
199+
"leafNodesToSearchPercent": 100,
200+
}
201+
},
202+
},
203+
}
204+
],
205+
}
206+
}
207+
)
155208

156209
# CMEK encryption
157210
_TEST_ENCRYPTION_KEY_NAME = "key_1234"
@@ -1119,6 +1172,80 @@ def test_upload_with_parameters_without_metadata(
11191172
timeout=None,
11201173
)
11211174

1175+
@pytest.mark.parametrize("sync", [True, False])
1176+
def test_upload_with_parameters_for_examples_presets(
1177+
self, upload_model_mock, get_model_mock, sync
1178+
):
1179+
my_model = models.Model.upload(
1180+
display_name=_TEST_MODEL_NAME,
1181+
serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
1182+
explanation_parameters=_TEST_EXPLANATION_PARAMETERS_EXAMPLES_PRESETS,
1183+
explanation_metadata=_TEST_EXPLANATION_METADATA_EXAMPLES,
1184+
sync=sync,
1185+
)
1186+
1187+
if not sync:
1188+
my_model.wait()
1189+
1190+
container_spec = gca_model.ModelContainerSpec(
1191+
image_uri=_TEST_SERVING_CONTAINER_IMAGE,
1192+
)
1193+
1194+
managed_model = gca_model.Model(
1195+
display_name=_TEST_MODEL_NAME,
1196+
container_spec=container_spec,
1197+
explanation_spec=gca_model.explanation.ExplanationSpec(
1198+
metadata=_TEST_EXPLANATION_METADATA_EXAMPLES,
1199+
parameters=_TEST_EXPLANATION_PARAMETERS_EXAMPLES_PRESETS,
1200+
),
1201+
version_aliases=["default"],
1202+
)
1203+
1204+
upload_model_mock.assert_called_once_with(
1205+
request=gca_model_service.UploadModelRequest(
1206+
parent=initializer.global_config.common_location_path(),
1207+
model=managed_model,
1208+
),
1209+
timeout=None,
1210+
)
1211+
1212+
@pytest.mark.parametrize("sync", [True, False])
1213+
def test_upload_with_parameters_for_examples_full_config(
1214+
self, upload_model_mock, get_model_mock, sync
1215+
):
1216+
my_model = models.Model.upload(
1217+
display_name=_TEST_MODEL_NAME,
1218+
serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
1219+
explanation_parameters=_TEST_EXPLANATION_PARAMETERS_EXAMPLES_FULL_CONFIG,
1220+
explanation_metadata=_TEST_EXPLANATION_METADATA_EXAMPLES,
1221+
sync=sync,
1222+
)
1223+
1224+
if not sync:
1225+
my_model.wait()
1226+
1227+
container_spec = gca_model.ModelContainerSpec(
1228+
image_uri=_TEST_SERVING_CONTAINER_IMAGE,
1229+
)
1230+
1231+
managed_model = gca_model.Model(
1232+
display_name=_TEST_MODEL_NAME,
1233+
container_spec=container_spec,
1234+
explanation_spec=gca_model.explanation.ExplanationSpec(
1235+
metadata=_TEST_EXPLANATION_METADATA_EXAMPLES,
1236+
parameters=_TEST_EXPLANATION_PARAMETERS_EXAMPLES_FULL_CONFIG,
1237+
),
1238+
version_aliases=["default"],
1239+
)
1240+
1241+
upload_model_mock.assert_called_once_with(
1242+
request=gca_model_service.UploadModelRequest(
1243+
parent=initializer.global_config.common_location_path(),
1244+
model=managed_model,
1245+
),
1246+
timeout=None,
1247+
)
1248+
11221249
@pytest.mark.parametrize("sync", [True, False])
11231250
def test_upload_uploads_and_gets_model_with_all_args(
11241251
self, upload_model_mock, get_model_mock, sync

0 commit comments

Comments
 (0)