Skip to content

Commit 4e4bff5

Browse files
authored
feat: Added forecasting snippets and fixed bugs with existing snippets (#1210)
* Added dataset snippets * Fixed typehint and missing parameter bugs as well as added new samples * Fixed lint issues * Added bq batch_prediction bq snippets * Removed unneeded fixture * Renamed bq_source to bigquery_source * Added back explain_tabular_sample.py for now * Fixed tests * Fixed lint issues
1 parent 0036ab0 commit 4e4bff5

24 files changed

+727
-31
lines changed

samples/model-builder/conftest.py

+41
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def mock_tabular_dataset():
4545
yield mock
4646

4747

48+
@pytest.fixture
49+
def mock_time_series_dataset():
50+
mock = MagicMock(aiplatform.datasets.TimeSeriesDataset)
51+
yield mock
52+
53+
4854
@pytest.fixture
4955
def mock_text_dataset():
5056
mock = MagicMock(aiplatform.datasets.TextDataset)
@@ -74,6 +80,13 @@ def mock_get_tabular_dataset(mock_tabular_dataset):
7480
yield mock_get_tabular_dataset
7581

7682

83+
@pytest.fixture
84+
def mock_get_time_series_dataset(mock_time_series_dataset):
85+
with patch.object(aiplatform, "TimeSeriesDataset") as mock_get_time_series_dataset:
86+
mock_get_time_series_dataset.return_value = mock_time_series_dataset
87+
yield mock_get_time_series_dataset
88+
89+
7790
@pytest.fixture
7891
def mock_get_text_dataset(mock_text_dataset):
7992
with patch.object(aiplatform, "TextDataset") as mock_get_text_dataset:
@@ -107,6 +120,15 @@ def mock_create_tabular_dataset(mock_tabular_dataset):
107120
yield mock_create_tabular_dataset
108121

109122

123+
@pytest.fixture
124+
def mock_create_time_series_dataset(mock_time_series_dataset):
125+
with patch.object(
126+
aiplatform.TimeSeriesDataset, "create"
127+
) as mock_create_time_series_dataset:
128+
mock_create_time_series_dataset.return_value = mock_time_series_dataset
129+
yield mock_create_time_series_dataset
130+
131+
110132
@pytest.fixture
111133
def mock_create_text_dataset(mock_text_dataset):
112134
with patch.object(aiplatform.TextDataset, "create") as mock_create_text_dataset:
@@ -183,6 +205,12 @@ def mock_tabular_training_job():
183205
yield mock
184206

185207

208+
@pytest.fixture
209+
def mock_forecasting_training_job():
210+
mock = MagicMock(aiplatform.training_jobs.AutoMLForecastingTrainingJob)
211+
yield mock
212+
213+
186214
@pytest.fixture
187215
def mock_text_training_job():
188216
mock = MagicMock(aiplatform.training_jobs.AutoMLTextTrainingJob)
@@ -208,6 +236,19 @@ def mock_run_automl_tabular_training_job(mock_tabular_training_job):
208236
yield mock
209237

210238

239+
@pytest.fixture
240+
def mock_get_automl_forecasting_training_job(mock_forecasting_training_job):
241+
with patch.object(aiplatform, "AutoMLForecastingTrainingJob") as mock:
242+
mock.return_value = mock_forecasting_training_job
243+
yield mock
244+
245+
246+
@pytest.fixture
247+
def mock_run_automl_forecasting_training_job(mock_forecasting_training_job):
248+
with patch.object(mock_forecasting_training_job, "run") as mock:
249+
yield mock
250+
251+
211252
@pytest.fixture
212253
def mock_get_automl_image_training_job(mock_image_training_job):
213254
with patch.object(aiplatform, "AutoMLImageTrainingJob") as mock:

samples/model-builder/create_and_import_dataset_tabular_bigquery_sample.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@
1818

1919
# [START aiplatform_sdk_create_and_import_dataset_tabular_bigquery_sample]
2020
def create_and_import_dataset_tabular_bigquery_sample(
21-
display_name: str, project: str, location: str, bq_source: str,
21+
display_name: str,
22+
project: str,
23+
location: str,
24+
bigquery_source: str,
2225
):
2326

2427
aiplatform.init(project=project, location=location)
2528

2629
dataset = aiplatform.TabularDataset.create(
27-
display_name=display_name, bq_source=bq_source,
30+
display_name=display_name,
31+
bigquery_source=bigquery_source,
2832
)
2933

3034
dataset.wait()

samples/model-builder/create_and_import_dataset_tabular_bigquery_sample_test.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@ def test_create_and_import_dataset_tabular_bigquery_sample(
2424
create_and_import_dataset_tabular_bigquery_sample.create_and_import_dataset_tabular_bigquery_sample(
2525
project=constants.PROJECT,
2626
location=constants.LOCATION,
27-
bq_source=constants.BIGQUERY_SOURCE,
27+
bigquery_source=constants.BIGQUERY_SOURCE,
2828
display_name=constants.DISPLAY_NAME,
2929
)
3030

3131
mock_sdk_init.assert_called_once_with(
3232
project=constants.PROJECT, location=constants.LOCATION
3333
)
3434
mock_create_tabular_dataset.assert_called_once_with(
35-
display_name=constants.DISPLAY_NAME, bq_source=constants.BIGQUERY_SOURCE,
35+
display_name=constants.DISPLAY_NAME,
36+
bigquery_source=constants.BIGQUERY_SOURCE,
3637
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from google.cloud import aiplatform
17+
18+
19+
# [START aiplatform_sdk_create_and_import_dataset_time_series_bigquery_sample]
20+
def create_and_import_dataset_time_series_bigquery_sample(
21+
display_name: str,
22+
project: str,
23+
location: str,
24+
bigquery_source: str,
25+
):
26+
27+
aiplatform.init(project=project, location=location)
28+
29+
dataset = aiplatform.TimeSeriesDataset.create(
30+
display_name=display_name,
31+
bigquery_source=bigquery_source,
32+
)
33+
34+
dataset.wait()
35+
36+
print(f'\tDataset: "{dataset.display_name}"')
37+
print(f'\tname: "{dataset.resource_name}"')
38+
39+
40+
# [END aiplatform_sdk_create_and_import_dataset_time_series_bigquery_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import create_and_import_dataset_time_series_bigquery_sample
17+
import test_constants as constants
18+
19+
20+
def test_create_and_import_dataset_time_series_bigquery_sample(
21+
mock_sdk_init, mock_create_time_series_dataset
22+
):
23+
24+
create_and_import_dataset_time_series_bigquery_sample.create_and_import_dataset_time_series_bigquery_sample(
25+
project=constants.PROJECT,
26+
location=constants.LOCATION,
27+
bigquery_source=constants.BIGQUERY_SOURCE,
28+
display_name=constants.DISPLAY_NAME,
29+
)
30+
31+
mock_sdk_init.assert_called_once_with(
32+
project=constants.PROJECT, location=constants.LOCATION
33+
)
34+
mock_create_time_series_dataset.assert_called_once_with(
35+
display_name=constants.DISPLAY_NAME,
36+
bigquery_source=constants.BIGQUERY_SOURCE,
37+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Union
16+
17+
from google.cloud import aiplatform
18+
19+
20+
# [START aiplatform_sdk_create_and_import_dataset_time_series_gcs_sample]
21+
def create_and_import_dataset_time_series_gcs_sample(
22+
display_name: str,
23+
project: str,
24+
location: str,
25+
gcs_source: Union[str, List[str]],
26+
):
27+
28+
aiplatform.init(project=project, location=location)
29+
30+
dataset = aiplatform.TimeSeriesDataset.create(
31+
display_name=display_name,
32+
gcs_source=gcs_source,
33+
)
34+
35+
dataset.wait()
36+
37+
print(f'\tDataset: "{dataset.display_name}"')
38+
print(f'\tname: "{dataset.resource_name}"')
39+
40+
41+
# [END aiplatform_sdk_create_and_import_dataset_time_series_gcs_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import create_and_import_dataset_time_series_gcs_sample
17+
import test_constants as constants
18+
19+
20+
def test_create_and_import_dataset_time_series_gcs_sample(
21+
mock_sdk_init, mock_create_time_series_dataset
22+
):
23+
24+
create_and_import_dataset_time_series_gcs_sample.create_and_import_dataset_time_series_gcs_sample(
25+
project=constants.PROJECT,
26+
location=constants.LOCATION,
27+
gcs_source=constants.GCS_SOURCES,
28+
display_name=constants.DISPLAY_NAME,
29+
)
30+
31+
mock_sdk_init.assert_called_once_with(
32+
project=constants.PROJECT, location=constants.LOCATION
33+
)
34+
mock_create_time_series_dataset.assert_called_once_with(
35+
display_name=constants.DISPLAY_NAME,
36+
gcs_source=constants.GCS_SOURCES,
37+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.cloud import aiplatform
16+
17+
18+
# [START aiplatform_sdk_create_batch_prediction_job_bigquery_sample]
19+
def create_batch_prediction_job_bigquery_sample(
20+
project: str,
21+
location: str,
22+
model_resource_name: str,
23+
job_display_name: str,
24+
bigquery_source: str,
25+
bigquery_destination_prefix: str,
26+
sync: bool = True,
27+
):
28+
aiplatform.init(project=project, location=location)
29+
30+
my_model = aiplatform.Model(model_resource_name)
31+
32+
batch_prediction_job = my_model.batch_predict(
33+
job_display_name=job_display_name,
34+
bigquery_source=bigquery_source,
35+
bigquery_destination_prefix=bigquery_destination_prefix,
36+
sync=sync,
37+
)
38+
39+
batch_prediction_job.wait()
40+
41+
print(batch_prediction_job.display_name)
42+
print(batch_prediction_job.resource_name)
43+
print(batch_prediction_job.state)
44+
return batch_prediction_job
45+
46+
47+
# [END aiplatform_sdk_create_batch_prediction_job_bigquery_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import create_batch_prediction_job_bigquery_sample
17+
import test_constants as constants
18+
19+
20+
def test_create_batch_prediction_job_bigquery_sample(
21+
mock_sdk_init, mock_model, mock_init_model, mock_batch_predict_model
22+
):
23+
24+
create_batch_prediction_job_bigquery_sample.create_batch_prediction_job_bigquery_sample(
25+
project=constants.PROJECT,
26+
location=constants.LOCATION,
27+
model_resource_name=constants.MODEL_NAME,
28+
job_display_name=constants.DISPLAY_NAME,
29+
bigquery_source=constants.BIGQUERY_SOURCE,
30+
bigquery_destination_prefix=constants.BIGQUERY_DESTINATION_PREFIX,
31+
)
32+
33+
mock_sdk_init.assert_called_once_with(
34+
project=constants.PROJECT, location=constants.LOCATION
35+
)
36+
mock_init_model.assert_called_once_with(constants.MODEL_NAME)
37+
mock_batch_predict_model.assert_called_once_with(
38+
job_display_name=constants.DISPLAY_NAME,
39+
bigquery_source=constants.BIGQUERY_SOURCE,
40+
bigquery_destination_prefix=constants.BIGQUERY_DESTINATION_PREFIX,
41+
sync=True,
42+
)

0 commit comments

Comments
 (0)