Skip to content

Commit a4d4e46

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Release API key support for GenerateContent to Public Preview
PiperOrigin-RevId: 689448497
1 parent b04196b commit a4d4e46

File tree

3 files changed

+87
-25
lines changed

3 files changed

+87
-25
lines changed

google/cloud/aiplatform/initializer.py

+22
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,12 @@ def init(
221221
f"{api_transport} is not a valid transport type. "
222222
+ f"Valid transport types: {VALID_TRANSPORT_TYPES}"
223223
)
224+
else:
225+
# Raise error if api_transport other than rest is specified for usage with API key.
226+
if not project and not api_transport:
227+
api_transport = "rest"
228+
elif not project and api_transport != "rest":
229+
raise ValueError(f"{api_transport} is not supported with API keys. ")
224230
if location:
225231
utils.validate_region(location)
226232
if experiment_description and experiment is None:
@@ -236,6 +242,11 @@ def init(
236242
logging.info("project/location updated, reset Experiment config.")
237243
metadata._experiment_tracker.reset()
238244

245+
if project and api_key:
246+
logging.info(
247+
"Both a project and API key have been provided. The project will take precedence over the API key."
248+
)
249+
239250
# Then we change the main state
240251
if api_endpoint is not None:
241252
self._api_endpoint = api_endpoint
@@ -438,7 +449,18 @@ def get_client_options(
438449

439450
api_endpoint = self.api_endpoint
440451

452+
if (
453+
api_endpoint is None
454+
and not self._project
455+
and not self._location
456+
and not location_override
457+
):
458+
# Default endpoint is location invariant if using API key
459+
api_endpoint = "aiplatform.googleapis.com"
460+
461+
# If both project and API key are passed in, project takes precedence.
441462
if api_endpoint is None:
463+
# Form the default endpoint to use with no API key.
442464
if not (self.location or location_override):
443465
raise ValueError(
444466
"No location found. Provide or initialize SDK with a location."

tests/unit/vertex_rag/test_rag_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ class TestRagDataManagement:
306306
def setup_method(self):
307307
importlib.reload(aiplatform.initializer)
308308
importlib.reload(aiplatform)
309-
aiplatform.init()
309+
aiplatform.init(project=tc.TEST_PROJECT, location=tc.TEST_REGION)
310310

311311
def teardown_method(self):
312312
aiplatform.initializer.global_pool.shutdown(wait=True)

vertexai/generative_models/_generative_models.py

+64-24
Original file line numberDiff line numberDiff line change
@@ -388,13 +388,25 @@ def __init__(
388388
def _prediction_client(self) -> prediction_service.PredictionServiceClient:
389389
# Switch to @functools.cached_property once its available.
390390
if not getattr(self, "_prediction_client_value", None):
391-
self._prediction_client_value = (
392-
aiplatform_initializer.global_config.create_client(
393-
client_class=prediction_service.PredictionServiceClient,
394-
location_override=self._location,
395-
prediction_client=True,
391+
if (
392+
aiplatform_initializer.global_config.api_key
393+
and not aiplatform_initializer.global_config.project
394+
):
395+
self._prediction_client_value = (
396+
aiplatform_initializer.global_config.create_client(
397+
client_class=prediction_service.PredictionServiceClient,
398+
api_key=aiplatform_initializer.global_config.api_key,
399+
prediction_client=True,
400+
)
401+
)
402+
else:
403+
self._prediction_client_value = (
404+
aiplatform_initializer.global_config.create_client(
405+
client_class=prediction_service.PredictionServiceClient,
406+
location_override=self._location,
407+
prediction_client=True,
408+
)
396409
)
397-
)
398410
return self._prediction_client_value
399411

400412
@property
@@ -403,26 +415,46 @@ def _prediction_async_client(
403415
) -> prediction_service.PredictionServiceAsyncClient:
404416
# Switch to @functools.cached_property once its available.
405417
if not getattr(self, "_prediction_async_client_value", None):
406-
self._prediction_async_client_value = (
407-
aiplatform_initializer.global_config.create_client(
408-
client_class=prediction_service.PredictionServiceAsyncClient,
409-
location_override=self._location,
410-
prediction_client=True,
418+
if (
419+
aiplatform_initializer.global_config.api_key
420+
and not aiplatform_initializer.global_config.project
421+
):
422+
raise RuntimeError(
423+
"Using an api key is not supported yet for async clients."
424+
)
425+
else:
426+
self._prediction_async_client_value = (
427+
aiplatform_initializer.global_config.create_client(
428+
client_class=prediction_service.PredictionServiceAsyncClient,
429+
location_override=self._location,
430+
prediction_client=True,
431+
)
411432
)
412-
)
413433
return self._prediction_async_client_value
414434

415435
@property
416436
def _llm_utility_client(self) -> llm_utility_service.LlmUtilityServiceClient:
417437
# Switch to @functools.cached_property once its available.
418438
if not getattr(self, "_llm_utility_client_value", None):
419-
self._llm_utility_client_value = (
420-
aiplatform_initializer.global_config.create_client(
421-
client_class=llm_utility_service.LlmUtilityServiceClient,
422-
location_override=self._location,
423-
prediction_client=True,
439+
if (
440+
aiplatform_initializer.global_config.api_key
441+
and not aiplatform_initializer.global_config.project
442+
):
443+
self._llm_utility_client_value = (
444+
aiplatform_initializer.global_config.create_client(
445+
client_class=llm_utility_service.LlmUtilityServiceClient,
446+
api_key=aiplatform_initializer.global_config.api_key,
447+
prediction_client=True,
448+
)
449+
)
450+
else:
451+
self._llm_utility_client_value = (
452+
aiplatform_initializer.global_config.create_client(
453+
client_class=llm_utility_service.LlmUtilityServiceClient,
454+
location_override=self._location,
455+
prediction_client=True,
456+
)
424457
)
425-
)
426458
return self._llm_utility_client_value
427459

428460
@property
@@ -431,13 +463,21 @@ def _llm_utility_async_client(
431463
) -> llm_utility_service.LlmUtilityServiceAsyncClient:
432464
# Switch to @functools.cached_property once its available.
433465
if not getattr(self, "_llm_utility_async_client_value", None):
434-
self._llm_utility_async_client_value = (
435-
aiplatform_initializer.global_config.create_client(
436-
client_class=llm_utility_service.LlmUtilityServiceAsyncClient,
437-
location_override=self._location,
438-
prediction_client=True,
466+
if (
467+
aiplatform_initializer.global_config.api_key
468+
and not aiplatform_initializer.global_config.project
469+
):
470+
raise RuntimeError(
471+
"Using an api key is not supported yet for async clients."
472+
)
473+
else:
474+
self._llm_utility_async_client_value = (
475+
aiplatform_initializer.global_config.create_client(
476+
client_class=llm_utility_service.LlmUtilityServiceAsyncClient,
477+
location_override=self._location,
478+
prediction_client=True,
479+
)
439480
)
440-
)
441481
return self._llm_utility_async_client_value
442482

443483
def _prepare_request(

0 commit comments

Comments
 (0)