@@ -782,6 +782,7 @@ def create(
782
782
enable_request_response_logging = False ,
783
783
request_response_logging_sampling_rate : Optional [float ] = None ,
784
784
request_response_logging_bq_destination_table : Optional [str ] = None ,
785
+ dedicated_endpoint_enabled = False ,
785
786
) -> "Endpoint" :
786
787
"""Creates a new endpoint.
787
788
@@ -849,6 +850,10 @@ def create(
849
850
request_response_logging_bq_destination_table (str):
850
851
Optional. The request response logging bigquery destination. If not set, will create a table with name:
851
852
``bq://{project_id}.logging_{endpoint_display_name}_{endpoint_id}.request_response_logging``.
853
+ dedicated_endpoint_enabled (bool):
854
+ Optional. If enabled, a dedicated dns will be created and your
855
+ traffic will be fully isolated from other customers' traffic and
856
+ latency will be reduced.
852
857
853
858
Returns:
854
859
endpoint (aiplatform.Endpoint):
@@ -893,6 +898,7 @@ def create(
893
898
create_request_timeout = create_request_timeout ,
894
899
endpoint_id = endpoint_id ,
895
900
predict_request_response_logging_config = predict_request_response_logging_config ,
901
+ dedicated_endpoint_enabled = dedicated_endpoint_enabled ,
896
902
)
897
903
898
904
@classmethod
@@ -918,6 +924,7 @@ def _create(
918
924
private_service_connect_config : Optional [
919
925
gca_service_networking .PrivateServiceConnectConfig
920
926
] = None ,
927
+ dedicated_endpoint_enabled = False ,
921
928
) -> "Endpoint" :
922
929
"""Creates a new endpoint by calling the API client.
923
930
@@ -984,6 +991,10 @@ def _create(
984
991
private_service_connect_config (aiplatform.service_network.PrivateServiceConnectConfig):
985
992
If enabled, the endpoint can be accessible via [Private Service Connect](https://cloud.google.com/vpc/docs/private-service-connect).
986
993
Cannot be enabled when network is specified.
994
+ dedicated_endpoint_enabled (bool):
995
+ Optional. If enabled, a dedicated dns will be created and your
996
+ traffic will be fully isolated from other customers' traffic and
997
+ latency will be reduced.
987
998
988
999
Returns:
989
1000
endpoint (aiplatform.Endpoint):
@@ -1002,6 +1013,7 @@ def _create(
1002
1013
network = network ,
1003
1014
predict_request_response_logging_config = predict_request_response_logging_config ,
1004
1015
private_service_connect_config = private_service_connect_config ,
1016
+ dedicated_endpoint_enabled = dedicated_endpoint_enabled ,
1005
1017
)
1006
1018
1007
1019
operation_future = api_client .create_endpoint (
@@ -2167,9 +2179,18 @@ def predict(
2167
2179
parameters : Optional [Dict ] = None ,
2168
2180
timeout : Optional [float ] = None ,
2169
2181
use_raw_predict : Optional [bool ] = False ,
2182
+ * ,
2183
+ use_dedicated_endpoint : Optional [bool ] = False ,
2170
2184
) -> Prediction :
2171
2185
"""Make a prediction against this Endpoint.
2172
2186
2187
+ For dedicated endpoint, set use_dedicated_endpoint = True:
2188
+ ```
2189
+ response = my_endpoint.predict(instances=[...],
2190
+ use_dedicated_endpoint=True)
2191
+ my_predictions = response.predictions
2192
+ ```
2193
+
2173
2194
Args:
2174
2195
instances (List):
2175
2196
Required. The instances that are the input to the
@@ -2194,6 +2215,9 @@ def predict(
2194
2215
use_raw_predict (bool):
2195
2216
Optional. Default value is False. If set to True, the underlying prediction call will be made
2196
2217
against Endpoint.raw_predict().
2218
+ use_dedicated_endpoint (bool):
2219
+ Optional. Default value is False. If set to True, the underlying prediction call will be made
2220
+ using the dedicated endpoint dns.
2197
2221
2198
2222
Returns:
2199
2223
prediction (aiplatform.Prediction):
@@ -2204,6 +2228,7 @@ def predict(
2204
2228
raw_predict_response = self .raw_predict (
2205
2229
body = json .dumps ({"instances" : instances , "parameters" : parameters }),
2206
2230
headers = {"Content-Type" : "application/json" },
2231
+ use_dedicated_endpoint = use_dedicated_endpoint ,
2207
2232
)
2208
2233
json_response = raw_predict_response .json ()
2209
2234
return Prediction (
@@ -2219,6 +2244,51 @@ def predict(
2219
2244
_RAW_PREDICT_MODEL_VERSION_ID_KEY , None
2220
2245
),
2221
2246
)
2247
+
2248
+ if use_dedicated_endpoint :
2249
+ self ._sync_gca_resource_if_skipped ()
2250
+ if (
2251
+ not self ._gca_resource .dedicated_endpoint_enabled
2252
+ or self ._gca_resource .dedicated_endpoint_dns is None
2253
+ ):
2254
+ raise ValueError (
2255
+ "Dedicated endpoint is not enabled or DNS is empty."
2256
+ "Please make sure endpoint has dedicated endpoint enabled"
2257
+ "and model are ready before making a prediction."
2258
+ )
2259
+
2260
+ if not self .authorized_session :
2261
+ self .credentials ._scopes = constants .base .DEFAULT_AUTHED_SCOPES
2262
+ self .authorized_session = google_auth_requests .AuthorizedSession (
2263
+ self .credentials
2264
+ )
2265
+
2266
+ headers = {
2267
+ "Content-Type" : "application/json" ,
2268
+ }
2269
+
2270
+ url = f"https://{ self ._gca_resource .dedicated_endpoint_dns } /v1/{ self .resource_name } :predict"
2271
+ response = self .authorized_session .post (
2272
+ url = url ,
2273
+ data = json .dumps (
2274
+ {
2275
+ "instances" : instances ,
2276
+ "parameters" : parameters ,
2277
+ }
2278
+ ),
2279
+ headers = headers ,
2280
+ )
2281
+
2282
+ prediction_response = json .loads (response .text )
2283
+
2284
+ return Prediction (
2285
+ predictions = prediction_response .get ("predictions" ),
2286
+ metadata = prediction_response .get ("metadata" ),
2287
+ deployed_model_id = prediction_response .get ("deployedModelId" ),
2288
+ model_resource_name = prediction_response .get ("model" ),
2289
+ model_version_id = prediction_response .get ("modelVersionId" ),
2290
+ )
2291
+
2222
2292
else :
2223
2293
prediction_response = self ._prediction_client .predict (
2224
2294
endpoint = self ._gca_resource .name ,
@@ -2307,7 +2377,11 @@ async def predict_async(
2307
2377
)
2308
2378
2309
2379
def raw_predict (
2310
- self , body : bytes , headers : Dict [str , str ]
2380
+ self ,
2381
+ body : bytes ,
2382
+ headers : Dict [str , str ],
2383
+ * ,
2384
+ use_dedicated_endpoint : Optional [bool ] = False ,
2311
2385
) -> requests .models .Response :
2312
2386
"""Makes a prediction request using arbitrary headers.
2313
2387
@@ -2317,6 +2391,12 @@ def raw_predict(
2317
2391
body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}'
2318
2392
headers = {'Content-Type':'application/json'}
2319
2393
)
2394
+ # For dedicated endpoint:
2395
+ response = my_endpoint.raw_predict(
2396
+ body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}',
2397
+ headers = {'Content-Type':'application/json'},
2398
+ dedicated_endpoint=True,
2399
+ )
2320
2400
status_code = response.status_code
2321
2401
results = json.dumps(response.text)
2322
2402
@@ -2325,6 +2405,9 @@ def raw_predict(
2325
2405
The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
2326
2406
headers (Dict[str, str]):
2327
2407
The header of the request as a dictionary. There are no restrictions on the header.
2408
+ use_dedicated_endpoint (bool):
2409
+ Optional. Default value is False. If set to True, the underlying prediction call will be made
2410
+ using the dedicated endpoint dns.
2328
2411
2329
2412
Returns:
2330
2413
A requests.models.Response object containing the status code and prediction results.
@@ -2338,12 +2421,29 @@ def raw_predict(
2338
2421
if self .raw_predict_request_url is None :
2339
2422
self .raw_predict_request_url = f"https://{ self .location } -{ constants .base .API_BASE_PATH } /v1/projects/{ self .project } /locations/{ self .location } /endpoints/{ self .name } :rawPredict"
2340
2423
2341
- return self .authorized_session .post (
2342
- url = self .raw_predict_request_url , data = body , headers = headers
2343
- )
2424
+ url = self .raw_predict_request_url
2425
+
2426
+ if use_dedicated_endpoint :
2427
+ self ._sync_gca_resource_if_skipped ()
2428
+ if (
2429
+ not self ._gca_resource .dedicated_endpoint_enabled
2430
+ or self ._gca_resource .dedicated_endpoint_dns is None
2431
+ ):
2432
+ raise ValueError (
2433
+ "Dedicated endpoint is not enabled or DNS is empty."
2434
+ "Please make sure endpoint has dedicated endpoint enabled"
2435
+ "and model are ready before making a prediction."
2436
+ )
2437
+ url = f"https://{ self ._gca_resource .dedicated_endpoint_dns } /v1/{ self .resource_name } :rawPredict"
2438
+
2439
+ return self .authorized_session .post (url = url , data = body , headers = headers )
2344
2440
2345
2441
def stream_raw_predict (
2346
- self , body : bytes , headers : Dict [str , str ]
2442
+ self ,
2443
+ body : bytes ,
2444
+ headers : Dict [str , str ],
2445
+ * ,
2446
+ use_dedicated_endpoint : Optional [bool ] = False ,
2347
2447
) -> Iterator [requests .models .Response ]:
2348
2448
"""Makes a streaming prediction request using arbitrary headers.
2349
2449
@@ -2358,13 +2458,28 @@ def stream_raw_predict(
2358
2458
stream_result = json.dumps(response.text)
2359
2459
```
2360
2460
2461
+ For dedicated endpoint:
2462
+ ```
2463
+ my_endpoint = aiplatform.Endpoint(ENDPOINT_ID)
2464
+ for stream_response in my_endpoint.stream_raw_predict(
2465
+ body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}',
2466
+ headers = {'Content-Type':'application/json'},
2467
+ use_dedicated_endpoint=True,
2468
+ ):
2469
+ status_code = response.status_code
2470
+ stream_result = json.dumps(response.text)
2471
+ ```
2472
+
2361
2473
Args:
2362
2474
body (bytes):
2363
2475
The body of the prediction request in bytes. This must not
2364
2476
exceed 10 mb per request.
2365
2477
headers (Dict[str, str]):
2366
2478
The header of the request as a dictionary. There are no
2367
2479
restrictions on the header.
2480
+ use_dedicated_endpoint (bool):
2481
+ Optional. Default value is False. If set to True, the underlying prediction call will be made
2482
+ using the dedicated endpoint dns.
2368
2483
2369
2484
Yields:
2370
2485
predictions (Iterator[requests.models.Response]):
@@ -2379,8 +2494,23 @@ def stream_raw_predict(
2379
2494
if self .stream_raw_predict_request_url is None :
2380
2495
self .stream_raw_predict_request_url = f"https://{ self .location } -{ constants .base .API_BASE_PATH } /v1/projects/{ self .project } /locations/{ self .location } /endpoints/{ self .name } :streamRawPredict"
2381
2496
2497
+ url = self .raw_predict_request_url
2498
+
2499
+ if use_dedicated_endpoint :
2500
+ self ._sync_gca_resource_if_skipped ()
2501
+ if (
2502
+ not self ._gca_resource .dedicated_endpoint_enabled
2503
+ or self ._gca_resource .dedicated_endpoint_dns is None
2504
+ ):
2505
+ raise ValueError (
2506
+ "Dedicated endpoint is not enabled or DNS is empty."
2507
+ "Please make sure endpoint has dedicated endpoint enabled"
2508
+ "and model are ready before making a prediction."
2509
+ )
2510
+ url = f"https://{ self ._gca_resource .dedicated_endpoint_dns } /v1/{ self .resource_name } :streamRawPredict"
2511
+
2382
2512
with self .authorized_session .post (
2383
- url = self . stream_raw_predict_request_url ,
2513
+ url = url ,
2384
2514
data = body ,
2385
2515
headers = headers ,
2386
2516
stream = True ,
0 commit comments