|
25 | 25 | from google.api_core import operation as ga_operation
|
26 | 26 | from google.auth import credentials as auth_credentials
|
27 | 27 |
|
| 28 | +from google.protobuf import field_mask_pb2 |
| 29 | + |
28 | 30 | from google.cloud import aiplatform
|
29 | 31 | from google.cloud.aiplatform import base
|
30 | 32 | from google.cloud.aiplatform import initializer
|
|
58 | 60 | _TEST_ID_2 = "4366591682456584192"
|
59 | 61 | _TEST_ID_3 = "5820582938582924817"
|
60 | 62 | _TEST_DESCRIPTION = "test-description"
|
| 63 | +_TEST_REQUEST_METADATA = () |
| 64 | +_TEST_TIMEOUT = None |
61 | 65 |
|
62 | 66 | _TEST_ENDPOINT_NAME = (
|
63 | 67 | f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}"
|
@@ -270,6 +274,16 @@ def create_endpoint_mock():
|
270 | 274 | yield create_endpoint_mock
|
271 | 275 |
|
272 | 276 |
|
| 277 | +@pytest.fixture |
| 278 | +def update_endpoint_mock(): |
| 279 | + with mock.patch.object( |
| 280 | + endpoint_service_client.EndpointServiceClient, "update_endpoint" |
| 281 | + ) as update_endpoint_mock: |
| 282 | + update_endpoint_lro_mock = mock.Mock(ga_operation.Operation) |
| 283 | + update_endpoint_mock.return_value = update_endpoint_lro_mock |
| 284 | + yield update_endpoint_mock |
| 285 | + |
| 286 | + |
273 | 287 | @pytest.fixture
|
274 | 288 | def deploy_model_mock():
|
275 | 289 | with mock.patch.object(
|
@@ -726,6 +740,54 @@ def test_create_with_labels(self, create_endpoint_mock, sync):
|
726 | 740 | timeout=None,
|
727 | 741 | )
|
728 | 742 |
|
| 743 | + @pytest.mark.usefixtures("get_endpoint_mock") |
| 744 | + def test_update_endpoint(self, update_endpoint_mock): |
| 745 | + endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) |
| 746 | + endpoint.update( |
| 747 | + display_name=_TEST_DISPLAY_NAME, |
| 748 | + description=_TEST_DESCRIPTION, |
| 749 | + labels=_TEST_LABELS, |
| 750 | + ) |
| 751 | + |
| 752 | + expected_endpoint = gca_endpoint.Endpoint( |
| 753 | + name=_TEST_ENDPOINT_NAME, |
| 754 | + display_name=_TEST_DISPLAY_NAME, |
| 755 | + description=_TEST_DESCRIPTION, |
| 756 | + labels=_TEST_LABELS, |
| 757 | + encryption_spec=_TEST_ENCRYPTION_SPEC, |
| 758 | + ) |
| 759 | + |
| 760 | + expected_update_mask = field_mask_pb2.FieldMask( |
| 761 | + paths=["display_name", "description", "labels"] |
| 762 | + ) |
| 763 | + |
| 764 | + update_endpoint_mock.assert_called_once_with( |
| 765 | + endpoint=expected_endpoint, |
| 766 | + update_mask=expected_update_mask, |
| 767 | + metadata=_TEST_REQUEST_METADATA, |
| 768 | + timeout=_TEST_TIMEOUT, |
| 769 | + ) |
| 770 | + |
| 771 | + @pytest.mark.usefixtures("get_endpoint_with_models_mock") |
| 772 | + def test_update_traffic_split(self, update_endpoint_mock): |
| 773 | + endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) |
| 774 | + endpoint.update(traffic_split={_TEST_ID: 10, _TEST_ID_2: 80, _TEST_ID_3: 10}) |
| 775 | + |
| 776 | + expected_endpoint = gca_endpoint.Endpoint( |
| 777 | + name=_TEST_ENDPOINT_NAME, |
| 778 | + display_name=_TEST_DISPLAY_NAME, |
| 779 | + deployed_models=_TEST_DEPLOYED_MODELS, |
| 780 | + traffic_split={_TEST_ID: 10, _TEST_ID_2: 80, _TEST_ID_3: 10}, |
| 781 | + ) |
| 782 | + expected_update_mask = field_mask_pb2.FieldMask(paths=["traffic_split"]) |
| 783 | + |
| 784 | + update_endpoint_mock.assert_called_once_with( |
| 785 | + endpoint=expected_endpoint, |
| 786 | + update_mask=expected_update_mask, |
| 787 | + metadata=_TEST_REQUEST_METADATA, |
| 788 | + timeout=_TEST_TIMEOUT, |
| 789 | + ) |
| 790 | + |
729 | 791 | @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
|
730 | 792 | @pytest.mark.parametrize("sync", [True, False])
|
731 | 793 | def test_deploy(self, deploy_model_mock, sync):
|
@@ -920,7 +982,7 @@ def test_deploy_raise_error_max_replica(self, sync):
|
920 | 982 | )
|
921 | 983 | test_endpoint.deploy(model=test_model, max_replica_count=-2, sync=sync)
|
922 | 984 |
|
923 |
| - @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") |
| 985 | + @pytest.mark.usefixtures("get_endpoint_with_models_mock", "get_model_mock") |
924 | 986 | @pytest.mark.parametrize("sync", [True, False])
|
925 | 987 | def test_deploy_raise_error_traffic_split(self, sync):
|
926 | 988 | with pytest.raises(ValueError):
|
@@ -973,48 +1035,39 @@ def test_deploy_with_traffic_percent(self, deploy_model_mock, sync):
|
973 | 1035 | timeout=None,
|
974 | 1036 | )
|
975 | 1037 |
|
976 |
| - @pytest.mark.usefixtures("get_model_mock") |
| 1038 | + @pytest.mark.usefixtures("get_endpoint_with_models_mock", "get_model_mock") |
977 | 1039 | @pytest.mark.parametrize("sync", [True, False])
|
978 | 1040 | def test_deploy_with_traffic_split(self, deploy_model_mock, sync):
|
979 |
| - with mock.patch.object( |
980 |
| - endpoint_service_client.EndpointServiceClient, "get_endpoint" |
981 |
| - ) as get_endpoint_mock: |
982 |
| - get_endpoint_mock.return_value = gca_endpoint.Endpoint( |
983 |
| - display_name=_TEST_DISPLAY_NAME, |
984 |
| - name=_TEST_ENDPOINT_NAME, |
985 |
| - traffic_split={"model1": 100}, |
986 |
| - ) |
987 |
| - |
988 |
| - test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) |
989 |
| - test_model = models.Model(_TEST_ID) |
990 |
| - test_model._gca_resource.supported_deployment_resources_types.append( |
991 |
| - aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES |
992 |
| - ) |
993 |
| - test_endpoint.deploy( |
994 |
| - model=test_model, |
995 |
| - traffic_split={"model1": 30, "0": 70}, |
996 |
| - sync=sync, |
997 |
| - deploy_request_timeout=None, |
998 |
| - ) |
| 1041 | + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) |
| 1042 | + test_model = models.Model(_TEST_ID) |
| 1043 | + test_model._gca_resource.supported_deployment_resources_types.append( |
| 1044 | + aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES |
| 1045 | + ) |
| 1046 | + test_endpoint.deploy( |
| 1047 | + model=test_model, |
| 1048 | + traffic_split={_TEST_ID: 10, _TEST_ID_2: 40, _TEST_ID_3: 10, "0": 40}, |
| 1049 | + sync=sync, |
| 1050 | + deploy_request_timeout=None, |
| 1051 | + ) |
999 | 1052 |
|
1000 |
| - if not sync: |
1001 |
| - test_endpoint.wait() |
1002 |
| - automatic_resources = gca_machine_resources.AutomaticResources( |
1003 |
| - min_replica_count=1, |
1004 |
| - max_replica_count=1, |
1005 |
| - ) |
1006 |
| - deployed_model = gca_endpoint.DeployedModel( |
1007 |
| - automatic_resources=automatic_resources, |
1008 |
| - model=test_model.resource_name, |
1009 |
| - display_name=None, |
1010 |
| - ) |
1011 |
| - deploy_model_mock.assert_called_once_with( |
1012 |
| - endpoint=test_endpoint.resource_name, |
1013 |
| - deployed_model=deployed_model, |
1014 |
| - traffic_split={"model1": 30, "0": 70}, |
1015 |
| - metadata=(), |
1016 |
| - timeout=None, |
1017 |
| - ) |
| 1053 | + if not sync: |
| 1054 | + test_endpoint.wait() |
| 1055 | + automatic_resources = gca_machine_resources.AutomaticResources( |
| 1056 | + min_replica_count=1, |
| 1057 | + max_replica_count=1, |
| 1058 | + ) |
| 1059 | + deployed_model = gca_endpoint.DeployedModel( |
| 1060 | + automatic_resources=automatic_resources, |
| 1061 | + model=test_model.resource_name, |
| 1062 | + display_name=None, |
| 1063 | + ) |
| 1064 | + deploy_model_mock.assert_called_once_with( |
| 1065 | + endpoint=test_endpoint.resource_name, |
| 1066 | + deployed_model=deployed_model, |
| 1067 | + traffic_split={_TEST_ID: 10, _TEST_ID_2: 40, _TEST_ID_3: 10, "0": 40}, |
| 1068 | + metadata=(), |
| 1069 | + timeout=None, |
| 1070 | + ) |
1018 | 1071 |
|
1019 | 1072 | @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
|
1020 | 1073 | @pytest.mark.parametrize("sync", [True, False])
|
|
0 commit comments