Skip to content

Commit 6383a52

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore: Enable keepalive to avoid connection timeout for dedicated endpoints.
PiperOrigin-RevId: 740860105
1 parent 51dbe94 commit 6383a52

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

google/cloud/aiplatform/models.py

+29
Original file line numberDiff line numberDiff line change
@@ -2353,6 +2353,9 @@ def predict(
23532353
Returns:
23542354
prediction (aiplatform.Prediction):
23552355
Prediction with returned predictions and Model ID.
2356+
2357+
Raises:
2358+
ImportError: If there is an issue importing the `TCPKeepAliveAdapter` package.
23562359
"""
23572360
self.wait()
23582361
if use_raw_predict:
@@ -2388,6 +2391,14 @@ def predict(
23882391
"Please make sure endpoint has dedicated endpoint enabled"
23892392
"and model are ready before making a prediction."
23902393
)
2394+
try:
2395+
from requests_toolbelt.adapters.socket_options import (
2396+
TCPKeepAliveAdapter,
2397+
)
2398+
except ImportError:
2399+
raise ImportError(
2400+
"Cannot import the requests-toolbelt library. Please install requests-toolbelt."
2401+
)
23912402

23922403
if not self.authorized_session:
23932404
self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
@@ -2400,6 +2411,9 @@ def predict(
24002411
}
24012412

24022413
url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:predict"
2414+
# count * interval need to be larger than 1 hr (3600s)
2415+
keep_alive = TCPKeepAliveAdapter(idle=120, count=100, interval=100)
2416+
self.authorized_session.mount("https://", keep_alive)
24032417
response = self.authorized_session.post(
24042418
url=url,
24052419
data=json.dumps(
@@ -2546,6 +2560,9 @@ def raw_predict(
25462560
25472561
Returns:
25482562
A requests.models.Response object containing the status code and prediction results.
2563+
2564+
Raises:
2565+
ImportError: If there is an issue importing the `TCPKeepAliveAdapter` package.
25492566
"""
25502567
if not self.authorized_session:
25512568
self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
@@ -2559,6 +2576,14 @@ def raw_predict(
25592576
url = self.raw_predict_request_url
25602577

25612578
if use_dedicated_endpoint:
2579+
try:
2580+
from requests_toolbelt.adapters.socket_options import (
2581+
TCPKeepAliveAdapter,
2582+
)
2583+
except ImportError:
2584+
raise ImportError(
2585+
"Cannot import the requests-toolbelt library. Please install requests-toolbelt."
2586+
)
25622587
self._sync_gca_resource_if_skipped()
25632588
if (
25642589
not self._gca_resource.dedicated_endpoint_enabled
@@ -2570,6 +2595,10 @@ def raw_predict(
25702595
"and model are ready before making a prediction."
25712596
)
25722597
url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:rawPredict"
2598+
# count * interval need to be larger than 1 hr (3600s)
2599+
keep_alive = TCPKeepAliveAdapter(idle=120, count=100, interval=100)
2600+
self.authorized_session.mount("https://", keep_alive)
2601+
25732602
return self.authorized_session.post(
25742603
url=url, data=body, headers=headers, timeout=timeout
25752604
)

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
"uvicorn[standard] >= 0.16.0",
9090
]
9191

92-
endpoint_extra_require = ["requests >= 2.28.1"]
92+
endpoint_extra_require = ["requests >= 2.28.1", "requests-toolbelt <= 1.0.0"]
9393

9494
private_endpoints_extra_require = [
9595
"urllib3 >=1.21.1, <1.27",
@@ -258,7 +258,7 @@
258258
# future versions fix this issue
259259
"torch >= 2.0.0, < 2.1.0; python_version<='3.11'",
260260
"torch >= 2.2.0; python_version>'3.11'",
261-
"requests-toolbelt < 1.0.0",
261+
"requests-toolbelt <= 1.0.0",
262262
"immutabledict",
263263
"xgboost",
264264
]

0 commit comments

Comments
 (0)