Skip to content

Commit f917269

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add sdk support to inference timeout on cloud-based endpoints (dedicated or PSC).
PiperOrigin-RevId: 699325577
1 parent 1487846 commit f917269

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

google/cloud/aiplatform/models.py

+31
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,7 @@ def create(
783783
request_response_logging_sampling_rate: Optional[float] = None,
784784
request_response_logging_bq_destination_table: Optional[str] = None,
785785
dedicated_endpoint_enabled=False,
786+
inference_timeout: Optional[int] = None,
786787
) -> "Endpoint":
787788
"""Creates a new endpoint.
788789
@@ -854,6 +855,8 @@ def create(
854855
Optional. If enabled, a dedicated dns will be created and your
855856
traffic will be fully isolated from other customers' traffic and
856857
latency will be reduced.
858+
inference_timeout (int):
859+
Optional. It defines the prediction timeout, in seconds, for online predictions using cloud-based endpoints. This applies to either PSC endpoints, when private_service_connect_config is set, or dedicated endpoints, when dedicated_endpoint_enabled is true.
857860
858861
Returns:
859862
endpoint (aiplatform.Endpoint):
@@ -882,6 +885,17 @@ def create(
882885
),
883886
)
884887
)
888+
889+
client_connection_config = None
890+
if (
891+
inference_timeout is not None
892+
and inference_timeout > 0
893+
and dedicated_endpoint_enabled
894+
):
895+
client_connection_config = gca_endpoint_compat.ClientConnectionConfig(
896+
inference_timeout=duration_pb2.Duration(seconds=inference_timeout)
897+
)
898+
885899
return cls._create(
886900
api_client=api_client,
887901
display_name=display_name,
@@ -899,6 +913,7 @@ def create(
899913
endpoint_id=endpoint_id,
900914
predict_request_response_logging_config=predict_request_response_logging_config,
901915
dedicated_endpoint_enabled=dedicated_endpoint_enabled,
916+
client_connection_config=client_connection_config,
902917
)
903918

904919
@classmethod
@@ -925,6 +940,9 @@ def _create(
925940
gca_service_networking.PrivateServiceConnectConfig
926941
] = None,
927942
dedicated_endpoint_enabled=False,
943+
client_connection_config: Optional[
944+
gca_endpoint_compat.ClientConnectionConfig
945+
] = None,
928946
) -> "Endpoint":
929947
"""Creates a new endpoint by calling the API client.
930948
@@ -995,6 +1013,8 @@ def _create(
9951013
Optional. If enabled, a dedicated dns will be created and your
9961014
traffic will be fully isolated from other customers' traffic and
9971015
latency will be reduced.
1016+
client_connection_config (aiplatform.endpoint.ClientConnectionConfig):
1017+
Optional. The inference timeout which is applied on cloud-based (PSC, or dedicated) endpoints for online prediction.
9981018
9991019
Returns:
10001020
endpoint (aiplatform.Endpoint):
@@ -1014,6 +1034,7 @@ def _create(
10141034
predict_request_response_logging_config=predict_request_response_logging_config,
10151035
private_service_connect_config=private_service_connect_config,
10161036
dedicated_endpoint_enabled=dedicated_endpoint_enabled,
1037+
client_connection_config=client_connection_config,
10171038
)
10181039

10191040
operation_future = api_client.create_endpoint(
@@ -3253,6 +3274,7 @@ def create(
32533274
encryption_spec_key_name: Optional[str] = None,
32543275
sync=True,
32553276
private_service_connect_config: Optional[PrivateServiceConnectConfig] = None,
3277+
inference_timeout: Optional[int] = None,
32563278
) -> "PrivateEndpoint":
32573279
"""Creates a new PrivateEndpoint.
32583280
@@ -3338,6 +3360,8 @@ def create(
33383360
private_service_connect_config (aiplatform.PrivateEndpoint.PrivateServiceConnectConfig):
33393361
[Private Service Connect](https://cloud.google.com/vpc/docs/private-service-connect) configuration for the endpoint.
33403362
Cannot be set when network is specified.
3363+
inference_timeout (int):
3364+
Optional. It defines the prediction timeout, in seconds, for online predictions using cloud-based endpoints. This applies to either PSC endpoints, when private_service_connect_config is set, or dedicated endpoints, when dedicated_endpoint_enabled is true.
33413365
33423366
Returns:
33433367
endpoint (aiplatform.PrivateEndpoint):
@@ -3374,6 +3398,12 @@ def create(
33743398
private_service_connect_config._gapic_private_service_connect_config
33753399
)
33763400

3401+
client_connection_config = None
3402+
if private_service_connect_config and inference_timeout:
3403+
client_connection_config = gca_endpoint_compat.ClientConnectionConfig(
3404+
inference_timeout=duration_pb2.Duration(seconds=inference_timeout)
3405+
)
3406+
33773407
return cls._create(
33783408
api_client=api_client,
33793409
display_name=display_name,
@@ -3388,6 +3418,7 @@ def create(
33883418
network=network,
33893419
sync=sync,
33903420
private_service_connect_config=config,
3421+
client_connection_config=client_connection_config,
33913422
)
33923423

33933424
@classmethod

tests/unit/aiplatform/test_endpoints.py

+67
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import json
2222
import requests
2323
from unittest import mock
24+
from google.protobuf import duration_pb2
2425

2526
from google.api_core import operation as ga_operation
2627
from google.auth import credentials as auth_credentials
@@ -269,6 +270,11 @@
269270
)
270271
)
271272

273+
_TEST_INFERENCE_TIMEOUT = 100
274+
_TEST_CLIENT_CONNECTION_CONFIG = gca_endpoint.ClientConnectionConfig(
275+
inference_timeout=duration_pb2.Duration(seconds=_TEST_INFERENCE_TIMEOUT)
276+
)
277+
272278
"""
273279
----------------------------------------------------------------------------
274280
Endpoint Fixtures
@@ -1258,6 +1264,34 @@ def test_create_dedicated_endpoint(self, create_dedicated_endpoint_mock, sync):
12581264
endpoint_id=None,
12591265
)
12601266

1267+
@pytest.mark.parametrize("sync", [True, False])
1268+
def test_create_dedicated_endpoint_with_timeout(
1269+
self, create_dedicated_endpoint_mock, sync
1270+
):
1271+
my_endpoint = models.Endpoint.create(
1272+
display_name=_TEST_DISPLAY_NAME,
1273+
project=_TEST_PROJECT,
1274+
location=_TEST_LOCATION,
1275+
dedicated_endpoint_enabled=True,
1276+
sync=sync,
1277+
inference_timeout=_TEST_INFERENCE_TIMEOUT,
1278+
)
1279+
if not sync:
1280+
my_endpoint.wait()
1281+
1282+
expected_endpoint = gca_endpoint.Endpoint(
1283+
display_name=_TEST_DISPLAY_NAME,
1284+
dedicated_endpoint_enabled=True,
1285+
client_connection_config=_TEST_CLIENT_CONNECTION_CONFIG,
1286+
)
1287+
create_dedicated_endpoint_mock.assert_called_once_with(
1288+
parent=_TEST_PARENT,
1289+
endpoint=expected_endpoint,
1290+
metadata=(),
1291+
timeout=None,
1292+
endpoint_id=None,
1293+
)
1294+
12611295
@pytest.mark.usefixtures("get_empty_endpoint_mock")
12621296
def test_accessing_properties_with_no_resource_raises(
12631297
self,
@@ -3441,6 +3475,39 @@ def test_create_psc(self, create_psc_private_endpoint_mock, sync):
34413475
endpoint_id=None,
34423476
)
34433477

3478+
@pytest.mark.parametrize("sync", [True, False])
3479+
def test_create_psc_with_timeout(self, create_psc_private_endpoint_mock, sync):
3480+
test_endpoint = models.PrivateEndpoint.create(
3481+
display_name=_TEST_DISPLAY_NAME,
3482+
project=_TEST_PROJECT,
3483+
location=_TEST_LOCATION,
3484+
private_service_connect_config=models.PrivateEndpoint.PrivateServiceConnectConfig(
3485+
project_allowlist=_TEST_PROJECT_ALLOWLIST
3486+
),
3487+
sync=sync,
3488+
inference_timeout=_TEST_INFERENCE_TIMEOUT,
3489+
)
3490+
3491+
if not sync:
3492+
test_endpoint.wait()
3493+
3494+
expected_endpoint = gca_endpoint.Endpoint(
3495+
display_name=_TEST_DISPLAY_NAME,
3496+
private_service_connect_config=gca_service_networking.PrivateServiceConnectConfig(
3497+
enable_private_service_connect=True,
3498+
project_allowlist=_TEST_PROJECT_ALLOWLIST,
3499+
),
3500+
client_connection_config=_TEST_CLIENT_CONNECTION_CONFIG,
3501+
)
3502+
3503+
create_psc_private_endpoint_mock.assert_called_once_with(
3504+
parent=_TEST_PARENT,
3505+
endpoint=expected_endpoint,
3506+
metadata=(),
3507+
timeout=None,
3508+
endpoint_id=None,
3509+
)
3510+
34443511
@pytest.mark.usefixtures("get_psa_private_endpoint_with_model_mock")
34453512
def test_psa_predict(self, predict_private_endpoint_mock):
34463513
test_endpoint = models.PrivateEndpoint(_TEST_ID)

0 commit comments

Comments
 (0)