Skip to content

Commit 9dccb51

Browse files
kazukitakamatsuiamjoel
authored andcommitted
Fix model provider of vertex ai (#11437)
1 parent 65262b3 commit 9dccb51

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

api/core/model_runtime/model_providers/vertex_ai/llm/llm.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,14 @@ def _generate_anthropic(
104104
"""
105105
# use Anthropic official SDK references
106106
# - https://github.com/anthropics/anthropic-sdk-python
107-
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
107+
service_account_key = credentials.get("vertex_service_account_key", "")
108108
project_id = credentials["vertex_project_id"]
109109
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
110110
token = ""
111111

112112
# get access token from service account credential
113-
if service_account_info:
113+
if service_account_key:
114+
service_account_info = json.loads(base64.b64decode(service_account_key))
114115
credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
115116
request = google.auth.transport.requests.Request()
116117
credentials.refresh(request)
@@ -478,10 +479,11 @@ def _generate(
478479
if stop:
479480
config_kwargs["stop_sequences"] = stop
480481

481-
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
482+
service_account_key = credentials.get("vertex_service_account_key", "")
482483
project_id = credentials["vertex_project_id"]
483484
location = credentials["vertex_location"]
484-
if service_account_info:
485+
if service_account_key:
486+
service_account_info = json.loads(base64.b64decode(service_account_key))
485487
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
486488
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
487489
else:

api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ def _invoke(
4848
:param input_type: input type
4949
:return: embeddings result
5050
"""
51-
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
51+
service_account_key = credentials.get("vertex_service_account_key", "")
5252
project_id = credentials["vertex_project_id"]
5353
location = credentials["vertex_location"]
54-
if service_account_info:
54+
if service_account_key:
55+
service_account_info = json.loads(base64.b64decode(service_account_key))
5556
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
5657
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
5758
else:
@@ -100,10 +101,11 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
100101
:return:
101102
"""
102103
try:
103-
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
104+
service_account_key = credentials.get("vertex_service_account_key", "")
104105
project_id = credentials["vertex_project_id"]
105106
location = credentials["vertex_location"]
106-
if service_account_info:
107+
if service_account_key:
108+
service_account_info = json.loads(base64.b64decode(service_account_key))
107109
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
108110
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
109111
else:

0 commit comments

Comments
 (0)