Skip to content

Commit 3d68777

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add support for Prediction dedicated endpoint. predict/rawPredict/streamRawPredict can use dedicated DNS to access the dedicated endpoint.
PiperOrigin-RevId: 667018843
1 parent a0d4ff2 commit 3d68777

File tree

2 files changed

+304
-6
lines changed

2 files changed

+304
-6
lines changed

google/cloud/aiplatform/models.py

+136-6
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,7 @@ def create(
782782
enable_request_response_logging=False,
783783
request_response_logging_sampling_rate: Optional[float] = None,
784784
request_response_logging_bq_destination_table: Optional[str] = None,
785+
dedicated_endpoint_enabled=False,
785786
) -> "Endpoint":
786787
"""Creates a new endpoint.
787788
@@ -849,6 +850,10 @@ def create(
849850
request_response_logging_bq_destination_table (str):
850851
Optional. The request response logging bigquery destination. If not set, will create a table with name:
851852
``bq://{project_id}.logging_{endpoint_display_name}_{endpoint_id}.request_response_logging``.
853+
dedicated_endpoint_enabled (bool):
854+
Optional. If enabled, a dedicated dns will be created and your
855+
traffic will be fully isolated from other customers' traffic and
856+
latency will be reduced.
852857
853858
Returns:
854859
endpoint (aiplatform.Endpoint):
@@ -893,6 +898,7 @@ def create(
893898
create_request_timeout=create_request_timeout,
894899
endpoint_id=endpoint_id,
895900
predict_request_response_logging_config=predict_request_response_logging_config,
901+
dedicated_endpoint_enabled=dedicated_endpoint_enabled,
896902
)
897903

898904
@classmethod
@@ -918,6 +924,7 @@ def _create(
918924
private_service_connect_config: Optional[
919925
gca_service_networking.PrivateServiceConnectConfig
920926
] = None,
927+
dedicated_endpoint_enabled=False,
921928
) -> "Endpoint":
922929
"""Creates a new endpoint by calling the API client.
923930
@@ -984,6 +991,10 @@ def _create(
984991
private_service_connect_config (aiplatform.service_network.PrivateServiceConnectConfig):
985992
If enabled, the endpoint can be accessible via [Private Service Connect](https://cloud.google.com/vpc/docs/private-service-connect).
986993
Cannot be enabled when network is specified.
994+
dedicated_endpoint_enabled (bool):
995+
Optional. If enabled, a dedicated dns will be created and your
996+
traffic will be fully isolated from other customers' traffic and
997+
latency will be reduced.
987998
988999
Returns:
9891000
endpoint (aiplatform.Endpoint):
@@ -1002,6 +1013,7 @@ def _create(
10021013
network=network,
10031014
predict_request_response_logging_config=predict_request_response_logging_config,
10041015
private_service_connect_config=private_service_connect_config,
1016+
dedicated_endpoint_enabled=dedicated_endpoint_enabled,
10051017
)
10061018

10071019
operation_future = api_client.create_endpoint(
@@ -2167,9 +2179,18 @@ def predict(
21672179
parameters: Optional[Dict] = None,
21682180
timeout: Optional[float] = None,
21692181
use_raw_predict: Optional[bool] = False,
2182+
*,
2183+
use_dedicated_endpoint: Optional[bool] = False,
21702184
) -> Prediction:
21712185
"""Make a prediction against this Endpoint.
21722186
2187+
For dedicated endpoint, set use_dedicated_endpoint = True:
2188+
```
2189+
response = my_endpoint.predict(instances=[...],
2190+
use_dedicated_endpoint=True)
2191+
my_predictions = response.predictions
2192+
```
2193+
21732194
Args:
21742195
instances (List):
21752196
Required. The instances that are the input to the
@@ -2194,6 +2215,9 @@ def predict(
21942215
use_raw_predict (bool):
21952216
Optional. Default value is False. If set to True, the underlying prediction call will be made
21962217
against Endpoint.raw_predict().
2218+
use_dedicated_endpoint (bool):
2219+
Optional. Default value is False. If set to True, the underlying prediction call will be made
2220+
using the dedicated endpoint dns.
21972221
21982222
Returns:
21992223
prediction (aiplatform.Prediction):
@@ -2204,6 +2228,7 @@ def predict(
22042228
raw_predict_response = self.raw_predict(
22052229
body=json.dumps({"instances": instances, "parameters": parameters}),
22062230
headers={"Content-Type": "application/json"},
2231+
use_dedicated_endpoint=use_dedicated_endpoint,
22072232
)
22082233
json_response = raw_predict_response.json()
22092234
return Prediction(
@@ -2219,6 +2244,51 @@ def predict(
22192244
_RAW_PREDICT_MODEL_VERSION_ID_KEY, None
22202245
),
22212246
)
2247+
2248+
if use_dedicated_endpoint:
2249+
self._sync_gca_resource_if_skipped()
2250+
if (
2251+
not self._gca_resource.dedicated_endpoint_enabled
2252+
or self._gca_resource.dedicated_endpoint_dns is None
2253+
):
2254+
raise ValueError(
2255+
"Dedicated endpoint is not enabled or DNS is empty."
2256+
"Please make sure endpoint has dedicated endpoint enabled"
2257+
"and model are ready before making a prediction."
2258+
)
2259+
2260+
if not self.authorized_session:
2261+
self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
2262+
self.authorized_session = google_auth_requests.AuthorizedSession(
2263+
self.credentials
2264+
)
2265+
2266+
headers = {
2267+
"Content-Type": "application/json",
2268+
}
2269+
2270+
url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:predict"
2271+
response = self.authorized_session.post(
2272+
url=url,
2273+
data=json.dumps(
2274+
{
2275+
"instances": instances,
2276+
"parameters": parameters,
2277+
}
2278+
),
2279+
headers=headers,
2280+
)
2281+
2282+
prediction_response = json.loads(response.text)
2283+
2284+
return Prediction(
2285+
predictions=prediction_response.get("predictions"),
2286+
metadata=prediction_response.get("metadata"),
2287+
deployed_model_id=prediction_response.get("deployedModelId"),
2288+
model_resource_name=prediction_response.get("model"),
2289+
model_version_id=prediction_response.get("modelVersionId"),
2290+
)
2291+
22222292
else:
22232293
prediction_response = self._prediction_client.predict(
22242294
endpoint=self._gca_resource.name,
@@ -2307,7 +2377,11 @@ async def predict_async(
23072377
)
23082378

23092379
def raw_predict(
2310-
self, body: bytes, headers: Dict[str, str]
2380+
self,
2381+
body: bytes,
2382+
headers: Dict[str, str],
2383+
*,
2384+
use_dedicated_endpoint: Optional[bool] = False,
23112385
) -> requests.models.Response:
23122386
"""Makes a prediction request using arbitrary headers.
23132387
@@ -2317,6 +2391,12 @@ def raw_predict(
23172391
body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}'
23182392
headers = {'Content-Type':'application/json'}
23192393
)
2394+
# For dedicated endpoint:
2395+
response = my_endpoint.raw_predict(
2396+
body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}',
2397+
headers = {'Content-Type':'application/json'},
2398+
dedicated_endpoint=True,
2399+
)
23202400
status_code = response.status_code
23212401
results = json.dumps(response.text)
23222402
@@ -2325,6 +2405,9 @@ def raw_predict(
23252405
The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
23262406
headers (Dict[str, str]):
23272407
The header of the request as a dictionary. There are no restrictions on the header.
2408+
use_dedicated_endpoint (bool):
2409+
Optional. Default value is False. If set to True, the underlying prediction call will be made
2410+
using the dedicated endpoint dns.
23282411
23292412
Returns:
23302413
A requests.models.Response object containing the status code and prediction results.
@@ -2338,12 +2421,29 @@ def raw_predict(
23382421
if self.raw_predict_request_url is None:
23392422
self.raw_predict_request_url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:rawPredict"
23402423

2341-
return self.authorized_session.post(
2342-
url=self.raw_predict_request_url, data=body, headers=headers
2343-
)
2424+
url = self.raw_predict_request_url
2425+
2426+
if use_dedicated_endpoint:
2427+
self._sync_gca_resource_if_skipped()
2428+
if (
2429+
not self._gca_resource.dedicated_endpoint_enabled
2430+
or self._gca_resource.dedicated_endpoint_dns is None
2431+
):
2432+
raise ValueError(
2433+
"Dedicated endpoint is not enabled or DNS is empty."
2434+
"Please make sure endpoint has dedicated endpoint enabled"
2435+
"and model are ready before making a prediction."
2436+
)
2437+
url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:rawPredict"
2438+
2439+
return self.authorized_session.post(url=url, data=body, headers=headers)
23442440

23452441
def stream_raw_predict(
2346-
self, body: bytes, headers: Dict[str, str]
2442+
self,
2443+
body: bytes,
2444+
headers: Dict[str, str],
2445+
*,
2446+
use_dedicated_endpoint: Optional[bool] = False,
23472447
) -> Iterator[requests.models.Response]:
23482448
"""Makes a streaming prediction request using arbitrary headers.
23492449
@@ -2358,13 +2458,28 @@ def stream_raw_predict(
23582458
stream_result = json.dumps(response.text)
23592459
```
23602460
2461+
For dedicated endpoint:
2462+
```
2463+
my_endpoint = aiplatform.Endpoint(ENDPOINT_ID)
2464+
for stream_response in my_endpoint.stream_raw_predict(
2465+
body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}',
2466+
headers = {'Content-Type':'application/json'},
2467+
use_dedicated_endpoint=True,
2468+
):
2469+
status_code = response.status_code
2470+
stream_result = json.dumps(response.text)
2471+
```
2472+
23612473
Args:
23622474
body (bytes):
23632475
The body of the prediction request in bytes. This must not
23642476
exceed 10 mb per request.
23652477
headers (Dict[str, str]):
23662478
The header of the request as a dictionary. There are no
23672479
restrictions on the header.
2480+
use_dedicated_endpoint (bool):
2481+
Optional. Default value is False. If set to True, the underlying prediction call will be made
2482+
using the dedicated endpoint dns.
23682483
23692484
Yields:
23702485
predictions (Iterator[requests.models.Response]):
@@ -2379,8 +2494,23 @@ def stream_raw_predict(
23792494
if self.stream_raw_predict_request_url is None:
23802495
self.stream_raw_predict_request_url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:streamRawPredict"
23812496

2497+
url = self.raw_predict_request_url
2498+
2499+
if use_dedicated_endpoint:
2500+
self._sync_gca_resource_if_skipped()
2501+
if (
2502+
not self._gca_resource.dedicated_endpoint_enabled
2503+
or self._gca_resource.dedicated_endpoint_dns is None
2504+
):
2505+
raise ValueError(
2506+
"Dedicated endpoint is not enabled or DNS is empty."
2507+
"Please make sure endpoint has dedicated endpoint enabled"
2508+
"and model are ready before making a prediction."
2509+
)
2510+
url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:streamRawPredict"
2511+
23822512
with self.authorized_session.post(
2383-
url=self.stream_raw_predict_request_url,
2513+
url=url,
23842514
data=body,
23852515
headers=headers,
23862516
stream=True,

0 commit comments

Comments
 (0)