Skip to content

Commit cb255ec

Browse files
sararobcopybara-github
authored andcommitted
fix: fix error when calling update_state() after ExperimentRun.list()
PiperOrigin-RevId: 544159034
1 parent d6476d0 commit cb255ec

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

google/cloud/aiplatform/metadata/resource.py

+2
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,8 @@ def update(
297297
Custom credentials to use to update this resource. Overrides
298298
credentials set in aiplatform.init.
299299
"""
300+
if not hasattr(self, "_threading_lock"):
301+
self._threading_lock = threading.Lock()
300302

301303
with self._threading_lock:
302304
gca_resource = deepcopy(self._gca_resource)

tests/unit/aiplatform/test_metadata.py

+38
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,23 @@ def get_artifact_mock():
556556
yield get_artifact_mock
557557

558558

559+
@pytest.fixture
560+
def get_artifact_mock_with_metadata():
561+
with patch.object(MetadataServiceClient, "get_artifact") as get_artifact_mock:
562+
get_artifact_mock.return_value = GapicArtifact(
563+
name=_TEST_ARTIFACT_NAME,
564+
display_name=_TEST_ARTIFACT_ID,
565+
schema_title=constants.SYSTEM_METRICS,
566+
schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS],
567+
metadata={
568+
google.cloud.aiplatform.metadata.constants._VERTEX_EXPERIMENT_TRACKING_LABEL: True,
569+
constants.GCP_ARTIFACT_RESOURCE_NAME_KEY: test_constants.TensorboardConstants._TEST_TENSORBOARD_RUN_NAME,
570+
constants._STATE_KEY: gca_execution.Execution.State.RUNNING,
571+
},
572+
)
573+
yield get_artifact_mock
574+
575+
559576
@pytest.fixture
560577
def get_artifact_not_found_mock():
561578
with patch.object(MetadataServiceClient, "get_artifact") as get_artifact_mock:
@@ -2026,6 +2043,27 @@ def test_experiment_run_get_logged_custom_jobs(self, get_custom_job_mock):
20262043
retry=base._DEFAULT_RETRY,
20272044
)
20282045

2046+
@pytest.mark.usefixtures(
2047+
"get_metadata_store_mock",
2048+
"get_experiment_mock",
2049+
"get_experiment_run_mock",
2050+
"get_context_mock",
2051+
"list_contexts_mock",
2052+
"list_executions_mock",
2053+
"get_artifact_mock_with_metadata",
2054+
"update_context_mock",
2055+
)
2056+
def test_update_experiment_run_after_list(
2057+
self,
2058+
):
2059+
aiplatform.init(
2060+
project=_TEST_PROJECT,
2061+
location=_TEST_LOCATION,
2062+
)
2063+
2064+
experiment_run_list = aiplatform.ExperimentRun.list(experiment=_TEST_EXPERIMENT)
2065+
experiment_run_list[0].update_state(gca_execution.Execution.State.FAILED)
2066+
20292067

20302068
class TestTensorboard:
20312069
def test_get_or_create_default_tb_with_existing_default(

0 commit comments

Comments
 (0)