|
18 | 18 | import copy
|
19 | 19 | from datetime import datetime, timedelta
|
20 | 20 | from importlib import reload
|
21 |
| -import requests |
22 | 21 | import json
|
| 22 | +import requests |
23 | 23 | from unittest import mock
|
24 | 24 |
|
25 | 25 | from google.api_core import operation as ga_operation
|
@@ -920,6 +920,49 @@ def predict_private_endpoint_mock():
|
920 | 920 | yield predict_mock
|
921 | 921 |
|
922 | 922 |
|
| 923 | +@pytest.fixture |
| 924 | +def stream_raw_predict_private_endpoint_mock(): |
| 925 | + with mock.patch.object( |
| 926 | + google_auth_requests.AuthorizedSession, "post" |
| 927 | + ) as stream_raw_predict_mock: |
| 928 | + # Create a mock response object |
| 929 | + mock_response = mock.Mock(spec=requests.Response) |
| 930 | + |
| 931 | + # Configure the mock to be used as a context manager |
| 932 | + stream_raw_predict_mock.return_value.__enter__.return_value = mock_response |
| 933 | + |
| 934 | + # Set the status code to 200 (OK) |
| 935 | + mock_response.status_code = 200 |
| 936 | + |
| 937 | + # Simulate streaming data with iter_lines |
| 938 | + mock_response.iter_lines = mock.Mock( |
| 939 | + return_value=iter( |
| 940 | + [ |
| 941 | + json.dumps( |
| 942 | + { |
| 943 | + "predictions": [1.0, 2.0, 3.0], |
| 944 | + "metadata": {"key": "value"}, |
| 945 | + "deployedModelId": "model-id-123", |
| 946 | + "model": "model-name", |
| 947 | + "modelVersionId": "1", |
| 948 | + } |
| 949 | + ).encode("utf-8"), |
| 950 | + json.dumps( |
| 951 | + { |
| 952 | + "predictions": [4.0, 5.0, 6.0], |
| 953 | + "metadata": {"key": "value"}, |
| 954 | + "deployedModelId": "model-id-123", |
| 955 | + "model": "model-name", |
| 956 | + "modelVersionId": "1", |
| 957 | + } |
| 958 | + ).encode("utf-8"), |
| 959 | + ] |
| 960 | + ) |
| 961 | + ) |
| 962 | + |
| 963 | + yield stream_raw_predict_mock |
| 964 | + |
| 965 | + |
923 | 966 | @pytest.fixture
|
924 | 967 | def health_check_private_endpoint_mock():
|
925 | 968 | with mock.patch.object(urllib3.PoolManager, "request") as health_check_mock:
|
@@ -3195,6 +3238,57 @@ def test_psc_predict(self, predict_private_endpoint_mock):
|
3195 | 3238 | },
|
3196 | 3239 | )
|
3197 | 3240 |
|
| 3241 | + @pytest.mark.usefixtures("get_psc_private_endpoint_mock") |
| 3242 | + def test_psc_stream_raw_predict(self, stream_raw_predict_private_endpoint_mock): |
| 3243 | + test_endpoint = models.PrivateEndpoint( |
| 3244 | + project=_TEST_PROJECT, location=_TEST_LOCATION, endpoint_name=_TEST_ID |
| 3245 | + ) |
| 3246 | + |
| 3247 | + test_prediction_iterator = test_endpoint.stream_raw_predict( |
| 3248 | + body='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}', |
| 3249 | + headers={ |
| 3250 | + "Content-Type": "application/json", |
| 3251 | + "Authorization": "Bearer None", |
| 3252 | + }, |
| 3253 | + endpoint_override=_TEST_ENDPOINT_OVERRIDE, |
| 3254 | + ) |
| 3255 | + |
| 3256 | + test_prediction = list(test_prediction_iterator) |
| 3257 | + |
| 3258 | + stream_raw_predict_private_endpoint_mock.assert_called_once_with( |
| 3259 | + url=f"https://{_TEST_ENDPOINT_OVERRIDE}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:streamRawPredict", |
| 3260 | + data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}', |
| 3261 | + headers={ |
| 3262 | + "Content-Type": "application/json", |
| 3263 | + "Authorization": "Bearer None", |
| 3264 | + }, |
| 3265 | + stream=True, |
| 3266 | + verify=False, |
| 3267 | + ) |
| 3268 | + |
| 3269 | + # Validate the content of the returned predictions |
| 3270 | + expected_predictions = [ |
| 3271 | + json.dumps( |
| 3272 | + { |
| 3273 | + "predictions": [1.0, 2.0, 3.0], |
| 3274 | + "metadata": {"key": "value"}, |
| 3275 | + "deployedModelId": "model-id-123", |
| 3276 | + "model": "model-name", |
| 3277 | + "modelVersionId": "1", |
| 3278 | + } |
| 3279 | + ).encode("utf-8"), |
| 3280 | + json.dumps( |
| 3281 | + { |
| 3282 | + "predictions": [4.0, 5.0, 6.0], |
| 3283 | + "metadata": {"key": "value"}, |
| 3284 | + "deployedModelId": "model-id-123", |
| 3285 | + "model": "model-name", |
| 3286 | + "modelVersionId": "1", |
| 3287 | + } |
| 3288 | + ).encode("utf-8"), |
| 3289 | + ] |
| 3290 | + assert test_prediction == expected_predictions |
| 3291 | + |
3198 | 3292 | @pytest.mark.usefixtures("get_psc_private_endpoint_mock")
|
3199 | 3293 | def test_psc_predict_without_endpoint_override(self):
|
3200 | 3294 | test_endpoint = models.PrivateEndpoint(
|
|
0 commit comments