|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | # [START aiplatform_predict_custom_trained_model_sample]
|
16 |
| -from typing import Dict |
| 16 | +from typing import Dict, List, Union |
17 | 17 |
|
18 | 18 | from google.cloud import aiplatform
|
19 | 19 | from google.protobuf import json_format
|
|
23 | 23 | def predict_custom_trained_model_sample(
|
24 | 24 | project: str,
|
25 | 25 | endpoint_id: str,
|
26 |
| - instance_dict: Dict, |
| 26 | + instances: Union[Dict, List[Dict]], |
27 | 27 | location: str = "us-central1",
|
28 | 28 | api_endpoint: str = "us-central1-aiplatform.googleapis.com",
|
29 | 29 | ):
|
| 30 | + """ |
| 31 | + `instances` can be either single instance of type dict or a list |
| 32 | + of instances. |
| 33 | + """ |
30 | 34 | # The AI Platform services require regional API endpoints.
|
31 | 35 | client_options = {"api_endpoint": api_endpoint}
|
32 | 36 | # Initialize client that will be used to create and send requests.
|
33 | 37 | # This client only needs to be created once, and can be reused for multiple requests.
|
34 | 38 | client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)
|
35 | 39 | # The format of each instance should conform to the deployed model's prediction input schema.
|
36 |
| - instance = json_format.ParseDict(instance_dict, Value()) |
37 |
| - instances = [instance] |
| 40 | + instances = instances if type(instances) == list else [instances] |
| 41 | + instances = [ |
| 42 | + json_format.ParseDict(instance_dict, Value()) for instance_dict in instances |
| 43 | + ] |
38 | 44 | parameters_dict = {}
|
39 | 45 | parameters = json_format.ParseDict(parameters_dict, Value())
|
40 | 46 | endpoint = client.endpoint_path(
|
|
0 commit comments