Skip to content

Commit bb92380

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add incremental training to AutoMLImageTrainingJob.
PiperOrigin-RevId: 517272484
1 parent 091d74f commit bb92380

File tree

2 files changed

+55
-10
lines changed

2 files changed

+55
-10
lines changed

google/cloud/aiplatform/training_jobs.py

+21
Original file line numberDiff line numberDiff line change
@@ -5270,6 +5270,7 @@ def __init__(
52705270
multi_label: bool = False,
52715271
model_type: str = "CLOUD",
52725272
base_model: Optional[models.Model] = None,
5273+
incremental_train_base_model: Optional[models.Model] = None,
52735274
project: Optional[str] = None,
52745275
location: Optional[str] = None,
52755276
credentials: Optional[auth_credentials.Credentials] = None,
@@ -5335,6 +5336,12 @@ def __init__(
53355336
Otherwise, the new model will be trained from scratch. The `base` model
53365337
must be in the same Project and Location as the new Model to train,
53375338
and have the same model_type.
5339+
incremental_train_base_model: Optional[models.Model] = None
5340+
Optional for both Image Classification and Object detection models, to
5341+
incrementally train a new model using an existing model as the starting point, with
5342+
a reduced training time. If not specified, the new model will be trained from scratch.
5343+
The `base` model must be in the same Project and Location as the new Model to train,
5344+
and have the same prediction_type and model_type.
53385345
project (str):
53395346
Optional. Project to run training in. Overrides project set in aiplatform.init.
53405347
location (str):
@@ -5423,6 +5430,7 @@ def __init__(
54235430
self._prediction_type = prediction_type
54245431
self._multi_label = multi_label
54255432
self._base_model = base_model
5433+
self._incremental_train_base_model = incremental_train_base_model
54265434

54275435
def run(
54285436
self,
@@ -5603,6 +5611,7 @@ def run(
56035611
return self._run(
56045612
dataset=dataset,
56055613
base_model=self._base_model,
5614+
incremental_train_base_model=self._incremental_train_base_model,
56065615
training_fraction_split=training_fraction_split,
56075616
validation_fraction_split=validation_fraction_split,
56085617
test_fraction_split=test_fraction_split,
@@ -5627,6 +5636,7 @@ def _run(
56275636
self,
56285637
dataset: datasets.ImageDataset,
56295638
base_model: Optional[models.Model] = None,
5639+
incremental_train_base_model: Optional[models.Model] = None,
56305640
training_fraction_split: Optional[float] = None,
56315641
validation_fraction_split: Optional[float] = None,
56325642
test_fraction_split: Optional[float] = None,
@@ -5681,6 +5691,12 @@ def _run(
56815691
Otherwise, the new model will be trained from scratch. The `base` model
56825692
must be in the same Project and Location as the new Model to train,
56835693
and have the same model_type.
5694+
incremental_train_base_model: Optional[models.Model] = None
5695+
Optional for both Image Classification and Object detection models, to
5696+
incrementally train a new model using an existing model as the starting point, with
5697+
a reduced training time. If not specified, the new model will be trained from scratch.
5698+
The `base` model must be in the same Project and Location as the new Model to train,
5699+
and have the same prediction_type and model_type.
56845700
model_id (str):
56855701
Optional. The ID to use for the Model produced by this job,
56865702
which will become the final component of the model resource name.
@@ -5818,6 +5834,11 @@ def _run(
58185834
# Set ID of Vertex AI Model to base this training job off of
58195835
training_task_inputs_dict["baseModelId"] = base_model.name
58205836

5837+
if incremental_train_base_model:
5838+
training_task_inputs_dict[
5839+
"uptrainBaseModelId"
5840+
] = incremental_train_base_model.name
5841+
58215842
return self._run_job(
58225843
training_task_definition=training_task_definition,
58235844
training_task_inputs=training_task_inputs_dict,

tests/unit/aiplatform/test_automl_image_training_jobs.py

+34-10
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,17 @@
8585
struct_pb2.Value(),
8686
)
8787

88+
_TEST_TRAINING_TASK_INPUTS_WITH_UPTRAIN_BASE_MODEL = json_format.ParseDict(
89+
{
90+
"modelType": "CLOUD",
91+
"budgetMilliNodeHours": _TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
92+
"multiLabel": False,
93+
"disableEarlyStopping": _TEST_TRAINING_DISABLE_EARLY_STOPPING,
94+
"uptrainBaseModelId": _TEST_MODEL_ID,
95+
},
96+
struct_pb2.Value(),
97+
)
98+
8899
_TEST_FRACTION_SPLIT_TRAINING = 0.6
89100
_TEST_FRACTION_SPLIT_VALIDATION = 0.2
90101
_TEST_FRACTION_SPLIT_TEST = 0.2
@@ -213,6 +224,20 @@ def mock_model():
213224
yield model
214225

215226

227+
@pytest.fixture
228+
def mock_uptrain_base_model():
229+
model = mock.MagicMock(models.Model)
230+
model.name = _TEST_MODEL_ID
231+
model._latest_future = None
232+
model._exception = None
233+
model._gca_resource = gca_model.Model(
234+
display_name=_TEST_MODEL_DISPLAY_NAME,
235+
description="This is the mock uptrain base Model's description",
236+
name=_TEST_MODEL_NAME,
237+
)
238+
yield model
239+
240+
216241
@pytest.mark.usefixtures("google_auth_mock")
217242
class TestAutoMLImageTrainingJob:
218243
def setup_method(self):
@@ -223,7 +248,7 @@ def teardown_method(self):
223248
initializer.global_pool.shutdown(wait=True)
224249

225250
def test_init_all_parameters(self, mock_model):
226-
"""Ensure all private members are set correctly at initialization"""
251+
"""Ensure all private members are set correctly at initialization."""
227252

228253
aiplatform.init(project=_TEST_PROJECT)
229254

@@ -275,7 +300,7 @@ def test_run_call_pipeline_service_create(
275300
mock_pipeline_service_get,
276301
mock_dataset_image,
277302
mock_model_service_get,
278-
mock_model,
303+
mock_uptrain_base_model,
279304
sync,
280305
):
281306
"""Create and run an AutoML ICN training job, verify calls and return value"""
@@ -287,7 +312,7 @@ def test_run_call_pipeline_service_create(
287312

288313
job = training_jobs.AutoMLImageTrainingJob(
289314
display_name=_TEST_DISPLAY_NAME,
290-
base_model=mock_model,
315+
incremental_train_base_model=mock_uptrain_base_model,
291316
labels=_TEST_LABELS,
292317
)
293318

@@ -315,8 +340,7 @@ def test_run_call_pipeline_service_create(
315340

316341
true_managed_model = gca_model.Model(
317342
display_name=_TEST_MODEL_DISPLAY_NAME,
318-
labels=mock_model._gca_resource.labels,
319-
description=mock_model._gca_resource.description,
343+
labels=_TEST_MODEL_LABELS,
320344
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
321345
version_aliases=["default"],
322346
)
@@ -330,7 +354,7 @@ def test_run_call_pipeline_service_create(
330354
display_name=_TEST_DISPLAY_NAME,
331355
labels=_TEST_LABELS,
332356
training_task_definition=schema.training_job.definition.automl_image_classification,
333-
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL,
357+
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_UPTRAIN_BASE_MODEL,
334358
model_to_upload=true_managed_model,
335359
input_data_config=true_input_data_config,
336360
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
@@ -754,7 +778,7 @@ def test_splits_default(
754778
mock_pipeline_service_get,
755779
mock_dataset_image,
756780
mock_model_service_get,
757-
mock_model,
781+
mock_uptrain_base_model,
758782
sync,
759783
):
760784
"""
@@ -768,7 +792,8 @@ def test_splits_default(
768792
)
769793

770794
job = training_jobs.AutoMLImageTrainingJob(
771-
display_name=_TEST_DISPLAY_NAME, base_model=mock_model
795+
display_name=_TEST_DISPLAY_NAME,
796+
incremental_train_base_model=mock_uptrain_base_model,
772797
)
773798

774799
model_from_job = job.run(
@@ -785,7 +810,6 @@ def test_splits_default(
785810

786811
true_managed_model = gca_model.Model(
787812
display_name=_TEST_MODEL_DISPLAY_NAME,
788-
description=mock_model._gca_resource.description,
789813
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
790814
version_aliases=["default"],
791815
)
@@ -797,7 +821,7 @@ def test_splits_default(
797821
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
798822
display_name=_TEST_DISPLAY_NAME,
799823
training_task_definition=schema.training_job.definition.automl_image_classification,
800-
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL,
824+
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_UPTRAIN_BASE_MODEL,
801825
model_to_upload=true_managed_model,
802826
input_data_config=true_input_data_config,
803827
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,

0 commit comments

Comments
 (0)