Skip to content

Commit 7997094

Browse files
jaycee-licopybara-github
authored andcommitted
docs: samples for model serialization
PiperOrigin-RevId: 504033183
1 parent f38ddc2 commit 7997094

12 files changed

+459
-0
lines changed

samples/model-builder/conftest.py

+68
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,30 @@ def mock_artifacts():
681681
yield mock
682682

683683

684+
@pytest.fixture
685+
def mock_experiment_models():
686+
mock = MagicMock()
687+
yield mock
688+
689+
690+
@pytest.fixture
691+
def mock_model_info():
692+
mock = MagicMock()
693+
yield mock
694+
695+
696+
@pytest.fixture
697+
def mock_ml_model():
698+
mock = MagicMock()
699+
yield mock
700+
701+
702+
@pytest.fixture
703+
def mock_experiment_model():
704+
mock = MagicMock(aiplatform.metadata.schema.google.artifact_schema.ExperimentModel)
705+
yield mock
706+
707+
684708
@pytest.fixture
685709
def mock_get_execution(mock_execution):
686710
with patch.object(aiplatform, "Execution") as mock_get_execution:
@@ -870,6 +894,13 @@ def mock_log_classification_metrics():
870894
yield mock_log_metrics
871895

872896

897+
@pytest.fixture
898+
def mock_log_model():
899+
with patch.object(aiplatform, "log_model") as mock_log_metrics:
900+
mock_log_metrics.return_value = None
901+
yield mock_log_metrics
902+
903+
873904
@pytest.fixture
874905
def mock_log_pipeline_job():
875906
with patch.object(aiplatform, "log") as mock_log_pipeline_job:
@@ -944,6 +975,43 @@ def mock_get_artifacts(mock_artifacts, mock_experiment_run):
944975
yield mock_get_artifacts
945976

946977

978+
@pytest.fixture
979+
def mock_get_experiment_models(mock_experiment_models, mock_experiment_run):
980+
with patch.object(
981+
mock_experiment_run, "get_experiment_models"
982+
) as mock_get_experiment_models:
983+
mock_get_experiment_models.return_value = mock_experiment_models
984+
yield mock_get_experiment_models
985+
986+
987+
@pytest.fixture
988+
def mock_get_experiment_model(mock_experiment_model):
989+
with patch.object(aiplatform, "get_experiment_model") as mock_get_experiment_model:
990+
mock_get_experiment_model.return_value = mock_experiment_model
991+
yield mock_get_experiment_model
992+
993+
994+
@pytest.fixture
995+
def mock_get_model_info(mock_experiment_model, mock_model_info):
996+
with patch.object(mock_experiment_model, "get_model_info") as mock_get_model_info:
997+
mock_get_model_info.return_value = mock_model_info
998+
yield mock_get_model_info
999+
1000+
1001+
@pytest.fixture
1002+
def mock_load_model(mock_experiment_model, mock_ml_model):
1003+
with patch.object(mock_experiment_model, "load_model") as mock_load_model:
1004+
mock_load_model.return_value = mock_ml_model
1005+
yield mock_load_model
1006+
1007+
1008+
@pytest.fixture
1009+
def mock_register_model(mock_experiment_model, mock_model):
1010+
with patch.object(mock_experiment_model, "register_model") as mock_register_model:
1011+
mock_register_model.return_value = mock_model
1012+
yield mock_register_model
1013+
1014+
9471015
"""
9481016
----------------------------------------------------------------------------
9491017
Model Versioning Fixtures
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Union
16+
17+
from google.cloud import aiplatform
18+
19+
20+
# [START aiplatform_sdk_get_experiment_run_models_sample]
21+
def get_experiment_run_models_sample(
22+
run_name: str,
23+
experiment: Union[str, aiplatform.Experiment],
24+
project: str,
25+
location: str,
26+
) -> List["ExperimentModel"]: # noqa: F821
27+
experiment_run = aiplatform.ExperimentRun(
28+
run_name=run_name, experiment=experiment, project=project, location=location
29+
)
30+
31+
return experiment_run.get_experiment_models()
32+
33+
34+
# [END aiplatform_sdk_get_experiment_run_models_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import get_experiment_run_models_sample
16+
17+
import pytest
18+
19+
import test_constants as constants
20+
21+
22+
@pytest.mark.usefixtures("mock_get_run")
23+
def test_get_experiment_run_models_sample(
24+
mock_get_experiment_models, mock_experiment_models
25+
):
26+
27+
experiment_models = (
28+
get_experiment_run_models_sample.get_experiment_run_models_sample(
29+
run_name=constants.EXPERIMENT_RUN_NAME,
30+
experiment=constants.EXPERIMENT_NAME,
31+
project=constants.PROJECT,
32+
location=constants.LOCATION,
33+
)
34+
)
35+
36+
mock_get_experiment_models.assert_called_once()
37+
38+
assert experiment_models is mock_experiment_models
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict
16+
17+
from google.cloud import aiplatform
18+
19+
20+
# [START aiplatform_sdk_get_model_info_sample]
21+
def get_model_info_sample(
22+
artifact_id: str,
23+
project: str,
24+
location: str,
25+
) -> Dict[str, Any]:
26+
experiment_model = aiplatform.get_experiment_model(
27+
artifact_id=artifact_id, project=project, location=location
28+
)
29+
30+
return experiment_model.get_model_info()
31+
32+
33+
# [END aiplatform_sdk_get_model_info_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import get_model_info_sample
16+
17+
import pytest
18+
19+
import test_constants as constants
20+
21+
22+
@pytest.mark.usefixtures("mock_get_run")
23+
def test_get_model_info_sample(
24+
mock_get_experiment_model, mock_get_model_info, mock_model_info
25+
):
26+
27+
model_info = get_model_info_sample.get_model_info_sample(
28+
artifact_id=constants.EXPERIMENT_MODEL_ID,
29+
project=constants.PROJECT,
30+
location=constants.LOCATION,
31+
)
32+
33+
mock_get_experiment_model.assert_called_once_with(
34+
artifact_id=constants.EXPERIMENT_MODEL_ID,
35+
project=constants.PROJECT,
36+
location=constants.LOCATION,
37+
)
38+
mock_get_model_info.assert_called_once()
39+
40+
assert model_info is mock_model_info
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Union
16+
17+
from google.cloud import aiplatform
18+
19+
20+
# [START aiplatform_sdk_load_experiment_model_sample]
21+
def load_experiment_model_sample(
22+
artifact_id: str,
23+
project: str,
24+
location: str,
25+
) -> Union["sklearn.base.BaseEstimator", "xgb.Booster", "tf.Module"]: # noqa: F821:
26+
experiment_model = aiplatform.get_experiment_model(
27+
artifact_id=artifact_id, project=project, location=location
28+
)
29+
30+
return experiment_model.load_model()
31+
32+
33+
# [END aiplatform_sdk_load_experiment_model_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import load_experiment_model_sample
16+
17+
import pytest
18+
19+
import test_constants as constants
20+
21+
22+
@pytest.mark.usefixtures("mock_get_run")
23+
def test_load_experiment_model_sample(
24+
mock_get_experiment_model, mock_load_model, mock_ml_model
25+
):
26+
27+
ml_model = load_experiment_model_sample.load_experiment_model_sample(
28+
artifact_id=constants.EXPERIMENT_MODEL_ID,
29+
project=constants.PROJECT,
30+
location=constants.LOCATION,
31+
)
32+
33+
mock_get_experiment_model.assert_called_once_with(
34+
artifact_id=constants.EXPERIMENT_MODEL_ID,
35+
project=constants.PROJECT,
36+
location=constants.LOCATION,
37+
)
38+
mock_load_model.assert_called_once()
39+
40+
assert ml_model is mock_ml_model
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional, Union
16+
17+
from google.cloud import aiplatform
18+
19+
20+
# [START aiplatform_sdk_log_model_sample]
21+
def log_model_sample(
22+
experiment_name: str,
23+
run_name: str,
24+
project: str,
25+
location: str,
26+
model: Union[
27+
"sklearn.base.BaseEstimator", "xgb.Booster", "tf.Module" # noqa: F821
28+
],
29+
artifact_id: Optional[str] = None,
30+
uri: Optional[str] = None,
31+
input_example: Optional[
32+
Union[list, dict, "pd.DataFrame", "np.ndarray"] # noqa: F821
33+
] = None, # noqa: F821
34+
display_name: Optional[str] = None,
35+
) -> None:
36+
aiplatform.init(experiment=experiment_name, project=project, location=location)
37+
38+
aiplatform.start_run(run=run_name, resume=True)
39+
40+
aiplatform.log_model(
41+
model=model,
42+
artifact_id=artifact_id,
43+
uri=uri,
44+
input_example=input_example,
45+
display_name=display_name,
46+
)
47+
48+
49+
# [END aiplatform_sdk_log_model_sample]

0 commit comments

Comments
 (0)