1
1
# -*- coding: utf-8 -*-
2
2
3
- # Copyright 2022 Google LLC
3
+ # Copyright 2023 Google LLC
4
4
#
5
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
6
# you may not use this file except in compliance with the License.
@@ -280,6 +280,17 @@ def get_execution_mock():
280
280
yield get_execution_mock
281
281
282
282
283
+ @pytest .fixture
284
+ def get_execution_not_found_mock ():
285
+ with patch .object (
286
+ MetadataServiceClient , "get_execution"
287
+ ) as get_execution_not_found_mock :
288
+ get_execution_not_found_mock .side_effect = exceptions .NotFound (
289
+ "test: not found"
290
+ )
291
+ yield get_execution_not_found_mock
292
+
293
+
283
294
@pytest .fixture
284
295
def get_execution_wrong_schema_mock ():
285
296
with patch .object (
@@ -681,6 +692,13 @@ def get_experiment_mock():
681
692
yield get_context_mock
682
693
683
694
695
+ @pytest .fixture
696
+ def get_experiment_not_found_mock ():
697
+ with patch .object (MetadataServiceClient , "get_context" ) as get_context_mock :
698
+ get_context_mock .side_effect = exceptions .NotFound ("test: not found" )
699
+ yield get_context_mock
700
+
701
+
684
702
@pytest .fixture
685
703
def get_experiment_run_run_mock ():
686
704
with patch .object (MetadataServiceClient , "get_context" ) as get_context_mock :
@@ -704,6 +722,17 @@ def get_experiment_run_mock():
704
722
yield get_context_mock
705
723
706
724
725
+ @pytest .fixture
726
+ def get_experiment_run_not_found_mock ():
727
+ with patch .object (MetadataServiceClient , "get_context" ) as get_context_mock :
728
+ get_context_mock .side_effect = [
729
+ _EXPERIMENT_MOCK ,
730
+ exceptions .NotFound ("test: not found" ),
731
+ ]
732
+
733
+ yield get_context_mock
734
+
735
+
707
736
@pytest .fixture
708
737
def create_experiment_context_mock ():
709
738
with patch .object (MetadataServiceClient , "create_context" ) as create_context_mock :
@@ -1125,6 +1154,66 @@ def test_init_experiment_wrong_schema(self):
1125
1154
experiment = _TEST_EXPERIMENT ,
1126
1155
)
1127
1156
1157
+ def test_get_experiment (self , get_experiment_mock ):
1158
+ aiplatform .init (
1159
+ project = _TEST_PROJECT ,
1160
+ location = _TEST_LOCATION ,
1161
+ )
1162
+
1163
+ exp = aiplatform .Experiment .get (_TEST_EXPERIMENT )
1164
+
1165
+ assert exp .name == _TEST_EXPERIMENT
1166
+ get_experiment_mock .assert_called_with (
1167
+ name = _TEST_CONTEXT_NAME , retry = base ._DEFAULT_RETRY
1168
+ )
1169
+
1170
+ def test_get_experiment_not_found (self , get_experiment_not_found_mock ):
1171
+ aiplatform .init (
1172
+ project = _TEST_PROJECT ,
1173
+ location = _TEST_LOCATION ,
1174
+ )
1175
+
1176
+ exp = aiplatform .Experiment .get (_TEST_EXPERIMENT )
1177
+
1178
+ assert exp is None
1179
+ get_experiment_not_found_mock .assert_called_with (
1180
+ name = _TEST_CONTEXT_NAME , retry = base ._DEFAULT_RETRY
1181
+ )
1182
+
1183
+ @pytest .mark .usefixtures (
1184
+ "get_metadata_store_mock" , "get_tensorboard_run_artifact_not_found_mock"
1185
+ )
1186
+ def test_get_experiment_run (self , get_experiment_run_mock ):
1187
+ aiplatform .init (
1188
+ project = _TEST_PROJECT ,
1189
+ location = _TEST_LOCATION ,
1190
+ )
1191
+
1192
+ run = aiplatform .ExperimentRun .get (_TEST_RUN , experiment = _TEST_EXPERIMENT )
1193
+
1194
+ assert run .name == _TEST_RUN
1195
+ get_experiment_run_mock .assert_called_with (
1196
+ name = f"{ _TEST_CONTEXT_NAME } -{ _TEST_RUN } " , retry = base ._DEFAULT_RETRY
1197
+ )
1198
+
1199
+ @pytest .mark .usefixtures (
1200
+ "get_metadata_store_mock" ,
1201
+ "get_tensorboard_run_artifact_not_found_mock" ,
1202
+ "get_execution_not_found_mock" ,
1203
+ )
1204
+ def test_get_experiment_run_not_found (self , get_experiment_run_not_found_mock ):
1205
+ aiplatform .init (
1206
+ project = _TEST_PROJECT ,
1207
+ location = _TEST_LOCATION ,
1208
+ )
1209
+
1210
+ run = aiplatform .ExperimentRun .get (_TEST_RUN , experiment = _TEST_EXPERIMENT )
1211
+
1212
+ assert run is None
1213
+ get_experiment_run_not_found_mock .assert_called_with (
1214
+ name = f"{ _TEST_CONTEXT_NAME } -{ _TEST_RUN } " , retry = base ._DEFAULT_RETRY
1215
+ )
1216
+
1128
1217
@pytest .mark .usefixtures ("get_metadata_store_mock" )
1129
1218
def test_start_run (
1130
1219
self ,
0 commit comments