Skip to content

Commit 8cb4839

Browse files
authored
fix: Support multiple instances in custom predict sample (#857)
1 parent dd1f650 commit 8cb4839

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

samples/snippets/prediction_service/predict_custom_trained_model_sample.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# [START aiplatform_predict_custom_trained_model_sample]
16-
from typing import Dict
16+
from typing import Dict, List, Union
1717

1818
from google.cloud import aiplatform
1919
from google.protobuf import json_format
@@ -23,18 +23,24 @@
2323
def predict_custom_trained_model_sample(
2424
project: str,
2525
endpoint_id: str,
26-
instance_dict: Dict,
26+
instances: Union[Dict, List[Dict]],
2727
location: str = "us-central1",
2828
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
2929
):
30+
"""
31+
`instances` can be either single instance of type dict or a list
32+
of instances.
33+
"""
3034
# The AI Platform services require regional API endpoints.
3135
client_options = {"api_endpoint": api_endpoint}
3236
# Initialize client that will be used to create and send requests.
3337
# This client only needs to be created once, and can be reused for multiple requests.
3438
client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)
3539
# The format of each instance should conform to the deployed model's prediction input schema.
36-
instance = json_format.ParseDict(instance_dict, Value())
37-
instances = [instance]
40+
instances = instances if type(instances) == list else [instances]
41+
instances = [
42+
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
43+
]
3844
parameters_dict = {}
3945
parameters = json_format.ParseDict(parameters_dict, Value())
4046
endpoint = client.endpoint_path(

samples/snippets/prediction_service/predict_custom_trained_model_sample_test.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,20 @@ def test_ucaip_generated_predict_custom_trained_model_sample(capsys):
3232

3333
instance_dict = {"image_bytes": {"b64": encoded_content}, "key": "0"}
3434

35+
# Single instance as a dict
3536
predict_custom_trained_model_sample.predict_custom_trained_model_sample(
36-
instance_dict=instance_dict, project=PROJECT_ID, endpoint_id=ENDPOINT_ID
37+
instances=instance_dict, project=PROJECT_ID, endpoint_id=ENDPOINT_ID
38+
)
39+
40+
# Multiple instances in a list
41+
predict_custom_trained_model_sample.predict_custom_trained_model_sample(
42+
instances=[instance_dict, instance_dict],
43+
project=PROJECT_ID,
44+
endpoint_id=ENDPOINT_ID,
3745
)
3846

3947
out, _ = capsys.readouterr()
4048
assert "1.0" in out
49+
50+
# Two sets of scores for multi-instance, one score for single instance
51+
assert out.count("scores") == 3

0 commit comments

Comments
 (0)