85
85
struct_pb2 .Value (),
86
86
)
87
87
88
+ _TEST_TRAINING_TASK_INPUTS_WITH_UPTRAIN_BASE_MODEL = json_format .ParseDict (
89
+ {
90
+ "modelType" : "CLOUD" ,
91
+ "budgetMilliNodeHours" : _TEST_TRAINING_BUDGET_MILLI_NODE_HOURS ,
92
+ "multiLabel" : False ,
93
+ "disableEarlyStopping" : _TEST_TRAINING_DISABLE_EARLY_STOPPING ,
94
+ "uptrainBaseModelId" : _TEST_MODEL_ID ,
95
+ },
96
+ struct_pb2 .Value (),
97
+ )
98
+
88
99
_TEST_FRACTION_SPLIT_TRAINING = 0.6
89
100
_TEST_FRACTION_SPLIT_VALIDATION = 0.2
90
101
_TEST_FRACTION_SPLIT_TEST = 0.2
@@ -213,6 +224,20 @@ def mock_model():
213
224
yield model
214
225
215
226
227
+ @pytest .fixture
228
+ def mock_uptrain_base_model ():
229
+ model = mock .MagicMock (models .Model )
230
+ model .name = _TEST_MODEL_ID
231
+ model ._latest_future = None
232
+ model ._exception = None
233
+ model ._gca_resource = gca_model .Model (
234
+ display_name = _TEST_MODEL_DISPLAY_NAME ,
235
+ description = "This is the mock uptrain base Model's description" ,
236
+ name = _TEST_MODEL_NAME ,
237
+ )
238
+ yield model
239
+
240
+
216
241
@pytest .mark .usefixtures ("google_auth_mock" )
217
242
class TestAutoMLImageTrainingJob :
218
243
def setup_method (self ):
@@ -223,7 +248,7 @@ def teardown_method(self):
223
248
initializer .global_pool .shutdown (wait = True )
224
249
225
250
def test_init_all_parameters (self , mock_model ):
226
- """Ensure all private members are set correctly at initialization"""
251
+ """Ensure all private members are set correctly at initialization. """
227
252
228
253
aiplatform .init (project = _TEST_PROJECT )
229
254
@@ -275,7 +300,7 @@ def test_run_call_pipeline_service_create(
275
300
mock_pipeline_service_get ,
276
301
mock_dataset_image ,
277
302
mock_model_service_get ,
278
- mock_model ,
303
+ mock_uptrain_base_model ,
279
304
sync ,
280
305
):
281
306
"""Create and run an AutoML ICN training job, verify calls and return value"""
@@ -287,7 +312,7 @@ def test_run_call_pipeline_service_create(
287
312
288
313
job = training_jobs .AutoMLImageTrainingJob (
289
314
display_name = _TEST_DISPLAY_NAME ,
290
- base_model = mock_model ,
315
+ incremental_train_base_model = mock_uptrain_base_model ,
291
316
labels = _TEST_LABELS ,
292
317
)
293
318
@@ -315,8 +340,7 @@ def test_run_call_pipeline_service_create(
315
340
316
341
true_managed_model = gca_model .Model (
317
342
display_name = _TEST_MODEL_DISPLAY_NAME ,
318
- labels = mock_model ._gca_resource .labels ,
319
- description = mock_model ._gca_resource .description ,
343
+ labels = _TEST_MODEL_LABELS ,
320
344
encryption_spec = _TEST_DEFAULT_ENCRYPTION_SPEC ,
321
345
version_aliases = ["default" ],
322
346
)
@@ -330,7 +354,7 @@ def test_run_call_pipeline_service_create(
330
354
display_name = _TEST_DISPLAY_NAME ,
331
355
labels = _TEST_LABELS ,
332
356
training_task_definition = schema .training_job .definition .automl_image_classification ,
333
- training_task_inputs = _TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL ,
357
+ training_task_inputs = _TEST_TRAINING_TASK_INPUTS_WITH_UPTRAIN_BASE_MODEL ,
334
358
model_to_upload = true_managed_model ,
335
359
input_data_config = true_input_data_config ,
336
360
encryption_spec = _TEST_DEFAULT_ENCRYPTION_SPEC ,
@@ -754,7 +778,7 @@ def test_splits_default(
754
778
mock_pipeline_service_get ,
755
779
mock_dataset_image ,
756
780
mock_model_service_get ,
757
- mock_model ,
781
+ mock_uptrain_base_model ,
758
782
sync ,
759
783
):
760
784
"""
@@ -768,7 +792,8 @@ def test_splits_default(
768
792
)
769
793
770
794
job = training_jobs .AutoMLImageTrainingJob (
771
- display_name = _TEST_DISPLAY_NAME , base_model = mock_model
795
+ display_name = _TEST_DISPLAY_NAME ,
796
+ incremental_train_base_model = mock_uptrain_base_model ,
772
797
)
773
798
774
799
model_from_job = job .run (
@@ -785,7 +810,6 @@ def test_splits_default(
785
810
786
811
true_managed_model = gca_model .Model (
787
812
display_name = _TEST_MODEL_DISPLAY_NAME ,
788
- description = mock_model ._gca_resource .description ,
789
813
encryption_spec = _TEST_DEFAULT_ENCRYPTION_SPEC ,
790
814
version_aliases = ["default" ],
791
815
)
@@ -797,7 +821,7 @@ def test_splits_default(
797
821
true_training_pipeline = gca_training_pipeline .TrainingPipeline (
798
822
display_name = _TEST_DISPLAY_NAME ,
799
823
training_task_definition = schema .training_job .definition .automl_image_classification ,
800
- training_task_inputs = _TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL ,
824
+ training_task_inputs = _TEST_TRAINING_TASK_INPUTS_WITH_UPTRAIN_BASE_MODEL ,
801
825
model_to_upload = true_managed_model ,
802
826
input_data_config = true_input_data_config ,
803
827
encryption_spec = _TEST_DEFAULT_ENCRYPTION_SPEC ,
0 commit comments