|
174 | 174 |
|
175 | 175 | _TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials())
|
176 | 176 | _TEST_SERVICE_ACCOUNT = "[email protected]"
|
| 177 | +_TEST_MODEL_GARDEN_SOURCE_MODEL_NAME = "publishers/meta/models/llama3_1" |
177 | 178 |
|
178 | 179 |
|
179 | 180 | _TEST_EXPLANATION_METADATA = explain.ExplanationMetadata(
|
@@ -1900,6 +1901,53 @@ def test_upload_uploads_and_gets_model_with_custom_location(
|
1900 | 1901 | name=test_model_resource_name, retry=base._DEFAULT_RETRY
|
1901 | 1902 | )
|
1902 | 1903 |
|
| 1904 | + @pytest.mark.parametrize("sync", [True, False]) |
| 1905 | + def test_upload_with_model_garden_source( |
| 1906 | + self, upload_model_mock, get_model_mock, sync |
| 1907 | + ): |
| 1908 | + |
| 1909 | + my_model = models.Model.upload( |
| 1910 | + display_name=_TEST_MODEL_NAME, |
| 1911 | + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, |
| 1912 | + serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, |
| 1913 | + serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, |
| 1914 | + sync=sync, |
| 1915 | + upload_request_timeout=None, |
| 1916 | + model_garden_source_model_name=_TEST_MODEL_GARDEN_SOURCE_MODEL_NAME, |
| 1917 | + ) |
| 1918 | + |
| 1919 | + if not sync: |
| 1920 | + my_model.wait() |
| 1921 | + |
| 1922 | + container_spec = gca_model.ModelContainerSpec( |
| 1923 | + image_uri=_TEST_SERVING_CONTAINER_IMAGE, |
| 1924 | + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, |
| 1925 | + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, |
| 1926 | + ) |
| 1927 | + |
| 1928 | + managed_model = gca_model.Model( |
| 1929 | + display_name=_TEST_MODEL_NAME, |
| 1930 | + container_spec=container_spec, |
| 1931 | + version_aliases=["default"], |
| 1932 | + base_model_source=gca_model.Model.BaseModelSource( |
| 1933 | + model_garden_source=gca_model.ModelGardenSource( |
| 1934 | + public_model_name=_TEST_MODEL_GARDEN_SOURCE_MODEL_NAME |
| 1935 | + ) |
| 1936 | + ), |
| 1937 | + ) |
| 1938 | + |
| 1939 | + upload_model_mock.assert_called_once_with( |
| 1940 | + request=gca_model_service.UploadModelRequest( |
| 1941 | + parent=initializer.global_config.common_location_path(), |
| 1942 | + model=managed_model, |
| 1943 | + ), |
| 1944 | + timeout=None, |
| 1945 | + ) |
| 1946 | + |
| 1947 | + get_model_mock.assert_called_once_with( |
| 1948 | + name=_TEST_MODEL_RESOURCE_NAME, retry=base._DEFAULT_RETRY |
| 1949 | + ) |
| 1950 | + |
1903 | 1951 | @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
|
1904 | 1952 | @pytest.mark.parametrize("sync", [True, False])
|
1905 | 1953 | def test_deploy(self, deploy_model_mock, sync):
|
|
0 commit comments