Skip to content

Commit b7de16a

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Add timeout to prediction rawPredict/streamRawPredict
PiperOrigin-RevId: 695761143
1 parent a1857ed commit b7de16a

File tree

3 files changed

+74
-2
lines changed

3 files changed

+74
-2
lines changed

google/cloud/aiplatform/models.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -2229,6 +2229,7 @@ def predict(
22292229
body=json.dumps({"instances": instances, "parameters": parameters}),
22302230
headers={"Content-Type": "application/json"},
22312231
use_dedicated_endpoint=use_dedicated_endpoint,
2232+
timeout=timeout,
22322233
)
22332234
json_response = raw_predict_response.json()
22342235
return Prediction(
@@ -2277,6 +2278,7 @@ def predict(
22772278
}
22782279
),
22792280
headers=headers,
2281+
timeout=timeout,
22802282
)
22812283

22822284
prediction_response = json.loads(response.text)
@@ -2382,6 +2384,7 @@ def raw_predict(
23822384
headers: Dict[str, str],
23832385
*,
23842386
use_dedicated_endpoint: Optional[bool] = False,
2387+
timeout: Optional[float] = None,
23852388
) -> requests.models.Response:
23862389
"""Makes a prediction request using arbitrary headers.
23872390
@@ -2408,6 +2411,7 @@ def raw_predict(
24082411
use_dedicated_endpoint (bool):
24092412
Optional. Default value is False. If set to True, the underlying prediction call will be made
24102413
using the dedicated endpoint dns.
2414+
timeout (float): Optional. The timeout for this request in seconds.
24112415
24122416
Returns:
24132417
A requests.models.Response object containing the status code and prediction results.
@@ -2435,15 +2439,17 @@ def raw_predict(
24352439
"and model are ready before making a prediction."
24362440
)
24372441
url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:rawPredict"
2438-
2439-
return self.authorized_session.post(url=url, data=body, headers=headers)
2442+
return self.authorized_session.post(
2443+
url=url, data=body, headers=headers, timeout=timeout
2444+
)
24402445

24412446
def stream_raw_predict(
24422447
self,
24432448
body: bytes,
24442449
headers: Dict[str, str],
24452450
*,
24462451
use_dedicated_endpoint: Optional[bool] = False,
2452+
timeout: Optional[float] = None,
24472453
) -> Iterator[requests.models.Response]:
24482454
"""Makes a streaming prediction request using arbitrary headers.
24492455
@@ -2480,6 +2486,7 @@ def stream_raw_predict(
24802486
use_dedicated_endpoint (bool):
24812487
Optional. Default value is False. If set to True, the underlying prediction call will be made
24822488
using the dedicated endpoint dns.
2489+
timeout (float): Optional. The timeout for this request in seconds.
24832490
24842491
Yields:
24852492
predictions (Iterator[requests.models.Response]):
@@ -2513,6 +2520,7 @@ def stream_raw_predict(
25132520
url=url,
25142521
data=body,
25152522
headers=headers,
2523+
timeout=timeout,
25162524
stream=True,
25172525
) as resp:
25182526
for line in resp.iter_lines():

tests/unit/aiplatform/test_endpoints.py

+63
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
_TEST_DESCRIPTION = "test-description"
8181
_TEST_REQUEST_METADATA = ()
8282
_TEST_TIMEOUT = None
83+
_TEST_PREDICT_TIMEOUT = 100
8384

8485
_TEST_ENDPOINT_NAME = test_constants.EndpointConstants._TEST_ENDPOINT_NAME
8586
_TEST_ENDPOINT_NAME_2 = test_constants.EndpointConstants._TEST_ENDPOINT_NAME_2
@@ -2387,6 +2388,34 @@ def test_predict_dedicated_endpoint(self, predict_endpoint_http_mock):
23872388
url=f"https://{_TEST_DEDICATED_ENDPOINT_DNS}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:predict",
23882389
data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], "parameters": {"param": 3.0}}',
23892390
headers={"Content-Type": "application/json"},
2391+
timeout=None,
2392+
)
2393+
2394+
@pytest.mark.usefixtures("get_dedicated_endpoint_mock")
2395+
def test_predict_dedicated_endpoint_with_timeout(self, predict_endpoint_http_mock):
2396+
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
2397+
2398+
test_prediction = test_endpoint.predict(
2399+
instances=_TEST_INSTANCES,
2400+
parameters={"param": 3.0},
2401+
use_dedicated_endpoint=True,
2402+
timeout=_TEST_PREDICT_TIMEOUT,
2403+
)
2404+
2405+
true_prediction = models.Prediction(
2406+
predictions=_TEST_PREDICTION,
2407+
deployed_model_id=_TEST_ID,
2408+
metadata=_TEST_METADATA,
2409+
model_version_id=_TEST_VERSION_ID,
2410+
model_resource_name=_TEST_MODEL_NAME,
2411+
)
2412+
2413+
assert true_prediction == test_prediction
2414+
predict_endpoint_http_mock.assert_called_once_with(
2415+
url=f"https://{_TEST_DEDICATED_ENDPOINT_DNS}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:predict",
2416+
data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], "parameters": {"param": 3.0}}',
2417+
headers={"Content-Type": "application/json"},
2418+
timeout=_TEST_PREDICT_TIMEOUT,
23902419
)
23912420

23922421
@pytest.mark.usefixtures("get_endpoint_mock")
@@ -2432,6 +2461,40 @@ def test_raw_predict_dedicated_endpoint(self, predict_endpoint_http_mock):
24322461
url=f"https://{_TEST_DEDICATED_ENDPOINT_DNS}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:rawPredict",
24332462
data=_TEST_RAW_INPUTS,
24342463
headers={"Content-Type": "application/json"},
2464+
timeout=None,
2465+
)
2466+
2467+
@pytest.mark.usefixtures("get_dedicated_endpoint_mock")
2468+
def test_raw_predict_dedicated_endpoint_with_timeout(
2469+
self, predict_endpoint_http_mock
2470+
):
2471+
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
2472+
2473+
test_prediction = test_endpoint.raw_predict(
2474+
body=_TEST_RAW_INPUTS,
2475+
headers={"Content-Type": "application/json"},
2476+
use_dedicated_endpoint=True,
2477+
timeout=_TEST_PREDICT_TIMEOUT,
2478+
)
2479+
2480+
true_prediction = requests.Response()
2481+
true_prediction.status_code = 200
2482+
true_prediction._content = json.dumps(
2483+
{
2484+
"predictions": _TEST_PREDICTION,
2485+
"metadata": _TEST_METADATA,
2486+
"deployedModelId": _TEST_DEPLOYED_MODELS[0].id,
2487+
"model": _TEST_MODEL_NAME,
2488+
"modelVersionId": "1",
2489+
}
2490+
).encode("utf-8")
2491+
assert true_prediction.status_code == test_prediction.status_code
2492+
assert true_prediction.text == test_prediction.text
2493+
predict_endpoint_http_mock.assert_called_once_with(
2494+
url=f"https://{_TEST_DEDICATED_ENDPOINT_DNS}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:rawPredict",
2495+
data=_TEST_RAW_INPUTS,
2496+
headers={"Content-Type": "application/json"},
2497+
timeout=_TEST_PREDICT_TIMEOUT,
24352498
)
24362499

24372500
@pytest.mark.usefixtures("get_endpoint_mock")

tests/unit/aiplatform/test_models.py

+1
Original file line numberDiff line numberDiff line change
@@ -3900,6 +3900,7 @@ def test_raw_predict(self, raw_predict_mock):
39003900
url=_TEST_RAW_PREDICT_URL,
39013901
data=_TEST_RAW_PREDICT_DATA,
39023902
headers=_TEST_RAW_PREDICT_HEADER,
3903+
timeout=None,
39033904
)
39043905

39053906
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)