Skip to content

Commit e9eb159

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: Added async prediction and explanation support to the Endpoint class
* Added the `Endpoint.predict_async` method * Added the `Endpoint.explain_async` method * Made it possible to use async clients in classes derived from `VertexAiResourceNounWithFutureManager` that use `@optional_sync`. PiperOrigin-RevId: 565472250
1 parent 8b0add1 commit e9eb159

File tree

6 files changed

+295
-6
lines changed

6 files changed

+295
-6
lines changed

google/cloud/aiplatform/compat/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
services.model_garden_service_client = services.model_garden_service_client_v1beta1
4040
services.pipeline_service_client = services.pipeline_service_client_v1beta1
4141
services.prediction_service_client = services.prediction_service_client_v1beta1
42+
services.prediction_service_async_client = (
43+
services.prediction_service_async_client_v1beta1
44+
)
4245
services.schedule_service_client = services.schedule_service_client_v1beta1
4346
services.specialist_pool_service_client = (
4447
services.specialist_pool_service_client_v1beta1
@@ -144,6 +147,9 @@
144147
services.model_service_client = services.model_service_client_v1
145148
services.pipeline_service_client = services.pipeline_service_client_v1
146149
services.prediction_service_client = services.prediction_service_client_v1
150+
services.prediction_service_async_client = (
151+
services.prediction_service_async_client_v1
152+
)
147153
services.schedule_service_client = services.schedule_service_client_v1
148154
services.specialist_pool_service_client = services.specialist_pool_service_client_v1
149155
services.tensorboard_service_client = services.tensorboard_service_client_v1

google/cloud/aiplatform/compat/services/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@
6060
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
6161
client as prediction_service_client_v1beta1,
6262
)
63+
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
64+
async_client as prediction_service_async_client_v1beta1,
65+
)
6366
from google.cloud.aiplatform_v1beta1.services.schedule_service import (
6467
client as schedule_service_client_v1beta1,
6568
)
@@ -109,6 +112,9 @@
109112
from google.cloud.aiplatform_v1.services.prediction_service import (
110113
client as prediction_service_client_v1,
111114
)
115+
from google.cloud.aiplatform_v1.services.prediction_service import (
116+
async_client as prediction_service_async_client_v1,
117+
)
112118
from google.cloud.aiplatform_v1.services.schedule_service import (
113119
client as schedule_service_client_v1,
114120
)
@@ -136,6 +142,7 @@
136142
model_service_client_v1,
137143
pipeline_service_client_v1,
138144
prediction_service_client_v1,
145+
prediction_service_async_client_v1,
139146
schedule_service_client_v1,
140147
specialist_pool_service_client_v1,
141148
tensorboard_service_client_v1,
@@ -155,6 +162,7 @@
155162
persistent_resource_service_client_v1beta1,
156163
pipeline_service_client_v1beta1,
157164
prediction_service_client_v1beta1,
165+
prediction_service_async_client_v1beta1,
158166
schedule_service_client_v1beta1,
159167
specialist_pool_service_client_v1beta1,
160168
metadata_service_client_v1beta1,

google/cloud/aiplatform/models.py

+158-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import asyncio
1718
import json
1819
import pathlib
1920
import re
@@ -226,7 +227,10 @@ def __init__(
226227
# Lazy load the Endpoint gca_resource until needed
227228
self._gca_resource = gca_endpoint_compat.Endpoint(name=endpoint_name)
228229

229-
self._prediction_client = self._instantiate_prediction_client(
230+
(
231+
self._prediction_client,
232+
self._prediction_async_client,
233+
) = self._instantiate_prediction_clients(
230234
location=self.location,
231235
credentials=credentials,
232236
)
@@ -572,7 +576,10 @@ def _construct_sdk_resource_from_gapic(
572576
credentials=credentials,
573577
)
574578

575-
endpoint._prediction_client = cls._instantiate_prediction_client(
579+
(
580+
endpoint._prediction_client,
581+
endpoint._prediction_async_client,
582+
) = cls._instantiate_prediction_clients(
576583
location=endpoint.location,
577584
credentials=credentials,
578585
)
@@ -1384,10 +1391,12 @@ def _undeploy(
13841391
self._sync_gca_resource()
13851392

13861393
@staticmethod
1387-
def _instantiate_prediction_client(
1394+
def _instantiate_prediction_clients(
13881395
location: Optional[str] = None,
13891396
credentials: Optional[auth_credentials.Credentials] = None,
1390-
) -> utils.PredictionClientWithOverride:
1397+
) -> Tuple[
1398+
utils.PredictionClientWithOverride, utils.PredictionAsyncClientWithOverride
1399+
]:
13911400
"""Helper method to instantiates prediction client with optional
13921401
overrides for this endpoint.
13931402
@@ -1399,14 +1408,34 @@ def _instantiate_prediction_client(
13991408
14001409
Returns:
14011410
prediction_client (prediction_service_client.PredictionServiceClient):
1402-
Initialized prediction client with optional overrides.
1411+
prediction_async_client (PredictionServiceAsyncClient):
1412+
Initialized prediction clients with optional overrides.
14031413
"""
1404-
return initializer.global_config.create_client(
1414+
1415+
# Creating an event loop if needed.
1416+
# PredictionServiceAsyncClient constructor calls `asyncio.get_event_loop`,
1417+
# which fails when there is no event loop (which does not exist by default
1418+
# in non-main threads in thread pool used when `sync=False`).
1419+
try:
1420+
asyncio.get_event_loop()
1421+
except RuntimeError:
1422+
asyncio.set_event_loop(asyncio.new_event_loop())
1423+
1424+
async_client = initializer.global_config.create_client(
1425+
client_class=utils.PredictionAsyncClientWithOverride,
1426+
credentials=credentials,
1427+
location_override=location,
1428+
prediction_client=True,
1429+
)
1430+
# We could use `client = async_client._client`, but then client would be
1431+
# a concrete `PredictionServiceClient`, not `PredictionClientWithOverride`.
1432+
client = initializer.global_config.create_client(
14051433
client_class=utils.PredictionClientWithOverride,
14061434
credentials=credentials,
14071435
location_override=location,
14081436
prediction_client=True,
14091437
)
1438+
return (client, async_client)
14101439

14111440
def update(
14121441
self,
@@ -1581,6 +1610,65 @@ def predict(
15811610
model_resource_name=prediction_response.model,
15821611
)
15831612

1613+
async def predict_async(
1614+
self,
1615+
instances: List,
1616+
*,
1617+
parameters: Optional[Dict] = None,
1618+
timeout: Optional[float] = None,
1619+
) -> Prediction:
1620+
"""Make an asynchronous prediction against this Endpoint.
1621+
Example usage:
1622+
```
1623+
response = await my_endpoint.predict_async(instances=[...])
1624+
my_predictions = response.predictions
1625+
```
1626+
1627+
Args:
1628+
instances (List):
1629+
Required. The instances that are the input to the
1630+
prediction call. A DeployedModel may have an upper limit
1631+
on the number of instances it supports per request, and
1632+
when it is exceeded the prediction call errors in case
1633+
of AutoML Models, or, in case of customer created
1634+
Models, the behaviour is as documented by that Model.
1635+
The schema of any single instance may be specified via
1636+
Endpoint's DeployedModels'
1637+
[Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model]
1638+
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
1639+
``instance_schema_uri``.
1640+
parameters (Dict):
1641+
Optional. The parameters that govern the prediction. The schema of
1642+
the parameters may be specified via Endpoint's
1643+
DeployedModels' [Model's
1644+
][google.cloud.aiplatform.v1beta1.DeployedModel.model]
1645+
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
1646+
``parameters_schema_uri``.
1647+
timeout (float): Optional. The timeout for this request in seconds.
1648+
1649+
Returns:
1650+
prediction (aiplatform.Prediction):
1651+
Prediction with returned predictions and Model ID.
1652+
"""
1653+
self.wait()
1654+
1655+
prediction_response = await self._prediction_async_client.predict(
1656+
endpoint=self._gca_resource.name,
1657+
instances=instances,
1658+
parameters=parameters,
1659+
timeout=timeout,
1660+
)
1661+
1662+
return Prediction(
1663+
predictions=[
1664+
json_format.MessageToDict(item)
1665+
for item in prediction_response.predictions.pb
1666+
],
1667+
deployed_model_id=prediction_response.deployed_model_id,
1668+
model_version_id=prediction_response.model_version_id,
1669+
model_resource_name=prediction_response.model,
1670+
)
1671+
15841672
def raw_predict(
15851673
self, body: bytes, headers: Dict[str, str]
15861674
) -> requests.models.Response:
@@ -1676,6 +1764,70 @@ def explain(
16761764
explanations=explain_response.explanations,
16771765
)
16781766

1767+
async def explain_async(
1768+
self,
1769+
instances: List[Dict],
1770+
*,
1771+
parameters: Optional[Dict] = None,
1772+
deployed_model_id: Optional[str] = None,
1773+
timeout: Optional[float] = None,
1774+
) -> Prediction:
1775+
"""Make a prediction with explanations against this Endpoint.
1776+
1777+
Example usage:
1778+
```
1779+
response = await my_endpoint.explain_async(instances=[...])
1780+
my_explanations = response.explanations
1781+
```
1782+
1783+
Args:
1784+
instances (List):
1785+
Required. The instances that are the input to the
1786+
prediction call. A DeployedModel may have an upper limit
1787+
on the number of instances it supports per request, and
1788+
when it is exceeded the prediction call errors in case
1789+
of AutoML Models, or, in case of customer created
1790+
Models, the behaviour is as documented by that Model.
1791+
The schema of any single instance may be specified via
1792+
Endpoint's DeployedModels'
1793+
[Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model]
1794+
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
1795+
``instance_schema_uri``.
1796+
parameters (Dict):
1797+
The parameters that govern the prediction. The schema of
1798+
the parameters may be specified via Endpoint's
1799+
DeployedModels' [Model's
1800+
][google.cloud.aiplatform.v1beta1.DeployedModel.model]
1801+
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
1802+
``parameters_schema_uri``.
1803+
deployed_model_id (str):
1804+
Optional. If specified, this ExplainRequest will be served by the
1805+
chosen DeployedModel, overriding this Endpoint's traffic split.
1806+
timeout (float): Optional. The timeout for this request in seconds.
1807+
1808+
Returns:
1809+
prediction (aiplatform.Prediction):
1810+
Prediction with returned predictions, explanations, and Model ID.
1811+
"""
1812+
self.wait()
1813+
1814+
explain_response = await self._prediction_async_client.explain(
1815+
endpoint=self.resource_name,
1816+
instances=instances,
1817+
parameters=parameters,
1818+
deployed_model_id=deployed_model_id,
1819+
timeout=timeout,
1820+
)
1821+
1822+
return Prediction(
1823+
predictions=[
1824+
json_format.MessageToDict(item)
1825+
for item in explain_response.predictions.pb
1826+
],
1827+
deployed_model_id=explain_response.deployed_model_id,
1828+
explanations=explain_response.explanations,
1829+
)
1830+
16791831
@classmethod
16801832
def list(
16811833
cls,

google/cloud/aiplatform/utils/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
model_service_client_v1beta1,
5050
pipeline_service_client_v1beta1,
5151
prediction_service_client_v1beta1,
52+
prediction_service_async_client_v1beta1,
5253
schedule_service_client_v1beta1,
5354
tensorboard_service_client_v1beta1,
5455
vizier_service_client_v1beta1,
@@ -68,6 +69,7 @@
6869
model_service_client_v1,
6970
pipeline_service_client_v1,
7071
prediction_service_client_v1,
72+
prediction_service_async_client_v1,
7173
schedule_service_client_v1,
7274
tensorboard_service_client_v1,
7375
vizier_service_client_v1,
@@ -89,6 +91,7 @@
8991
index_endpoint_service_client_v1beta1.IndexEndpointServiceClient,
9092
model_service_client_v1beta1.ModelServiceClient,
9193
prediction_service_client_v1beta1.PredictionServiceClient,
94+
prediction_service_async_client_v1beta1.PredictionServiceAsyncClient,
9295
pipeline_service_client_v1beta1.PipelineServiceClient,
9396
job_service_client_v1beta1.JobServiceClient,
9497
match_service_client_v1beta1.MatchServiceClient,
@@ -104,6 +107,7 @@
104107
metadata_service_client_v1.MetadataServiceClient,
105108
model_service_client_v1.ModelServiceClient,
106109
prediction_service_client_v1.PredictionServiceClient,
110+
prediction_service_async_client_v1.PredictionServiceAsyncClient,
107111
pipeline_service_client_v1.PipelineServiceClient,
108112
job_service_client_v1.JobServiceClient,
109113
schedule_service_client_v1.ScheduleServiceClient,
@@ -616,6 +620,18 @@ class PredictionClientWithOverride(ClientWithOverride):
616620
)
617621

618622

623+
class PredictionAsyncClientWithOverride(ClientWithOverride):
624+
_is_temporary = False
625+
_default_version = compat.DEFAULT_VERSION
626+
_version_map = (
627+
(compat.V1, prediction_service_async_client_v1.PredictionServiceAsyncClient),
628+
(
629+
compat.V1BETA1,
630+
prediction_service_async_client_v1beta1.PredictionServiceAsyncClient,
631+
),
632+
)
633+
634+
619635
class MatchClientWithOverride(ClientWithOverride):
620636
_is_temporary = False
621637
_default_version = compat.V1BETA1

tests/system/aiplatform/test_model_interactions.py

+9
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
import json
19+
import pytest
1920

2021
from google.cloud import aiplatform
2122

@@ -64,3 +65,11 @@ def test_prediction(self):
6465
)
6566
assert raw_prediction_response.status_code == 200
6667
assert len(json.loads(raw_prediction_response.text)) == 1
68+
69+
@pytest.mark.asyncio
70+
async def test_endpoint_predict_async(self):
71+
# Test the Endpoint.predict_async method.
72+
prediction_response = await self.endpoint.predict_async(
73+
instances=[_PREDICTION_INSTANCE]
74+
)
75+
assert len(prediction_response.predictions) == 1

0 commit comments

Comments
 (0)