Skip to content

Commit 197f333

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: PrivateEndpoint.stream_raw_predict
PiperOrigin-RevId: 669479012
1 parent efbcb54 commit 197f333

File tree

2 files changed

+184
-1
lines changed

2 files changed

+184
-1
lines changed

google/cloud/aiplatform/models.py

+89
Original file line numberDiff line numberDiff line change
@@ -3666,6 +3666,95 @@ def raw_predict(
36663666
headers=headers_with_token,
36673667
)
36683668

3669+
def stream_raw_predict(
3670+
self,
3671+
body: bytes,
3672+
headers: Dict[str, str],
3673+
endpoint_override: Optional[str] = None,
3674+
) -> Iterator[bytes]:
3675+
"""Make a streaming prediction request using arbitrary headers.
3676+
3677+
Example usage:
3678+
my_endpoint = aiplatform.PrivateEndpoint(ENDPOINT_ID)
3679+
3680+
# Prepare the request body
3681+
request_body = json.dumps({...}).encode('utf-8')
3682+
3683+
# Define the headers
3684+
headers = {
3685+
'Content-Type': 'application/json',
3686+
}
3687+
3688+
# Use stream_raw_predict to send the request and process the response
3689+
for stream_response in psc_endpoint.stream_raw_predict(
3690+
body=request_body,
3691+
headers=headers,
3692+
endpoint_override="10.128.0.26" # Replace with your actual endpoint
3693+
):
3694+
stream_response_text = stream_response.decode('utf-8')
3695+
3696+
Args:
3697+
body (bytes):
3698+
The body of the prediction request in bytes. This must not
3699+
exceed 10 mb per request.
3700+
headers (Dict[str, str]):
3701+
The header of the request as a dictionary. There are no
3702+
restrictions on the header.
3703+
endpoint_override (Optional[str]):
3704+
The Private Service Connect endpoint's IP address or DNS that
3705+
points to the endpoint's service attachment.
3706+
3707+
Yields:
3708+
predictions (Iterator[bytes]):
3709+
The streaming prediction results as lines of bytes.
3710+
3711+
Raises:
3712+
ValueError: If a endpoint override is not provided for PSC based
3713+
endpoint.
3714+
ValueError: If a endpoint override is invalid for PSC based endpoint.
3715+
"""
3716+
self.wait()
3717+
if self.network or not self.private_service_connect_config:
3718+
raise ValueError(
3719+
"PSA based private endpoint does not support streaming prediction."
3720+
)
3721+
3722+
if self.private_service_connect_config:
3723+
if not endpoint_override:
3724+
raise ValueError(
3725+
"Cannot make a predict request because endpoint override is"
3726+
"not provided. Please ensure an endpoint override is"
3727+
"provided."
3728+
)
3729+
if not self._validate_endpoint_override(endpoint_override):
3730+
raise ValueError(
3731+
"Invalid endpoint override provided. Please only use IP"
3732+
"address or DNS."
3733+
)
3734+
if not self.credentials.valid:
3735+
self.credentials.refresh(google_auth_requests.Request())
3736+
3737+
token = self.credentials.token
3738+
headers_with_token = dict(headers)
3739+
headers_with_token["Authorization"] = f"Bearer {token}"
3740+
3741+
if not self.authorized_session:
3742+
self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
3743+
self.authorized_session = google_auth_requests.AuthorizedSession(
3744+
self.credentials
3745+
)
3746+
3747+
url = f"https://{endpoint_override}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:streamRawPredict"
3748+
with self.authorized_session.post(
3749+
url=url,
3750+
data=body,
3751+
headers=headers_with_token,
3752+
stream=True,
3753+
verify=False,
3754+
) as resp:
3755+
for line in resp.iter_lines():
3756+
yield line
3757+
36693758
def explain(self):
36703759
raise NotImplementedError(
36713760
f"{self.__class__.__name__} class does not support 'explain' as of now."

tests/unit/aiplatform/test_endpoints.py

+95-1
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import copy
1919
from datetime import datetime, timedelta
2020
from importlib import reload
21-
import requests
2221
import json
22+
import requests
2323
from unittest import mock
2424

2525
from google.api_core import operation as ga_operation
@@ -920,6 +920,49 @@ def predict_private_endpoint_mock():
920920
yield predict_mock
921921

922922

923+
@pytest.fixture
924+
def stream_raw_predict_private_endpoint_mock():
925+
with mock.patch.object(
926+
google_auth_requests.AuthorizedSession, "post"
927+
) as stream_raw_predict_mock:
928+
# Create a mock response object
929+
mock_response = mock.Mock(spec=requests.Response)
930+
931+
# Configure the mock to be used as a context manager
932+
stream_raw_predict_mock.return_value.__enter__.return_value = mock_response
933+
934+
# Set the status code to 200 (OK)
935+
mock_response.status_code = 200
936+
937+
# Simulate streaming data with iter_lines
938+
mock_response.iter_lines = mock.Mock(
939+
return_value=iter(
940+
[
941+
json.dumps(
942+
{
943+
"predictions": [1.0, 2.0, 3.0],
944+
"metadata": {"key": "value"},
945+
"deployedModelId": "model-id-123",
946+
"model": "model-name",
947+
"modelVersionId": "1",
948+
}
949+
).encode("utf-8"),
950+
json.dumps(
951+
{
952+
"predictions": [4.0, 5.0, 6.0],
953+
"metadata": {"key": "value"},
954+
"deployedModelId": "model-id-123",
955+
"model": "model-name",
956+
"modelVersionId": "1",
957+
}
958+
).encode("utf-8"),
959+
]
960+
)
961+
)
962+
963+
yield stream_raw_predict_mock
964+
965+
923966
@pytest.fixture
924967
def health_check_private_endpoint_mock():
925968
with mock.patch.object(urllib3.PoolManager, "request") as health_check_mock:
@@ -3195,6 +3238,57 @@ def test_psc_predict(self, predict_private_endpoint_mock):
31953238
},
31963239
)
31973240

3241+
@pytest.mark.usefixtures("get_psc_private_endpoint_mock")
3242+
def test_psc_stream_raw_predict(self, stream_raw_predict_private_endpoint_mock):
3243+
test_endpoint = models.PrivateEndpoint(
3244+
project=_TEST_PROJECT, location=_TEST_LOCATION, endpoint_name=_TEST_ID
3245+
)
3246+
3247+
test_prediction_iterator = test_endpoint.stream_raw_predict(
3248+
body='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}',
3249+
headers={
3250+
"Content-Type": "application/json",
3251+
"Authorization": "Bearer None",
3252+
},
3253+
endpoint_override=_TEST_ENDPOINT_OVERRIDE,
3254+
)
3255+
3256+
test_prediction = list(test_prediction_iterator)
3257+
3258+
stream_raw_predict_private_endpoint_mock.assert_called_once_with(
3259+
url=f"https://{_TEST_ENDPOINT_OVERRIDE}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:streamRawPredict",
3260+
data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}',
3261+
headers={
3262+
"Content-Type": "application/json",
3263+
"Authorization": "Bearer None",
3264+
},
3265+
stream=True,
3266+
verify=False,
3267+
)
3268+
3269+
# Validate the content of the returned predictions
3270+
expected_predictions = [
3271+
json.dumps(
3272+
{
3273+
"predictions": [1.0, 2.0, 3.0],
3274+
"metadata": {"key": "value"},
3275+
"deployedModelId": "model-id-123",
3276+
"model": "model-name",
3277+
"modelVersionId": "1",
3278+
}
3279+
).encode("utf-8"),
3280+
json.dumps(
3281+
{
3282+
"predictions": [4.0, 5.0, 6.0],
3283+
"metadata": {"key": "value"},
3284+
"deployedModelId": "model-id-123",
3285+
"model": "model-name",
3286+
"modelVersionId": "1",
3287+
}
3288+
).encode("utf-8"),
3289+
]
3290+
assert test_prediction == expected_predictions
3291+
31983292
@pytest.mark.usefixtures("get_psc_private_endpoint_mock")
31993293
def test_psc_predict_without_endpoint_override(self):
32003294
test_endpoint = models.PrivateEndpoint(

0 commit comments

Comments
 (0)