Skip to content

feat: Support for Model Versioning #1438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 128 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from 70 commits
Commits
Show all changes
128 commits
Select commit Hold shift + click to select a range
fb3555d
Initial changes for first ModelRegistry design proposal
Apr 11, 2022
594a0d8
More changes for design doc
Apr 13, 2022
6a91ee2
Merge branch 'main' of https://github.com/googleapis/python-aiplatfor…
Apr 26, 2022
74d8a71
ModelRegistry class implementation
Apr 26, 2022
9c256a5
Added method docs
Apr 27, 2022
ef50cba
Changes from book doc
Apr 29, 2022
974e171
training_jobs versioning changes
May 2, 2022
6425d33
More models.py changes for versioning
May 2, 2022
79bda9b
More version arg plumbing
May 3, 2022
13fe946
Merge branch 'main' of https://github.com/googleapis/python-aiplatfor…
May 3, 2022
df13f71
Tests, implementation changes, and assorted tweaks to make GAPIC stuf…
May 9, 2022
68d8ee2
Merge remote-tracking branch 'origin/main' into goodmansam/design1
May 9, 2022
f7ec0b9
Batch predict versioning, variable name cleanup
May 10, 2022
dbb6157
Reset training_jobs changes to limit scope
May 10, 2022
2fb11f9
Prediction test fixes
May 10, 2022
66f1a2c
Blackend and lint changes
May 11, 2022
07feedc
Merge remote-tracking branch 'origin/main' into goodmansam/modelversi…
May 11, 2022
34750bc
Training jobs versioning support
May 13, 2022
aad969f
Blackend and lint changes
May 13, 2022
e743f52
Added TODO for async support
May 16, 2022
4f9729c
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] May 16, 2022
638b257
Override for _construct_sdk_resource_from_gapic
May 16, 2022
20face1
Merge branch 'goodmansam/modelversioning' of https://github.com/googl…
May 16, 2022
c9dfcda
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] May 16, 2022
c4d8663
Fixed _construct_sdk_resource_from_gapic for Model class and gave doc…
May 16, 2022
ded0b4f
Merge branch 'goodmansam/modelversioning' of https://github.com/googl…
May 16, 2022
f398723
Removed errant futuremanager init
May 16, 2022
7a815e9
Start of new versioning system test
May 17, 2022
d87af04
Pass model version on upload cls init
May 18, 2022
1f3614f
Merge branch 'goodmansam/modelversioning' into goodmansam/versionings…
May 18, 2022
bcf14f4
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] May 18, 2022
aab2488
Fully-fleshed system test for model versioning
May 19, 2022
276fa18
Improvements based on system testing
May 19, 2022
557ad88
Merge branch 'goodmansam/modelversioning' of https://github.com/googl…
May 19, 2022
a14d3f8
Nox fixes
May 19, 2022
2c74f14
Merge branch 'goodmansam/modelversioning' into goodmansam/versionings…
May 19, 2022
68e18b9
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] May 19, 2022
54dcc13
Nox fixes
May 19, 2022
a761cde
Merge branch 'goodmansam/modelversioning' of https://github.com/googl…
gcf-owl-bot[bot] May 19, 2022
0efc1df
Initial commit for vertexpreviews, with compat changes
May 24, 2022
ddb9bea
Merge remote-tracking branch 'origin/main' into vertexpreviews
May 24, 2022
a33999f
Use compat type for DeploymentResourcesType
May 24, 2022
27c2917
Merge branch 'vertexpreviews' into goodmansam/modelversioning
May 24, 2022
14b9e02
Merge branch 'goodmansam/modelversioning' of https://github.com/googl…
May 24, 2022
31c3e02
PR Feedback
May 24, 2022
26d484a
Test compat fixes
May 24, 2022
e4e5fde
More v1->v1beta1 shifts
May 24, 2022
0758bc5
Merge branch 'vertexpreviews' into goodmansam/modelversioning
May 24, 2022
f204288
Plumbing model changes through more tests
May 24, 2022
664f8ec
Merge branch 'vertexpreviews' into goodmansam/trainingjobversioningre…
May 24, 2022
e6a5fb4
chore: release 1.13.1
May 26, 2022
8231a9c
Merge branch 'main' of https://github.com/googleapis/python-aiplatform
May 27, 2022
64d4db4
Merge branch 'main' of https://github.com/googleapis/python-aiplatform
May 31, 2022
cf225da
Merge branch 'main' into vertexpreviews
Jun 1, 2022
acd76de
Merge branch 'vertexpreviews' into goodmansam/versioningfull
Jun 1, 2022
0068762
Merge branch 'goodmansam/versioningsystemtest' into goodmansam/versio…
Jun 1, 2022
d14d626
Test fixes
Jun 1, 2022
544c6d8
Nox run
Jun 1, 2022
91dcc00
Merge branch 'main' of https://github.com/googleapis/python-aiplatform
Jun 8, 2022
59b0708
Merge branch 'main' of https://github.com/googleapis/python-aiplatform
Jun 15, 2022
029f35f
Merge branch 'main' into goodmansam/versioningfull
Jun 15, 2022
466e84a
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Jun 15, 2022
0032a80
Training jobs test fixes
Jun 15, 2022
ef31250
Merge branch 'goodmansam/versioningfull' of https://github.com/google…
Jun 15, 2022
b884751
Merge branch 'main' into goodmansam/versioningfull
samgoodman Jun 15, 2022
da92960
Reverted v1beta1 changes
Jun 15, 2022
dd481fc
Merge branch 'goodmansam/versioningfull' of https://github.com/google…
Jun 15, 2022
6564bb7
Blacken and lint changes
Jun 15, 2022
74a7fdf
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Jun 15, 2022
3779e9c
Merge branch 'goodmansam/versioningfull' of https://github.com/google…
gcf-owl-bot[bot] Jun 15, 2022
d08d15d
Merge remote-tracking branch 'origin/main' into goodmansam/versioning…
Jun 16, 2022
41918fa
PR feedback changes
Jun 16, 2022
0bff90b
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Jun 16, 2022
ddeabf9
Merge branch 'main' into goodmansam/versioningfull
samgoodman Jun 16, 2022
9f7e9ed
Revert "🦉 Updates from OwlBot post-processor"
Jun 17, 2022
c57f305
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Jun 17, 2022
e401532
Revert "🦉 Updates from OwlBot post-processor"
Jun 17, 2022
ee5ff17
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Jun 17, 2022
6774ef2
Revert "🦉 Updates from OwlBot post-processor"
Jun 21, 2022
92dd7eb
Merge branch 'main' into goodmansam/versioningfull
samgoodman Jun 21, 2022
968ee1b
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Jun 21, 2022
4322fe2
fix: Prevent owlbot from re-adding 3.6 dependencies
Jun 21, 2022
7d8baa8
Merge branch 'main' into goodmansam/versioningfull
samgoodman Jun 21, 2022
4fdb245
Test fixes
Jun 22, 2022
a0bd987
Merge remote-tracking branch 'origin/main' into goodmansam/versioning…
Jun 22, 2022
82292a5
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Jun 22, 2022
cd2a7e9
nox blacken
Jun 22, 2022
8ebfe43
Merge branch 'main' into goodmansam/versioningfull
samgoodman Jun 23, 2022
84aab1b
Merge branch 'goodmansam/versioningfull' of https://github.com/google…
Jun 23, 2022
da2c401
Test fixes for 3.6 compat
Jun 23, 2022
c19274e
Test fix for 3.6
Jun 23, 2022
f190b79
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Jun 23, 2022
1ba4499
Quitting the fight against nox
Jun 23, 2022
ca3986c
Merge branch 'goodmansam/versioningfull' of https://github.com/google…
Jun 23, 2022
9d4b2ee
Merge branch 'main' into goodmansam/versioningfull
samgoodman Jun 24, 2022
071797b
Test fixes for python 3.7
Jun 24, 2022
3e5533f
Merge branch 'goodmansam/versioningfull' of https://github.com/google…
Jun 24, 2022
7fe8d2e
Merge branch 'main' into goodmansam/versioningfull
samgoodman Jun 27, 2022
732c1cc
System test cleanup
Jun 27, 2022
67fec89
Merge branch 'main' into goodmansam/versioningfull
samgoodman Jun 27, 2022
5e1b0e8
retrigger checks
Jun 28, 2022
43d14ba
Merge branch 'goodmansam/versioningfull' of https://github.com/google…
Jun 28, 2022
a1a3009
Update google/cloud/aiplatform/models.py
samgoodman Jun 28, 2022
a4a0573
Update google/cloud/aiplatform/models.py
samgoodman Jun 28, 2022
75e1a65
Update google/cloud/aiplatform/models.py
samgoodman Jun 28, 2022
697bb4f
Update google/cloud/aiplatform/models.py
samgoodman Jun 28, 2022
275420f
PR feedback changes
Jun 28, 2022
2abf1aa
Merge branch 'goodmansam/versioningfull' of https://github.com/google…
Jun 28, 2022
d485ce6
Pass location, project, creds to Model Registry
Jun 28, 2022
056d0ee
Credential fix when getting model from registry
Jun 28, 2022
b04b931
Merge branch 'main' into goodmansam/versioningfull
samgoodman Jun 28, 2022
0c61cc6
Copyright update
Jun 28, 2022
cfe8d66
Merge branch 'main' into goodmansam/versioningfull
samgoodman Jun 28, 2022
65d32b1
Fixed issue with Model update trying to update a version, rather than…
Jun 28, 2022
f1fa8b9
Merge branch 'main' into goodmansam/versioningfull
Jun 28, 2022
89d79c2
Merge branch 'main' into goodmansam/versioningfull
samgoodman Jun 28, 2022
3e36ae1
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Jun 28, 2022
2f33b19
Merge branch 'goodmansam/versioningfull' of https://github.com/google…
gcf-owl-bot[bot] Jun 28, 2022
a9dabf0
Revert "fix: Prevent owlbot from re-adding 3.6 dependencies"
Jun 28, 2022
fe5a27b
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Jun 28, 2022
b842b9b
Merge branch 'goodmansam/versioningfull' of https://github.com/google…
gcf-owl-bot[bot] Jun 28, 2022
265b2e0
Merge remote-tracking branch 'origin/main' into goodmansam/versioning…
Jun 28, 2022
b9db115
Nox blacken
Jun 28, 2022
256d44c
Merge branch 'goodmansam/versioningfull' of https://github.com/google…
Jun 28, 2022
8a882b3
Revert "🦉 Updates from OwlBot post-processor"
Jun 28, 2022
1ef4646
Fighting with owlbot
Jun 28, 2022
2ade8c6
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Jun 28, 2022
1a65fdf
Merge branch 'goodmansam/versioningfull' of https://github.com/google…
gcf-owl-bot[bot] Jun 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,9 @@ def _create(
Required. BatchPredictionJob without _gca_resource populated.
model_or_model_name (Union[str, aiplatform.Model]):
Required. Required. A fully-qualified model resource name or
an instance of aiplatform.Model.
an instance of aiplatform.Model. If a resource name, it may
optionally contain a version ID or alias in
{model_name}@{version} form.
gca_batch_prediction_job (gca_bp_job.BatchPredictionJob):
Required. a batch prediction job proto for creating a batch prediction job on Vertex AI.
generate_explanation (bool):
Expand Down Expand Up @@ -742,7 +744,7 @@ def _create(
model_resource_name = (
model_or_model_name
if isinstance(model_or_model_name, str)
else model_or_model_name.resource_name
else model_or_model_name.versioned_resource_name
)

gca_batch_prediction_job.model = model_resource_name
Expand Down
645 changes: 636 additions & 9 deletions google/cloud/aiplatform/models.py

Large diffs are not rendered by default.

635 changes: 634 additions & 1 deletion google/cloud/aiplatform/training_jobs.py

Large diffs are not rendered by default.

138 changes: 138 additions & 0 deletions tests/system/aiplatform/test_model_version_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import tempfile
import uuid

import pytest

from google.cloud import aiplatform
from google.cloud import storage
from google.cloud.aiplatform.models import ModelRegistry

from tests.system.aiplatform import e2e_base
from tests.system.aiplatform import test_model_upload


@pytest.mark.usefixtures("delete_staging_bucket")
class TestVersionManagement(e2e_base.TestEndToEnd):

_temp_prefix = "temp_vertex_sdk_e2e_model_upload_test"

def test_upload_deploy_manage_versioned_model(self, shared_state):
"""Upload XGBoost model from local file and deploy it for prediction. Additionally, update model name, description and labels"""

aiplatform.init(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
)

storage_client = storage.Client(project=e2e_base._PROJECT)
model_blob = storage.Blob.from_string(
uri=test_model_upload._XGBOOST_MODEL_URI, client=storage_client
)
model_path = tempfile.mktemp() + ".my_model.xgb"
model_blob.download_to_filename(filename=model_path)

model_id = "my_model_id" + uuid.uuid4().hex
version_description = "My description"
version_aliases = ["system-test-model", "testing"]

model = aiplatform.Model.upload_xgboost_model_file(
model_file_path=model_path,
version_aliases=version_aliases,
model_id=model_id,
version_description=version_description,
)
shared_state["resources"] = [model]

staging_bucket = storage.Blob.from_string(
uri=model.uri, client=storage_client
).bucket
# Checking that the bucket is auto-generated
assert "-vertex-staging-" in staging_bucket.name

shared_state["bucket"] = staging_bucket

assert model.version_description == version_description
assert model.version_aliases == version_aliases
assert "default" in model.version_aliases

model2 = aiplatform.Model.upload_xgboost_model_file(
model_file_path=model_path, parent_model=model_id, is_default_version=False
)
shared_state["resources"].append(model2)

assert model2.version_id == "2"
assert model2.resource_name == model.resource_name
assert model2.version_aliases == []
"""
# Test predictions use right version
endpoint = model2.deploy(machine_type="n1-standard-2")
shared_state["resources"].append(endpoint)
predict_response = endpoint.predict(instances=[[0, 0, 0]])

assert len(predict_response.predictions) == 1
assert predict_response.model_version_id == '2'
"""

# Test that VersionInfo properties are correct.
# Currently, get_version_info and list_versions don't yield identical
# resource names at this time due to a Model Registry bug (b/233118690)
model_info = model2.versioning_registry.get_version_info("testing")
version_list = model2.versioning_registry.list_versions()
assert len(version_list) == 2
list_info = version_list[0]
assert model_info.version_id == list_info.version_id == model.version_id
assert (
model_info.version_aliases
== list_info.version_aliases
== model.version_aliases
)
assert (
model_info.version_description
== list_info.version_description
== model.version_description
)
assert (
model_info.model_display_name
== list_info.model_display_name
== model.display_name
)
assert (
model_info.version_update_time
== list_info.version_update_time
== model.version_update_time
)

# Test that get_model yields a new instance of `model`
model_clone = model2.versioning_registry.get_model()
assert model.resource_name == model_clone.resource_name
assert model.version_id == model_clone.version_id
assert model.name == model_clone.name

# Test add and removal of aliases
registry = ModelRegistry(model)
registry.add_version_aliases(["new-alias"], "default")
registry.remove_version_aliases(["testing"], "new-alias")
model = registry.get_model("new-alias")
assert "testing" not in model.version_aliases

# Test deletion of a model version
registry.delete_version("2")
versions = registry.list_versions()
assert "2" not in [version.version_id for version in versions]
20 changes: 16 additions & 4 deletions tests/unit/aiplatform/test_automl_forecasting_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def mock_model_service_get():
with mock.patch.object(
model_service_client.ModelServiceClient, "get_model"
) as mock_get_model:
mock_get_model.return_value = gca_model.Model()
mock_get_model.return_value = gca_model.Model(name=_TEST_MODEL_NAME)
yield mock_get_model


Expand Down Expand Up @@ -341,7 +341,9 @@ def test_run_call_pipeline_service_create(
model_from_job.wait()

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME, labels=_TEST_MODEL_LABELS
display_name=_TEST_MODEL_DISPLAY_NAME,
labels=_TEST_MODEL_LABELS,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
Expand Down Expand Up @@ -447,7 +449,9 @@ def test_run_call_pipeline_service_create_with_timeout(
model_from_job.wait()

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME, labels=_TEST_MODEL_LABELS
display_name=_TEST_MODEL_DISPLAY_NAME,
labels=_TEST_MODEL_LABELS,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
Expand Down Expand Up @@ -538,6 +542,7 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
true_managed_model = gca_model.Model(
display_name=_TEST_DISPLAY_NAME,
labels=_TEST_LABELS,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
Expand Down Expand Up @@ -623,7 +628,10 @@ def test_run_call_pipeline_if_set_additional_experiments(
model_from_job.wait()

# Test that if defaults to the job display name
true_managed_model = gca_model.Model(display_name=_TEST_DISPLAY_NAME)
true_managed_model = gca_model.Model(
display_name=_TEST_DISPLAY_NAME,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
dataset_id=mock_dataset_time_series.name,
Expand Down Expand Up @@ -910,6 +918,7 @@ def test_splits_fraction(
true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
Expand Down Expand Up @@ -1017,6 +1026,7 @@ def test_splits_timestamp(
true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
Expand Down Expand Up @@ -1116,6 +1126,7 @@ def test_splits_predefined(
true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
Expand Down Expand Up @@ -1211,6 +1222,7 @@ def test_splits_default(
true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
Expand Down
8 changes: 7 additions & 1 deletion tests/unit/aiplatform/test_automl_image_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def mock_model_service_get():
with mock.patch.object(
model_service_client.ModelServiceClient, "get_model"
) as mock_get_model:
mock_get_model.return_value = gca_model.Model()
mock_get_model.return_value = gca_model.Model(name=_TEST_MODEL_NAME)
yield mock_get_model


Expand Down Expand Up @@ -318,6 +318,7 @@ def test_run_call_pipeline_service_create(
labels=mock_model._gca_resource.labels,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
Expand Down Expand Up @@ -402,6 +403,7 @@ def test_run_call_pipeline_service_create_with_timeout(
labels=mock_model._gca_resource.labels,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
Expand Down Expand Up @@ -460,6 +462,7 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
display_name=_TEST_DISPLAY_NAME,
labels=_TEST_LABELS,
encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
Expand Down Expand Up @@ -645,6 +648,7 @@ def test_splits_fraction(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
Expand Down Expand Up @@ -718,6 +722,7 @@ def test_splits_filter(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
Expand Down Expand Up @@ -782,6 +787,7 @@ def test_splits_default(
display_name=_TEST_MODEL_DISPLAY_NAME,
description=mock_model._gca_resource.description,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
Expand Down
Loading