Skip to content

Commit 5baf5f8

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Use colab enterprise enviroment variables to infer project_id and region
PiperOrigin-RevId: 615076478
1 parent e004e87 commit 5baf5f8

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

google/cloud/aiplatform/initializer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def _set_project_as_env_var_or_google_auth_default(self):
6969
# See https://github.com/googleapis/google-auth-library-python/issues/924
7070
# TODO: Remove when google.auth.default() learns the
7171
# CLOUD_ML_PROJECT_ID env variable or Vertex AI starts setting GOOGLE_CLOUD_PROJECT env variable.
72-
project_number = os.environ.get("CLOUD_ML_PROJECT_ID")
72+
project_number = os.environ.get("GOOGLE_CLOUD_PROJECT") or os.environ.get(
73+
"CLOUD_ML_PROJECT_ID"
74+
)
7375
if project_number:
7476
if not self._credentials:
7577
credentials, _ = google.auth.default()
@@ -312,7 +314,7 @@ def location(self) -> str:
312314
if self._location:
313315
return self._location
314316

315-
location = os.getenv("CLOUD_ML_REGION")
317+
location = os.getenv("GOOGLE_CLOUD_REGION") or os.getenv("CLOUD_ML_REGION")
316318
if location:
317319
utils.validate_region(location)
318320
return location

tests/unit/aiplatform/test_initializer.py

+22
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,28 @@ def mock_get_project_id(project_number: str, **_):
8686
):
8787
assert initializer.global_config.project == _TEST_PROJECT
8888

89+
def test_infer_project_id_with_precedence(self):
90+
lower_precedence_cloud_project_number = "456"
91+
higher_precedence_cloud_project_number = "123"
92+
93+
def mock_get_project_id(project_number: str, **_):
94+
assert project_number == higher_precedence_cloud_project_number
95+
return _TEST_PROJECT
96+
97+
with mock.patch.object(
98+
target=resource_manager_utils,
99+
attribute="get_project_id",
100+
new=mock_get_project_id,
101+
), mock.patch.dict(
102+
os.environ,
103+
{
104+
"GOOGLE_CLOUD_PROJECT": higher_precedence_cloud_project_number,
105+
"CLOUD_ML_PROJECT_ID": lower_precedence_cloud_project_number,
106+
},
107+
clear=True,
108+
):
109+
assert initializer.global_config.project == _TEST_PROJECT
110+
89111
def test_init_location_sets_location(self):
90112
initializer.global_config.init(location=_TEST_LOCATION)
91113
assert initializer.global_config.location == _TEST_LOCATION

0 commit comments

Comments
 (0)