Skip to content

Commit 677b311

Browse files
andrewferlitschkweinmeisterrosiezounayaknishant
authored
docs(samples): add AutoML image classification sample (#923)
* Create predict_image_classification_sample.py * feat: new sample and test * lint: fix wsp * lint: import order * lint: fix import * tags: fixed start tag * samples: change tabular to image in sample function name. * samples: replace TF version of reading in binary file with Python version * samples: delete tf import, move other imports within region tags * Update predict_image_classification_sample.py * samples: move imports for lint * Update predict_image_classification_sample.py * Update predict_image_classification_sample_test.py Co-authored-by: Karl Weinmeister <[email protected]> Co-authored-by: Rosie Zou <[email protected]> Co-authored-by: nayaknishant <[email protected]>
1 parent 406ed84 commit 677b311

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
# [START aiplatform_sdk_predict_image_classification_sample]
17+
import base64
18+
19+
from typing import List
20+
21+
from google.cloud import aiplatform
22+
23+
24+
def predict_image_classification_sample(
25+
project: str,
26+
location: str,
27+
endpoint_name: str,
28+
images: List,
29+
):
30+
'''
31+
Args
32+
project: Your project ID or project number.
33+
location: Region where Endpoint is located. For example, 'us-central1'.
34+
endpoint_name: A fully qualified endpoint name or endpoint ID. Example: "projects/123/locations/us-central1/endpoints/456" or
35+
"456" when project and location are initialized or passed.
36+
images: A list of one or more images to return a prediction for.
37+
'''
38+
aiplatform.init(project=project, location=location)
39+
40+
endpoint = aiplatform.Endpoint(endpoint_name)
41+
42+
instances = []
43+
for image in images:
44+
with open(image, "rb") as f:
45+
content = f.read()
46+
instances.append({"content": base64.b64encode(content).decode("utf-8")})
47+
48+
response = endpoint.predict(instances=instances)
49+
50+
for prediction_ in response.predictions:
51+
print(prediction_)
52+
53+
54+
# [END aiplatform_sdk_predict_image_classification_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import predict_image_classification_sample
17+
import test_constants as constants
18+
19+
20+
def test_predict_image_classification_sample(mock_sdk_init, mock_get_endpoint):
21+
22+
predict_image_classification_sample.predict_image_classification_sample(
23+
project=constants.PROJECT,
24+
location=constants.LOCATION,
25+
endpoint_name=constants.ENDPOINT_NAME,
26+
images=[]
27+
)
28+
29+
mock_sdk_init.assert_called_once_with(
30+
project=constants.PROJECT, location=constants.LOCATION
31+
)
32+
33+
mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME,)

0 commit comments

Comments
 (0)