Skip to content

Commit d06b22d

Browse files
ucdmktcopybara-github
authored andcommitted
fix: address broken unit tests in certain environments
PiperOrigin-RevId: 501875885
1 parent 9ffd173 commit d06b22d

File tree

4 files changed

+28
-6
lines changed

4 files changed

+28
-6
lines changed

google/cloud/aiplatform/vizier/pyvizier/study_config.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,36 @@ class SearchSpace(SearchSpace):
117117
@classmethod
118118
def from_proto(cls, proto: study_pb2.StudySpec) -> "SearchSpace":
119119
"""Extracts a SearchSpace object from a StudyConfig proto."""
120-
parameter_configs = []
120+
121+
# For google-vizier <= 0.0.15
122+
if hasattr(cls, "_factory"):
123+
parameter_configs = []
124+
for pc in proto.parameters:
125+
parameter_configs.append(
126+
proto_converters.ParameterConfigConverter.from_proto(pc)
127+
)
128+
return cls._factory(parameter_configs=parameter_configs)
129+
130+
result = cls()
121131
for pc in proto.parameters:
122-
parameter_configs.append(
123-
proto_converters.ParameterConfigConverter.from_proto(pc)
124-
)
125-
return cls._factory(parameter_configs=parameter_configs)
132+
result.add(proto_converters.ParameterConfigConverter.from_proto(pc))
133+
134+
return result
126135

127136
@property
128137
def parameter_protos(self) -> List[study_pb2.StudySpec.ParameterSpec]:
129138
"""Returns the search space as a List of ParameterConfig protos."""
139+
140+
# For google-vizier <= 0.0.15
141+
if isinstance(self._parameter_configs, list):
142+
return [
143+
proto_converters.ParameterConfigConverter.to_proto(pc)
144+
for pc in self._parameter_configs
145+
]
146+
130147
return [
131148
proto_converters.ParameterConfigConverter.to_proto(pc)
132-
for pc in self._parameter_configs
149+
for _, pc in self._parameter_configs.items()
133150
]
134151

135152

tests/unit/aiplatform/test_metadata_resources.py

+2
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ def list_artifact_empty_mock():
614614
yield list_artifacts_mock
615615

616616

617+
@pytest.mark.usefixtures("google_auth_mock")
617618
class TestExecution:
618619
def setup_method(self):
619620
reload(initializer)
@@ -893,6 +894,7 @@ def test_query_input_and_output_artifacts(
893894
assert artifact_list[0]._gca_resource == expected_artifact
894895

895896

897+
@pytest.mark.usefixtures("google_auth_mock")
896898
class TestArtifact:
897899
def setup_method(self):
898900
reload(initializer)

tests/unit/aiplatform/test_metadata_store.py

+1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def delete_metadata_store_mock():
134134
yield delete_metadata_store_mock
135135

136136

137+
@pytest.mark.usefixtures("google_auth_mock")
137138
class TestMetadataStore:
138139
def setup_method(self):
139140
reload(initializer)

tests/unit/aiplatform/test_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ def test_wrapped_client():
390390
)
391391

392392

393+
@pytest.mark.usefixtures("google_auth_mock")
393394
def test_client_w_override_default_version():
394395

395396
test_client_info = gapic_v1.client_info.ClientInfo()
@@ -407,6 +408,7 @@ def test_client_w_override_default_version():
407408
)
408409

409410

411+
@pytest.mark.usefixtures("google_auth_mock")
410412
def test_client_w_override_select_version():
411413

412414
test_client_info = gapic_v1.client_info.ClientInfo()

0 commit comments

Comments
 (0)