|
80 | 80 | _TEST_DESCRIPTION = "test-description"
|
81 | 81 | _TEST_REQUEST_METADATA = ()
|
82 | 82 | _TEST_TIMEOUT = None
|
| 83 | +_TEST_PREDICT_TIMEOUT = 100 |
83 | 84 |
|
84 | 85 | _TEST_ENDPOINT_NAME = test_constants.EndpointConstants._TEST_ENDPOINT_NAME
|
85 | 86 | _TEST_ENDPOINT_NAME_2 = test_constants.EndpointConstants._TEST_ENDPOINT_NAME_2
|
@@ -2387,6 +2388,34 @@ def test_predict_dedicated_endpoint(self, predict_endpoint_http_mock):
|
2387 | 2388 | url=f"https://{_TEST_DEDICATED_ENDPOINT_DNS}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:predict",
|
2388 | 2389 | data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], "parameters": {"param": 3.0}}',
|
2389 | 2390 | headers={"Content-Type": "application/json"},
|
| 2391 | + timeout=None, |
| 2392 | + ) |
| 2393 | + |
| 2394 | + @pytest.mark.usefixtures("get_dedicated_endpoint_mock") |
| 2395 | + def test_predict_dedicated_endpoint_with_timeout(self, predict_endpoint_http_mock): |
| 2396 | + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) |
| 2397 | + |
| 2398 | + test_prediction = test_endpoint.predict( |
| 2399 | + instances=_TEST_INSTANCES, |
| 2400 | + parameters={"param": 3.0}, |
| 2401 | + use_dedicated_endpoint=True, |
| 2402 | + timeout=_TEST_PREDICT_TIMEOUT, |
| 2403 | + ) |
| 2404 | + |
| 2405 | + true_prediction = models.Prediction( |
| 2406 | + predictions=_TEST_PREDICTION, |
| 2407 | + deployed_model_id=_TEST_ID, |
| 2408 | + metadata=_TEST_METADATA, |
| 2409 | + model_version_id=_TEST_VERSION_ID, |
| 2410 | + model_resource_name=_TEST_MODEL_NAME, |
| 2411 | + ) |
| 2412 | + |
| 2413 | + assert true_prediction == test_prediction |
| 2414 | + predict_endpoint_http_mock.assert_called_once_with( |
| 2415 | + url=f"https://{_TEST_DEDICATED_ENDPOINT_DNS}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:predict", |
| 2416 | + data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], "parameters": {"param": 3.0}}', |
| 2417 | + headers={"Content-Type": "application/json"}, |
| 2418 | + timeout=_TEST_PREDICT_TIMEOUT, |
2390 | 2419 | )
|
2391 | 2420 |
|
2392 | 2421 | @pytest.mark.usefixtures("get_endpoint_mock")
|
@@ -2432,6 +2461,40 @@ def test_raw_predict_dedicated_endpoint(self, predict_endpoint_http_mock):
|
2432 | 2461 | url=f"https://{_TEST_DEDICATED_ENDPOINT_DNS}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:rawPredict",
|
2433 | 2462 | data=_TEST_RAW_INPUTS,
|
2434 | 2463 | headers={"Content-Type": "application/json"},
|
| 2464 | + timeout=None, |
| 2465 | + ) |
| 2466 | + |
| 2467 | + @pytest.mark.usefixtures("get_dedicated_endpoint_mock") |
| 2468 | + def test_raw_predict_dedicated_endpoint_with_timeout( |
| 2469 | + self, predict_endpoint_http_mock |
| 2470 | + ): |
| 2471 | + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) |
| 2472 | + |
| 2473 | + test_prediction = test_endpoint.raw_predict( |
| 2474 | + body=_TEST_RAW_INPUTS, |
| 2475 | + headers={"Content-Type": "application/json"}, |
| 2476 | + use_dedicated_endpoint=True, |
| 2477 | + timeout=_TEST_PREDICT_TIMEOUT, |
| 2478 | + ) |
| 2479 | + |
| 2480 | + true_prediction = requests.Response() |
| 2481 | + true_prediction.status_code = 200 |
| 2482 | + true_prediction._content = json.dumps( |
| 2483 | + { |
| 2484 | + "predictions": _TEST_PREDICTION, |
| 2485 | + "metadata": _TEST_METADATA, |
| 2486 | + "deployedModelId": _TEST_DEPLOYED_MODELS[0].id, |
| 2487 | + "model": _TEST_MODEL_NAME, |
| 2488 | + "modelVersionId": "1", |
| 2489 | + } |
| 2490 | + ).encode("utf-8") |
| 2491 | + assert true_prediction.status_code == test_prediction.status_code |
| 2492 | + assert true_prediction.text == test_prediction.text |
| 2493 | + predict_endpoint_http_mock.assert_called_once_with( |
| 2494 | + url=f"https://{_TEST_DEDICATED_ENDPOINT_DNS}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:rawPredict", |
| 2495 | + data=_TEST_RAW_INPUTS, |
| 2496 | + headers={"Content-Type": "application/json"}, |
| 2497 | + timeout=_TEST_PREDICT_TIMEOUT, |
2435 | 2498 | )
|
2436 | 2499 |
|
2437 | 2500 | @pytest.mark.usefixtures("get_endpoint_mock")
|
|
0 commit comments