Skip to content

Commit eb6071f

Browse files
Ark-kuncopybara-github
authored andcommitted
fix: Made the Endpoint prediction client initialization lazy
This is mainly to avoid issues with the `PredictionAsyncClient` which is based on `asyncio` and conflicts with other asynchronous solutions. Fixes #2620 PiperOrigin-RevId: 574978613
1 parent 98ab2f9 commit eb6071f

File tree

3 files changed

+30
-88
lines changed

3 files changed

+30
-88
lines changed

google/cloud/aiplatform/base.py

-1
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,6 @@ def _sync_object_with_future_result(
986986
"credentials",
987987
]
988988
optional_sync_attributes = [
989-
"_prediction_client",
990989
"_authorized_session",
991990
"_raw_predict_request_url",
992991
]

google/cloud/aiplatform/models.py

+30-63
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
import asyncio
1817
import json
1918
import pathlib
2019
import re
@@ -227,16 +226,39 @@ def __init__(
227226
# Lazy load the Endpoint gca_resource until needed
228227
self._gca_resource = gca_endpoint_compat.Endpoint(name=endpoint_name)
229228

230-
(
231-
self._prediction_client,
232-
self._prediction_async_client,
233-
) = self._instantiate_prediction_clients(
234-
location=self.location,
235-
credentials=credentials,
236-
)
237229
self.authorized_session = None
238230
self.raw_predict_request_url = None
239231

232+
@property
233+
def _prediction_client(self) -> utils.PredictionClientWithOverride:
234+
# The attribute might not exist due to issues in
235+
# `VertexAiResourceNounWithFutureManager._sync_object_with_future_result`
236+
# We should switch to @functools.cached_property once its available.
237+
if not getattr(self, "_prediction_client_value", None):
238+
self._prediction_client_value = initializer.global_config.create_client(
239+
client_class=utils.PredictionClientWithOverride,
240+
credentials=self.credentials,
241+
location_override=self.location,
242+
prediction_client=True,
243+
)
244+
return self._prediction_client_value
245+
246+
@property
247+
def _prediction_async_client(self) -> utils.PredictionAsyncClientWithOverride:
248+
# The attribute might not exist due to issues in
249+
# `VertexAiResourceNounWithFutureManager._sync_object_with_future_result`
250+
# We should switch to @functools.cached_property once its available.
251+
if not getattr(self, "_prediction_async_client_value", None):
252+
self._prediction_async_client_value = (
253+
initializer.global_config.create_client(
254+
client_class=utils.PredictionAsyncClientWithOverride,
255+
credentials=self.credentials,
256+
location_override=self.location,
257+
prediction_client=True,
258+
)
259+
)
260+
return self._prediction_async_client_value
261+
240262
def _skipped_getter_call(self) -> bool:
241263
"""Check if GAPIC resource was populated by call to get/list API methods
242264
@@ -575,14 +597,6 @@ def _construct_sdk_resource_from_gapic(
575597
location=location,
576598
credentials=credentials,
577599
)
578-
579-
(
580-
endpoint._prediction_client,
581-
endpoint._prediction_async_client,
582-
) = cls._instantiate_prediction_clients(
583-
location=endpoint.location,
584-
credentials=credentials,
585-
)
586600
endpoint.authorized_session = None
587601
endpoint.raw_predict_request_url = None
588602

@@ -1390,53 +1404,6 @@ def _undeploy(
13901404
# update local resource
13911405
self._sync_gca_resource()
13921406

1393-
@staticmethod
1394-
def _instantiate_prediction_clients(
1395-
location: Optional[str] = None,
1396-
credentials: Optional[auth_credentials.Credentials] = None,
1397-
) -> Tuple[
1398-
utils.PredictionClientWithOverride, utils.PredictionAsyncClientWithOverride
1399-
]:
1400-
"""Helper method to instantiates prediction client with optional
1401-
overrides for this endpoint.
1402-
1403-
Args:
1404-
location (str): The location of this endpoint.
1405-
credentials (google.auth.credentials.Credentials):
1406-
Optional custom credentials to use when accessing interacting with
1407-
the prediction client.
1408-
1409-
Returns:
1410-
prediction_client (prediction_service_client.PredictionServiceClient):
1411-
prediction_async_client (PredictionServiceAsyncClient):
1412-
Initialized prediction clients with optional overrides.
1413-
"""
1414-
1415-
# Creating an event loop if needed.
1416-
# PredictionServiceAsyncClient constructor calls `asyncio.get_event_loop`,
1417-
# which fails when there is no event loop (which does not exist by default
1418-
# in non-main threads in thread pool used when `sync=False`).
1419-
try:
1420-
asyncio.get_event_loop()
1421-
except RuntimeError:
1422-
asyncio.set_event_loop(asyncio.new_event_loop())
1423-
1424-
async_client = initializer.global_config.create_client(
1425-
client_class=utils.PredictionAsyncClientWithOverride,
1426-
credentials=credentials,
1427-
location_override=location,
1428-
prediction_client=True,
1429-
)
1430-
# We could use `client = async_client._client`, but then client would be
1431-
# a concrete `PredictionServiceClient`, not `PredictionClientWithOverride`.
1432-
client = initializer.global_config.create_client(
1433-
client_class=utils.PredictionClientWithOverride,
1434-
credentials=credentials,
1435-
location_override=location,
1436-
prediction_client=True,
1437-
)
1438-
return (client, async_client)
1439-
14401407
def update(
14411408
self,
14421409
display_name: Optional[str] = None,

tests/unit/aiplatform/test_endpoints.py

-24
Original file line numberDiff line numberDiff line change
@@ -658,18 +658,6 @@ def test_constructor(self, create_endpoint_client_mock):
658658
location_override=_TEST_LOCATION,
659659
appended_user_agent=None,
660660
),
661-
mock.call(
662-
client_class=utils.PredictionAsyncClientWithOverride,
663-
credentials=None,
664-
location_override=_TEST_LOCATION,
665-
prediction_client=True,
666-
),
667-
mock.call(
668-
client_class=utils.PredictionClientWithOverride,
669-
credentials=None,
670-
location_override=_TEST_LOCATION,
671-
prediction_client=True,
672-
),
673661
]
674662
)
675663

@@ -754,18 +742,6 @@ def test_constructor_with_custom_credentials(self, create_endpoint_client_mock):
754742
location_override=_TEST_LOCATION,
755743
appended_user_agent=None,
756744
),
757-
mock.call(
758-
client_class=utils.PredictionAsyncClientWithOverride,
759-
credentials=creds,
760-
location_override=_TEST_LOCATION,
761-
prediction_client=True,
762-
),
763-
mock.call(
764-
client_class=utils.PredictionClientWithOverride,
765-
credentials=creds,
766-
location_override=_TEST_LOCATION,
767-
prediction_client=True,
768-
),
769745
]
770746
)
771747

0 commit comments

Comments
 (0)