Skip to content

Commit cc59e60

Browse files
authored
feat: Add timeout arguments to Endpoint.predict and Endpoint.explain (#1094)
Fixes # [b/224990641](b/224990641) 🦕
1 parent 25b546a commit cc59e60

File tree

4 files changed

+56
-4
lines changed

4 files changed

+56
-4
lines changed

google/cloud/aiplatform/models.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1167,7 +1167,12 @@ def _instantiate_prediction_client(
11671167
prediction_client=True,
11681168
)
11691169

1170-
def predict(self, instances: List, parameters: Optional[Dict] = None) -> Prediction:
1170+
def predict(
1171+
self,
1172+
instances: List,
1173+
parameters: Optional[Dict] = None,
1174+
timeout: Optional[float] = None,
1175+
) -> Prediction:
11711176
"""Make a prediction against this Endpoint.
11721177
11731178
Args:
@@ -1190,13 +1195,17 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict
11901195
][google.cloud.aiplatform.v1beta1.DeployedModel.model]
11911196
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
11921197
``parameters_schema_uri``.
1198+
timeout (float): Optional. The timeout for this request in seconds.
11931199
Returns:
11941200
prediction: Prediction with returned predictions and Model Id.
11951201
"""
11961202
self.wait()
11971203

11981204
prediction_response = self._prediction_client.predict(
1199-
endpoint=self._gca_resource.name, instances=instances, parameters=parameters
1205+
endpoint=self._gca_resource.name,
1206+
instances=instances,
1207+
parameters=parameters,
1208+
timeout=timeout,
12001209
)
12011210

12021211
return Prediction(
@@ -1212,6 +1221,7 @@ def explain(
12121221
instances: List[Dict],
12131222
parameters: Optional[Dict] = None,
12141223
deployed_model_id: Optional[str] = None,
1224+
timeout: Optional[float] = None,
12151225
) -> Prediction:
12161226
"""Make a prediction with explanations against this Endpoint.
12171227
@@ -1242,6 +1252,7 @@ def explain(
12421252
deployed_model_id (str):
12431253
Optional. If specified, this ExplainRequest will be served by the
12441254
chosen DeployedModel, overriding this Endpoint's traffic split.
1255+
timeout (float): Optional. The timeout for this request in seconds.
12451256
Returns:
12461257
prediction: Prediction with returned predictions, explanations and Model Id.
12471258
"""
@@ -1252,6 +1263,7 @@ def explain(
12521263
instances=instances,
12531264
parameters=parameters,
12541265
deployed_model_id=deployed_model_id,
1266+
timeout=timeout,
12551267
)
12561268

12571269
return Prediction(

tests/system/aiplatform/test_e2e_tabular.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,14 @@ def test_end_to_end_tabular(self, shared_state):
164164
is True
165165
)
166166

167-
custom_prediction = custom_endpoint.predict([_INSTANCE])
167+
custom_prediction = custom_endpoint.predict([_INSTANCE], timeout=180.0)
168168

169169
custom_batch_prediction_job.wait()
170170

171171
automl_endpoint.wait()
172172
automl_prediction = automl_endpoint.predict(
173-
[{k: str(v) for k, v in _INSTANCE.items()}] # Cast int values to strings
173+
[{k: str(v) for k, v in _INSTANCE.items()}], # Cast int values to strings
174+
timeout=180.0,
174175
)
175176

176177
# Test lazy loading of Endpoint, check getter was never called after predict()

tests/unit/aiplatform/test_end_to_end.py

+1
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def test_dataset_create_to_model_predict(
174174
endpoint=test_endpoints._TEST_ENDPOINT_NAME,
175175
instances=[[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]],
176176
parameters={"param": 3.0},
177+
timeout=None,
177178
)
178179

179180
expected_dataset = gca_dataset.Dataset(

tests/unit/aiplatform/test_endpoints.py

+38
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,7 @@ def test_predict(self, get_endpoint_mock, predict_client_predict_mock):
11621162
endpoint=_TEST_ENDPOINT_NAME,
11631163
instances=_TEST_INSTANCES,
11641164
parameters={"param": 3.0},
1165+
timeout=None,
11651166
)
11661167

11671168
def test_explain(self, get_endpoint_mock, predict_client_explain_mock):
@@ -1187,6 +1188,43 @@ def test_explain(self, get_endpoint_mock, predict_client_explain_mock):
11871188
instances=_TEST_INSTANCES,
11881189
parameters={"param": 3.0},
11891190
deployed_model_id=_TEST_MODEL_ID,
1191+
timeout=None,
1192+
)
1193+
1194+
@pytest.mark.usefixtures("get_endpoint_mock")
1195+
def test_predict_with_timeout(self, predict_client_predict_mock):
1196+
1197+
test_endpoint = models.Endpoint(_TEST_ID)
1198+
1199+
test_endpoint.predict(
1200+
instances=_TEST_INSTANCES, parameters={"param": 3.0}, timeout=10.0
1201+
)
1202+
1203+
predict_client_predict_mock.assert_called_once_with(
1204+
endpoint=_TEST_ENDPOINT_NAME,
1205+
instances=_TEST_INSTANCES,
1206+
parameters={"param": 3.0},
1207+
timeout=10.0,
1208+
)
1209+
1210+
@pytest.mark.usefixtures("get_endpoint_mock")
1211+
def test_explain_with_timeout(self, predict_client_explain_mock):
1212+
1213+
test_endpoint = models.Endpoint(_TEST_ID)
1214+
1215+
test_endpoint.explain(
1216+
instances=_TEST_INSTANCES,
1217+
parameters={"param": 3.0},
1218+
deployed_model_id=_TEST_MODEL_ID,
1219+
timeout=10.0,
1220+
)
1221+
1222+
predict_client_explain_mock.assert_called_once_with(
1223+
endpoint=_TEST_ENDPOINT_NAME,
1224+
instances=_TEST_INSTANCES,
1225+
parameters={"param": 3.0},
1226+
deployed_model_id=_TEST_MODEL_ID,
1227+
timeout=10.0,
11901228
)
11911229

11921230
def test_list_models(self, get_endpoint_with_models_mock):

0 commit comments

Comments
 (0)