20
20
import re
21
21
import shutil
22
22
import tempfile
23
+ import requests
23
24
from typing import (
24
25
Any ,
25
26
Dict ,
35
36
from google .api_core import operation
36
37
from google .api_core import exceptions as api_exceptions
37
38
from google .auth import credentials as auth_credentials
39
+ from google .auth .transport import requests as google_auth_requests
38
40
39
41
from google .cloud import aiplatform
40
42
from google .cloud .aiplatform import base
43
+ from google .cloud .aiplatform import constants
41
44
from google .cloud .aiplatform import explain
42
45
from google .cloud .aiplatform import initializer
43
46
from google .cloud .aiplatform import jobs
69
72
_DEFAULT_MACHINE_TYPE = "n1-standard-2"
70
73
_DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY = "0"
71
74
_SUCCESSFUL_HTTP_RESPONSE = 300
75
+ _RAW_PREDICT_DEPLOYED_MODEL_ID_KEY = "X-Vertex-AI-Deployed-Model-Id"
76
+ _RAW_PREDICT_MODEL_RESOURCE_KEY = "X-Vertex-AI-Model"
72
77
73
78
_LOGGER = base .Logger (__name__ )
74
79
@@ -200,6 +205,8 @@ def __init__(
200
205
location = self .location ,
201
206
credentials = credentials ,
202
207
)
208
+ self .authorized_session = None
209
+ self .raw_predict_request_url = None
203
210
204
211
def _skipped_getter_call (self ) -> bool :
205
212
"""Check if GAPIC resource was populated by call to get/list API methods
@@ -1389,16 +1396,15 @@ def update(
1389
1396
"""Updates an endpoint.
1390
1397
1391
1398
Example usage:
1392
-
1393
- my_endpoint = my_endpoint.update(
1394
- display_name='my-updated-endpoint',
1395
- description='my updated description',
1396
- labels={'key': 'value'},
1397
- traffic_split={
1398
- '123456': 20,
1399
- '234567': 80,
1400
- },
1401
- )
1399
+ my_endpoint = my_endpoint.update(
1400
+ display_name='my-updated-endpoint',
1401
+ description='my updated description',
1402
+ labels={'key': 'value'},
1403
+ traffic_split={
1404
+ '123456': 20,
1405
+ '234567': 80,
1406
+ },
1407
+ )
1402
1408
1403
1409
Args:
1404
1410
display_name (str):
@@ -1481,6 +1487,7 @@ def predict(
1481
1487
instances : List ,
1482
1488
parameters : Optional [Dict ] = None ,
1483
1489
timeout : Optional [float ] = None ,
1490
+ use_raw_predict : Optional [bool ] = False ,
1484
1491
) -> Prediction :
1485
1492
"""Make a prediction against this Endpoint.
1486
1493
@@ -1505,29 +1512,80 @@ def predict(
1505
1512
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
1506
1513
``parameters_schema_uri``.
1507
1514
timeout (float): Optional. The timeout for this request in seconds.
1515
+ use_raw_predict (bool):
1516
+ Optional. Default value is False. If set to True, the underlying prediction call will be made
1517
+ against Endpoint.raw_predict(). Note that model version information will
1518
+ not be available in the prediciton response using raw_predict.
1508
1519
1509
1520
Returns:
1510
1521
prediction (aiplatform.Prediction):
1511
1522
Prediction with returned predictions and Model ID.
1512
1523
"""
1513
1524
self .wait ()
1525
+ if use_raw_predict :
1526
+ raw_predict_response = self .raw_predict (
1527
+ body = json .dumps ({"instances" : instances , "parameters" : parameters }),
1528
+ headers = {"Content-Type" : "application/json" },
1529
+ )
1530
+ json_response = json .loads (raw_predict_response .text )
1531
+ return Prediction (
1532
+ predictions = json_response ["predictions" ],
1533
+ deployed_model_id = raw_predict_response .headers [
1534
+ _RAW_PREDICT_DEPLOYED_MODEL_ID_KEY
1535
+ ],
1536
+ model_resource_name = raw_predict_response .headers [
1537
+ _RAW_PREDICT_MODEL_RESOURCE_KEY
1538
+ ],
1539
+ )
1540
+ else :
1541
+ prediction_response = self ._prediction_client .predict (
1542
+ endpoint = self ._gca_resource .name ,
1543
+ instances = instances ,
1544
+ parameters = parameters ,
1545
+ timeout = timeout ,
1546
+ )
1514
1547
1515
- prediction_response = self ._prediction_client .predict (
1516
- endpoint = self ._gca_resource .name ,
1517
- instances = instances ,
1518
- parameters = parameters ,
1519
- timeout = timeout ,
1520
- )
1548
+ return Prediction (
1549
+ predictions = [
1550
+ json_format .MessageToDict (item )
1551
+ for item in prediction_response .predictions .pb
1552
+ ],
1553
+ deployed_model_id = prediction_response .deployed_model_id ,
1554
+ model_version_id = prediction_response .model_version_id ,
1555
+ model_resource_name = prediction_response .model ,
1556
+ )
1521
1557
1522
- return Prediction (
1523
- predictions = [
1524
- json_format .MessageToDict (item )
1525
- for item in prediction_response .predictions .pb
1526
- ],
1527
- deployed_model_id = prediction_response .deployed_model_id ,
1528
- model_version_id = prediction_response .model_version_id ,
1529
- model_resource_name = prediction_response .model ,
1530
- )
1558
+ def raw_predict (
1559
+ self , body : bytes , headers : Dict [str , str ]
1560
+ ) -> requests .models .Response :
1561
+ """Makes a prediction request using arbitrary headers.
1562
+
1563
+ Example usage:
1564
+ my_endpoint = aiplatform.Endpoint(ENDPOINT_ID)
1565
+ response = my_endpoint.raw_predict(
1566
+ body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}'
1567
+ headers = {'Content-Type':'application/json'}
1568
+ )
1569
+ status_code = response.status_code
1570
+ results = json.dumps(response.text)
1571
+
1572
+ Args:
1573
+ body (bytes):
1574
+ The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
1575
+ headers (Dict[str, str]):
1576
+ The header of the request as a dictionary. There are no restrictions on the header.
1577
+
1578
+ Returns:
1579
+ A requests.models.Response object containing the status code and prediction results.
1580
+ """
1581
+ if not self .authorized_session :
1582
+ self .credentials ._scopes = constants .base .DEFAULT_AUTHED_SCOPES
1583
+ self .authorized_session = google_auth_requests .AuthorizedSession (
1584
+ self .credentials
1585
+ )
1586
+ self .raw_predict_request_url = f"https://{ self .location } -{ constants .base .API_BASE_PATH } /v1/projects/{ self .project } /locations/{ self .location } /endpoints/{ self .name } :rawPredict"
1587
+
1588
+ return self .authorized_session .post (self .raw_predict_request_url , body , headers )
1531
1589
1532
1590
def explain (
1533
1591
self ,
@@ -2004,7 +2062,7 @@ def _http_request(
2004
2062
def predict (self , instances : List , parameters : Optional [Dict ] = None ) -> Prediction :
2005
2063
"""Make a prediction against this PrivateEndpoint using a HTTP request.
2006
2064
This method must be called within the network the PrivateEndpoint is peered to.
2007
- The predict() call will fail otherwise . To check, use `PrivateEndpoint.network`.
2065
+ Otherwise, the predict() call will fail with error code 404 . To check, use `PrivateEndpoint.network`.
2008
2066
2009
2067
Example usage:
2010
2068
response = my_private_endpoint.predict(instances=[...])
@@ -2062,6 +2120,39 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict
2062
2120
deployed_model_id = self ._gca_resource .deployed_models [0 ].id ,
2063
2121
)
2064
2122
2123
+ def raw_predict (
2124
+ self , body : bytes , headers : Dict [str , str ]
2125
+ ) -> requests .models .Response :
2126
+ """Make a prediction request using arbitrary headers.
2127
+ This method must be called within the network the PrivateEndpoint is peered to.
2128
+ Otherwise, the predict() call will fail with error code 404. To check, use `PrivateEndpoint.network`.
2129
+
2130
+ Example usage:
2131
+ my_endpoint = aiplatform.PrivateEndpoint(ENDPOINT_ID)
2132
+ response = my_endpoint.raw_predict(
2133
+ body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}'
2134
+ headers = {'Content-Type':'application/json'}
2135
+ )
2136
+ status_code = response.status_code
2137
+ results = json.dumps(response.text)
2138
+
2139
+ Args:
2140
+ body (bytes):
2141
+ The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
2142
+ headers (Dict[str, str]):
2143
+ The header of the request as a dictionary. There are no restrictions on the header.
2144
+
2145
+ Returns:
2146
+ A requests.models.Response object containing the status code and prediction results.
2147
+ """
2148
+ self .wait ()
2149
+ return self ._http_request (
2150
+ method = "POST" ,
2151
+ url = self .predict_http_uri ,
2152
+ body = body ,
2153
+ headers = headers ,
2154
+ )
2155
+
2065
2156
def explain (self ):
2066
2157
raise NotImplementedError (
2067
2158
f"{ self .__class__ .__name__ } class does not support 'explain' as of now."
0 commit comments