Skip to content

Commit 83457ca

Browse files
jaycee-licopybara-github
authored andcommitted
docs: new samples for model serialization
PiperOrigin-RevId: 504654663
1 parent 5d0bc1e commit 83457ca

File tree

5 files changed

+159
-3
lines changed

5 files changed

+159
-3
lines changed

samples/model-builder/conftest.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -896,9 +896,16 @@ def mock_log_classification_metrics():
896896

897897
@pytest.fixture
898898
def mock_log_model():
899-
with patch.object(aiplatform, "log_model") as mock_log_metrics:
900-
mock_log_metrics.return_value = None
901-
yield mock_log_metrics
899+
with patch.object(aiplatform, "log_model") as mock_log_model:
900+
mock_log_model.return_value = None
901+
yield mock_log_model
902+
903+
904+
@pytest.fixture
905+
def mock_save_model():
906+
with patch.object(aiplatform, "save_model") as mock_save_model:
907+
mock_save_model.return_value = None
908+
yield mock_save_model
902909

903910

904911
@pytest.fixture
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.cloud import aiplatform
16+
17+
18+
# [START aiplatform_sdk_get_experiment_model_sample]
19+
def get_experiment_model_sample(
20+
project: str,
21+
location: str,
22+
artifact_id: str,
23+
) -> "ExperimentModel": # noqa: F821
24+
aiplatform.init(project=project, location=location)
25+
experiment_model = aiplatform.get_experiment_model(artifact_id=artifact_id)
26+
27+
return experiment_model
28+
29+
30+
# [END aiplatform_sdk_get_experiment_model_sample]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import get_experiment_model_sample
16+
17+
import pytest
18+
19+
import test_constants as constants
20+
21+
22+
@pytest.mark.usefixtures("mock_sdk_init")
23+
def test_get_experiment_model_sample(mock_get_experiment_model):
24+
25+
get_experiment_model_sample.get_experiment_model_sample(
26+
project=constants.PROJECT,
27+
location=constants.LOCATION,
28+
artifact_id=constants.EXPERIMENT_MODEL_ID,
29+
)
30+
31+
mock_get_experiment_model.assert_called_once_with(
32+
artifact_id=constants.EXPERIMENT_MODEL_ID,
33+
)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional, Union
16+
17+
from google.cloud import aiplatform
18+
19+
20+
# [START aiplatform_sdk_save_model_sample]
21+
def save_model_sample(
22+
project: str,
23+
location: str,
24+
model: Union[
25+
"sklearn.base.BaseEstimator", "xgb.Booster", "tf.Module" # noqa: F821
26+
],
27+
artifact_id: Optional[str] = None,
28+
uri: Optional[str] = None,
29+
input_example: Optional[
30+
Union[list, dict, "pd.DataFrame", "np.ndarray"] # noqa: F821
31+
] = None,
32+
display_name: Optional[str] = None,
33+
) -> None:
34+
aiplatform.init(project=project, location=location)
35+
36+
aiplatform.save_model(
37+
model=model,
38+
artifact_id=artifact_id,
39+
uri=uri,
40+
input_example=input_example,
41+
display_name=display_name,
42+
)
43+
44+
45+
# [END aiplatform_sdk_save_model_sample]
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
import save_model_sample
18+
19+
import test_constants as constants
20+
21+
22+
@pytest.mark.usefixtures("mock_sdk_init")
23+
def test_save_model_sample(mock_save_model):
24+
25+
save_model_sample.save_model_sample(
26+
project=constants.PROJECT,
27+
location=constants.LOCATION,
28+
model=constants.ML_MODEL,
29+
artifact_id=constants.EXPERIMENT_MODEL_ID,
30+
uri=constants.MODEL_ARTIFACT_URI,
31+
input_example=constants.EXPERIMENT_MODEL_INPUT_EXAMPLE,
32+
display_name=constants.DISPLAY_NAME,
33+
)
34+
35+
mock_save_model.assert_called_once_with(
36+
model=constants.ML_MODEL,
37+
artifact_id=constants.EXPERIMENT_MODEL_ID,
38+
uri=constants.MODEL_ARTIFACT_URI,
39+
input_example=constants.EXPERIMENT_MODEL_INPUT_EXAMPLE,
40+
display_name=constants.DISPLAY_NAME,
41+
)

0 commit comments

Comments
 (0)