Skip to content

Commit 94dd82f

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Implement Model.copy functionality.
PiperOrigin-RevId: 516678124
1 parent 8cb4377 commit 94dd82f

File tree

2 files changed

+198
-0
lines changed

2 files changed

+198
-0
lines changed

google/cloud/aiplatform/models.py

+119
Original file line numberDiff line numberDiff line change
@@ -4656,6 +4656,125 @@ def upload_tensorflow_saved_model(
46564656
upload_request_timeout=upload_request_timeout,
46574657
)
46584658

4659+
# TODO(b/273499620): Add async support.
4660+
def copy(
4661+
self,
4662+
destination_location: str,
4663+
destination_model_id: Optional[str] = None,
4664+
destination_parent_model: Optional[str] = None,
4665+
encryption_spec_key_name: Optional[str] = None,
4666+
copy_request_timeout: Optional[float] = None,
4667+
) -> "Model":
4668+
"""Copys a model and returns a Model representing the copied Model
4669+
resource. This method is a blocking call.
4670+
4671+
Example usage:
4672+
copied_model = my_model.copy(
4673+
destination_location="us-central1"
4674+
)
4675+
4676+
Args:
4677+
destination_location (str):
4678+
The destination location to copy the model to.
4679+
destination_model_id (str):
4680+
Optional. The ID to use for the copied Model, which will
4681+
become the final component of the model resource name.
4682+
This value may be up to 63 characters, and valid characters
4683+
are `[a-z0-9_-]`. The first character cannot be a number or hyphen.
4684+
4685+
Only set this field when copying as a new model. If this field is not set,
4686+
a numeric model id will be generated.
4687+
destination_parent_model (str):
4688+
Optional. The resource name or model ID of an existing model that the
4689+
newly-copied model will be a version of.
4690+
4691+
Only set this field when copying as a new version of an existing model.
4692+
encryption_spec_key_name (Optional[str]):
4693+
Optional. The Cloud KMS resource identifier of the customer
4694+
managed encryption key used to protect the model. Has the
4695+
form:
4696+
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
4697+
The key needs to be in the same region as where the compute
4698+
resource is created.
4699+
4700+
If set, this Model and all sub-resources of this Model will be secured by this key.
4701+
4702+
Overrides encryption_spec_key_name set in aiplatform.init.
4703+
copy_request_timeout (float):
4704+
Optional. The timeout for the copy request in seconds.
4705+
4706+
Returns:
4707+
model (aiplatform.Model):
4708+
Instantiated representation of the copied model resource.
4709+
4710+
Raises:
4711+
ValueError: If both `destination_model_id` and `destination_parent_model` are set.
4712+
"""
4713+
if destination_model_id is not None and destination_parent_model is not None:
4714+
raise ValueError(
4715+
"`destination_model_id` and `destination_parent_model` can not be set together."
4716+
)
4717+
4718+
parent = initializer.global_config.common_location_path(
4719+
initializer.global_config.project, destination_location
4720+
)
4721+
4722+
source_model = self.versioned_resource_name
4723+
4724+
destination_parent_model = ModelRegistry._get_true_version_parent(
4725+
parent_model=destination_parent_model,
4726+
project=initializer.global_config.project,
4727+
location=destination_location,
4728+
)
4729+
4730+
encryption_spec = initializer.global_config.get_encryption_spec(
4731+
encryption_spec_key_name=encryption_spec_key_name,
4732+
)
4733+
4734+
if destination_model_id is not None:
4735+
request = gca_model_service_compat.CopyModelRequest(
4736+
parent=parent,
4737+
source_model=source_model,
4738+
model_id=destination_model_id,
4739+
encryption_spec=encryption_spec,
4740+
)
4741+
else:
4742+
request = gca_model_service_compat.CopyModelRequest(
4743+
parent=parent,
4744+
source_model=source_model,
4745+
parent_model=destination_parent_model,
4746+
encryption_spec=encryption_spec,
4747+
)
4748+
4749+
api_client = initializer.global_config.create_client(
4750+
client_class=utils.ModelClientWithOverride,
4751+
location_override=destination_location,
4752+
credentials=initializer.global_config.credentials,
4753+
)
4754+
4755+
_LOGGER.log_action_start_against_resource("Copying", "", self)
4756+
4757+
lro = api_client.copy_model(
4758+
request=request,
4759+
timeout=copy_request_timeout,
4760+
)
4761+
4762+
_LOGGER.log_action_started_against_resource_with_lro(
4763+
"Copy", "", self.__class__, lro
4764+
)
4765+
4766+
model_copy_response = lro.result(timeout=None)
4767+
4768+
this_model = models.Model(
4769+
model_copy_response.model,
4770+
version=model_copy_response.model_version_id,
4771+
location=destination_location,
4772+
)
4773+
4774+
_LOGGER.log_action_completed_against_resource("", "copied", this_model)
4775+
4776+
return this_model
4777+
46594778
def list_model_evaluations(
46604779
self,
46614780
) -> List["model_evaluation.ModelEvaluation"]:

tests/unit/aiplatform/test_models.py

+79
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
7373
_TEST_MODEL_NAME = "123"
7474
_TEST_MODEL_NAME_ALT = "456"
75+
_TEST_MODEL_ID = "my-model"
7576
_TEST_MODEL_PARENT = (
7677
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_MODEL_NAME}"
7778
)
@@ -581,6 +582,19 @@ def delete_model_mock():
581582
yield delete_model_mock
582583

583584

585+
@pytest.fixture
586+
def copy_model_mock():
587+
with mock.patch.object(
588+
model_service_client.ModelServiceClient, "copy_model"
589+
) as copy_model_mock:
590+
mock_lro = mock.Mock(ga_operation.Operation)
591+
mock_lro.result.return_value = gca_model_service.CopyModelResponse(
592+
model=_TEST_MODEL_RESOURCE_NAME_CUSTOM_LOCATION
593+
)
594+
copy_model_mock.return_value = mock_lro
595+
yield copy_model_mock
596+
597+
584598
@pytest.fixture
585599
def deploy_model_mock():
586600
with mock.patch.object(
@@ -2419,6 +2433,71 @@ def test_upload_tensorflow_saved_model_uploads_and_gets_model(
24192433
staged_model_file_name = staged_model_file_path.split("/")[-1]
24202434
assert staged_model_file_name in ["saved_model.pb", "saved_model.pbtxt"]
24212435

2436+
def test_copy_as_new_model(self, copy_model_mock, get_model_mock):
2437+
2438+
test_model = models.Model(_TEST_ID)
2439+
test_model.copy(destination_location=_TEST_LOCATION_2)
2440+
2441+
copy_model_mock.assert_called_once_with(
2442+
request=gca_model_service.CopyModelRequest(
2443+
parent=initializer.global_config.common_location_path(
2444+
location=_TEST_LOCATION_2
2445+
),
2446+
source_model=_TEST_MODEL_RESOURCE_NAME,
2447+
),
2448+
timeout=None,
2449+
)
2450+
2451+
def test_copy_as_new_version(self, copy_model_mock, get_model_mock):
2452+
test_model = models.Model(_TEST_ID)
2453+
test_model.copy(
2454+
destination_location=_TEST_LOCATION_2,
2455+
destination_parent_model=_TEST_MODEL_NAME_ALT,
2456+
)
2457+
2458+
copy_model_mock.assert_called_once_with(
2459+
request=gca_model_service.CopyModelRequest(
2460+
parent=initializer.global_config.common_location_path(
2461+
location=_TEST_LOCATION_2
2462+
),
2463+
source_model=_TEST_MODEL_RESOURCE_NAME,
2464+
parent_model=model_service_client.ModelServiceClient.model_path(
2465+
_TEST_PROJECT, _TEST_LOCATION_2, _TEST_MODEL_NAME_ALT
2466+
),
2467+
),
2468+
timeout=None,
2469+
)
2470+
2471+
def test_copy_as_new_model_custom_id(self, copy_model_mock, get_model_mock):
2472+
test_model = models.Model(_TEST_ID)
2473+
test_model.copy(
2474+
destination_location=_TEST_LOCATION_2, destination_model_id=_TEST_MODEL_ID
2475+
)
2476+
2477+
copy_model_mock.assert_called_once_with(
2478+
request=gca_model_service.CopyModelRequest(
2479+
parent=initializer.global_config.common_location_path(
2480+
location=_TEST_LOCATION_2
2481+
),
2482+
source_model=_TEST_MODEL_RESOURCE_NAME,
2483+
model_id=_TEST_MODEL_ID,
2484+
),
2485+
timeout=None,
2486+
)
2487+
2488+
def test_copy_with_invalid_params(self, copy_model_mock, get_model_mock):
2489+
with pytest.raises(ValueError) as e:
2490+
test_model = models.Model(_TEST_ID)
2491+
test_model.copy(
2492+
destination_location=_TEST_LOCATION,
2493+
destination_model_id=_TEST_MODEL_ID,
2494+
destination_parent_model=_TEST_MODEL_RESOURCE_NAME,
2495+
)
2496+
2497+
assert e.match(
2498+
regexp=r"`destination_model_id` and `destination_parent_model` can not be set together."
2499+
)
2500+
24222501
@pytest.mark.usefixtures("get_model_mock")
24232502
def test_update(self, update_model_mock, get_model_mock):
24242503

0 commit comments

Comments
 (0)