14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
#
17
+ import asyncio
17
18
import json
18
19
import pathlib
19
20
import re
@@ -226,7 +227,10 @@ def __init__(
226
227
# Lazy load the Endpoint gca_resource until needed
227
228
self ._gca_resource = gca_endpoint_compat .Endpoint (name = endpoint_name )
228
229
229
- self ._prediction_client = self ._instantiate_prediction_client (
230
+ (
231
+ self ._prediction_client ,
232
+ self ._prediction_async_client ,
233
+ ) = self ._instantiate_prediction_clients (
230
234
location = self .location ,
231
235
credentials = credentials ,
232
236
)
@@ -572,7 +576,10 @@ def _construct_sdk_resource_from_gapic(
572
576
credentials = credentials ,
573
577
)
574
578
575
- endpoint ._prediction_client = cls ._instantiate_prediction_client (
579
+ (
580
+ endpoint ._prediction_client ,
581
+ endpoint ._prediction_async_client ,
582
+ ) = cls ._instantiate_prediction_clients (
576
583
location = endpoint .location ,
577
584
credentials = credentials ,
578
585
)
@@ -1384,10 +1391,12 @@ def _undeploy(
1384
1391
self ._sync_gca_resource ()
1385
1392
1386
1393
@staticmethod
1387
- def _instantiate_prediction_client (
1394
+ def _instantiate_prediction_clients (
1388
1395
location : Optional [str ] = None ,
1389
1396
credentials : Optional [auth_credentials .Credentials ] = None ,
1390
- ) -> utils .PredictionClientWithOverride :
1397
+ ) -> Tuple [
1398
+ utils .PredictionClientWithOverride , utils .PredictionAsyncClientWithOverride
1399
+ ]:
1391
1400
"""Helper method to instantiates prediction client with optional
1392
1401
overrides for this endpoint.
1393
1402
@@ -1399,14 +1408,34 @@ def _instantiate_prediction_client(
1399
1408
1400
1409
Returns:
1401
1410
prediction_client (prediction_service_client.PredictionServiceClient):
1402
- Initialized prediction client with optional overrides.
1411
+ prediction_async_client (PredictionServiceAsyncClient):
1412
+ Initialized prediction clients with optional overrides.
1403
1413
"""
1404
- return initializer .global_config .create_client (
1414
+
1415
+ # Creating an event loop if needed.
1416
+ # PredictionServiceAsyncClient constructor calls `asyncio.get_event_loop`,
1417
+ # which fails when there is no event loop (which does not exist by default
1418
+ # in non-main threads in thread pool used when `sync=False`).
1419
+ try :
1420
+ asyncio .get_event_loop ()
1421
+ except RuntimeError :
1422
+ asyncio .set_event_loop (asyncio .new_event_loop ())
1423
+
1424
+ async_client = initializer .global_config .create_client (
1425
+ client_class = utils .PredictionAsyncClientWithOverride ,
1426
+ credentials = credentials ,
1427
+ location_override = location ,
1428
+ prediction_client = True ,
1429
+ )
1430
+ # We could use `client = async_client._client`, but then client would be
1431
+ # a concrete `PredictionServiceClient`, not `PredictionClientWithOverride`.
1432
+ client = initializer .global_config .create_client (
1405
1433
client_class = utils .PredictionClientWithOverride ,
1406
1434
credentials = credentials ,
1407
1435
location_override = location ,
1408
1436
prediction_client = True ,
1409
1437
)
1438
+ return (client , async_client )
1410
1439
1411
1440
def update (
1412
1441
self ,
@@ -1581,6 +1610,65 @@ def predict(
1581
1610
model_resource_name = prediction_response .model ,
1582
1611
)
1583
1612
1613
+ async def predict_async (
1614
+ self ,
1615
+ instances : List ,
1616
+ * ,
1617
+ parameters : Optional [Dict ] = None ,
1618
+ timeout : Optional [float ] = None ,
1619
+ ) -> Prediction :
1620
+ """Make an asynchronous prediction against this Endpoint.
1621
+ Example usage:
1622
+ ```
1623
+ response = await my_endpoint.predict_async(instances=[...])
1624
+ my_predictions = response.predictions
1625
+ ```
1626
+
1627
+ Args:
1628
+ instances (List):
1629
+ Required. The instances that are the input to the
1630
+ prediction call. A DeployedModel may have an upper limit
1631
+ on the number of instances it supports per request, and
1632
+ when it is exceeded the prediction call errors in case
1633
+ of AutoML Models, or, in case of customer created
1634
+ Models, the behaviour is as documented by that Model.
1635
+ The schema of any single instance may be specified via
1636
+ Endpoint's DeployedModels'
1637
+ [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model]
1638
+ [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
1639
+ ``instance_schema_uri``.
1640
+ parameters (Dict):
1641
+ Optional. The parameters that govern the prediction. The schema of
1642
+ the parameters may be specified via Endpoint's
1643
+ DeployedModels' [Model's
1644
+ ][google.cloud.aiplatform.v1beta1.DeployedModel.model]
1645
+ [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
1646
+ ``parameters_schema_uri``.
1647
+ timeout (float): Optional. The timeout for this request in seconds.
1648
+
1649
+ Returns:
1650
+ prediction (aiplatform.Prediction):
1651
+ Prediction with returned predictions and Model ID.
1652
+ """
1653
+ self .wait ()
1654
+
1655
+ prediction_response = await self ._prediction_async_client .predict (
1656
+ endpoint = self ._gca_resource .name ,
1657
+ instances = instances ,
1658
+ parameters = parameters ,
1659
+ timeout = timeout ,
1660
+ )
1661
+
1662
+ return Prediction (
1663
+ predictions = [
1664
+ json_format .MessageToDict (item )
1665
+ for item in prediction_response .predictions .pb
1666
+ ],
1667
+ deployed_model_id = prediction_response .deployed_model_id ,
1668
+ model_version_id = prediction_response .model_version_id ,
1669
+ model_resource_name = prediction_response .model ,
1670
+ )
1671
+
1584
1672
def raw_predict (
1585
1673
self , body : bytes , headers : Dict [str , str ]
1586
1674
) -> requests .models .Response :
@@ -1676,6 +1764,70 @@ def explain(
1676
1764
explanations = explain_response .explanations ,
1677
1765
)
1678
1766
1767
+ async def explain_async (
1768
+ self ,
1769
+ instances : List [Dict ],
1770
+ * ,
1771
+ parameters : Optional [Dict ] = None ,
1772
+ deployed_model_id : Optional [str ] = None ,
1773
+ timeout : Optional [float ] = None ,
1774
+ ) -> Prediction :
1775
+ """Make a prediction with explanations against this Endpoint.
1776
+
1777
+ Example usage:
1778
+ ```
1779
+ response = await my_endpoint.explain_async(instances=[...])
1780
+ my_explanations = response.explanations
1781
+ ```
1782
+
1783
+ Args:
1784
+ instances (List):
1785
+ Required. The instances that are the input to the
1786
+ prediction call. A DeployedModel may have an upper limit
1787
+ on the number of instances it supports per request, and
1788
+ when it is exceeded the prediction call errors in case
1789
+ of AutoML Models, or, in case of customer created
1790
+ Models, the behaviour is as documented by that Model.
1791
+ The schema of any single instance may be specified via
1792
+ Endpoint's DeployedModels'
1793
+ [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model]
1794
+ [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
1795
+ ``instance_schema_uri``.
1796
+ parameters (Dict):
1797
+ The parameters that govern the prediction. The schema of
1798
+ the parameters may be specified via Endpoint's
1799
+ DeployedModels' [Model's
1800
+ ][google.cloud.aiplatform.v1beta1.DeployedModel.model]
1801
+ [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
1802
+ ``parameters_schema_uri``.
1803
+ deployed_model_id (str):
1804
+ Optional. If specified, this ExplainRequest will be served by the
1805
+ chosen DeployedModel, overriding this Endpoint's traffic split.
1806
+ timeout (float): Optional. The timeout for this request in seconds.
1807
+
1808
+ Returns:
1809
+ prediction (aiplatform.Prediction):
1810
+ Prediction with returned predictions, explanations, and Model ID.
1811
+ """
1812
+ self .wait ()
1813
+
1814
+ explain_response = await self ._prediction_async_client .explain (
1815
+ endpoint = self .resource_name ,
1816
+ instances = instances ,
1817
+ parameters = parameters ,
1818
+ deployed_model_id = deployed_model_id ,
1819
+ timeout = timeout ,
1820
+ )
1821
+
1822
+ return Prediction (
1823
+ predictions = [
1824
+ json_format .MessageToDict (item )
1825
+ for item in explain_response .predictions .pb
1826
+ ],
1827
+ deployed_model_id = explain_response .deployed_model_id ,
1828
+ explanations = explain_response .explanations ,
1829
+ )
1830
+
1679
1831
@classmethod
1680
1832
def list (
1681
1833
cls ,
0 commit comments