@@ -556,6 +556,23 @@ def get_artifact_mock():
556
556
yield get_artifact_mock
557
557
558
558
559
+ @pytest .fixture
560
+ def get_artifact_mock_with_metadata ():
561
+ with patch .object (MetadataServiceClient , "get_artifact" ) as get_artifact_mock :
562
+ get_artifact_mock .return_value = GapicArtifact (
563
+ name = _TEST_ARTIFACT_NAME ,
564
+ display_name = _TEST_ARTIFACT_ID ,
565
+ schema_title = constants .SYSTEM_METRICS ,
566
+ schema_version = constants .SCHEMA_VERSIONS [constants .SYSTEM_METRICS ],
567
+ metadata = {
568
+ google .cloud .aiplatform .metadata .constants ._VERTEX_EXPERIMENT_TRACKING_LABEL : True ,
569
+ constants .GCP_ARTIFACT_RESOURCE_NAME_KEY : test_constants .TensorboardConstants ._TEST_TENSORBOARD_RUN_NAME ,
570
+ constants ._STATE_KEY : gca_execution .Execution .State .RUNNING ,
571
+ },
572
+ )
573
+ yield get_artifact_mock
574
+
575
+
559
576
@pytest .fixture
560
577
def get_artifact_not_found_mock ():
561
578
with patch .object (MetadataServiceClient , "get_artifact" ) as get_artifact_mock :
@@ -2026,6 +2043,27 @@ def test_experiment_run_get_logged_custom_jobs(self, get_custom_job_mock):
2026
2043
retry = base ._DEFAULT_RETRY ,
2027
2044
)
2028
2045
2046
+ @pytest .mark .usefixtures (
2047
+ "get_metadata_store_mock" ,
2048
+ "get_experiment_mock" ,
2049
+ "get_experiment_run_mock" ,
2050
+ "get_context_mock" ,
2051
+ "list_contexts_mock" ,
2052
+ "list_executions_mock" ,
2053
+ "get_artifact_mock_with_metadata" ,
2054
+ "update_context_mock" ,
2055
+ )
2056
+ def test_update_experiment_run_after_list (
2057
+ self ,
2058
+ ):
2059
+ aiplatform .init (
2060
+ project = _TEST_PROJECT ,
2061
+ location = _TEST_LOCATION ,
2062
+ )
2063
+
2064
+ experiment_run_list = aiplatform .ExperimentRun .list (experiment = _TEST_EXPERIMENT )
2065
+ experiment_run_list [0 ].update_state (gca_execution .Execution .State .FAILED )
2066
+
2029
2067
2030
2068
class TestTensorboard :
2031
2069
def test_get_or_create_default_tb_with_existing_default (
0 commit comments