@@ -2353,6 +2353,9 @@ def predict(
2353
2353
Returns:
2354
2354
prediction (aiplatform.Prediction):
2355
2355
Prediction with returned predictions and Model ID.
2356
+
2357
+ Raises:
2358
+ ImportError: If there is an issue importing the `TCPKeepAliveAdapter` package.
2356
2359
"""
2357
2360
self .wait ()
2358
2361
if use_raw_predict :
@@ -2388,6 +2391,14 @@ def predict(
2388
2391
"Please make sure endpoint has dedicated endpoint enabled"
2389
2392
"and model are ready before making a prediction."
2390
2393
)
2394
+ try :
2395
+ from requests_toolbelt .adapters .socket_options import (
2396
+ TCPKeepAliveAdapter ,
2397
+ )
2398
+ except ImportError :
2399
+ raise ImportError (
2400
+ "Cannot import the requests-toolbelt library. Please install requests-toolbelt."
2401
+ )
2391
2402
2392
2403
if not self .authorized_session :
2393
2404
self .credentials ._scopes = constants .base .DEFAULT_AUTHED_SCOPES
@@ -2400,6 +2411,9 @@ def predict(
2400
2411
}
2401
2412
2402
2413
url = f"https://{ self ._gca_resource .dedicated_endpoint_dns } /v1/{ self .resource_name } :predict"
2414
+ # count * interval need to be larger than 1 hr (3600s)
2415
+ keep_alive = TCPKeepAliveAdapter (idle = 120 , count = 100 , interval = 100 )
2416
+ self .authorized_session .mount ("https://" , keep_alive )
2403
2417
response = self .authorized_session .post (
2404
2418
url = url ,
2405
2419
data = json .dumps (
@@ -2546,6 +2560,9 @@ def raw_predict(
2546
2560
2547
2561
Returns:
2548
2562
A requests.models.Response object containing the status code and prediction results.
2563
+
2564
+ Raises:
2565
+ ImportError: If there is an issue importing the `TCPKeepAliveAdapter` package.
2549
2566
"""
2550
2567
if not self .authorized_session :
2551
2568
self .credentials ._scopes = constants .base .DEFAULT_AUTHED_SCOPES
@@ -2559,6 +2576,14 @@ def raw_predict(
2559
2576
url = self .raw_predict_request_url
2560
2577
2561
2578
if use_dedicated_endpoint :
2579
+ try :
2580
+ from requests_toolbelt .adapters .socket_options import (
2581
+ TCPKeepAliveAdapter ,
2582
+ )
2583
+ except ImportError :
2584
+ raise ImportError (
2585
+ "Cannot import the requests-toolbelt library. Please install requests-toolbelt."
2586
+ )
2562
2587
self ._sync_gca_resource_if_skipped ()
2563
2588
if (
2564
2589
not self ._gca_resource .dedicated_endpoint_enabled
@@ -2570,6 +2595,10 @@ def raw_predict(
2570
2595
"and model are ready before making a prediction."
2571
2596
)
2572
2597
url = f"https://{ self ._gca_resource .dedicated_endpoint_dns } /v1/{ self .resource_name } :rawPredict"
2598
+ # count * interval need to be larger than 1 hr (3600s)
2599
+ keep_alive = TCPKeepAliveAdapter (idle = 120 , count = 100 , interval = 100 )
2600
+ self .authorized_session .mount ("https://" , keep_alive )
2601
+
2573
2602
return self .authorized_session .post (
2574
2603
url = url , data = body , headers = headers , timeout = timeout
2575
2604
)
0 commit comments