Skip to content

Commit 50bdb01

Browse files
authored
Feat: add batch_size kwarg for batch prediction jobs (#1194)
* Add batch_size kwarg for batch prediction jobs * Fix errors Update the copyright year. Change the order of the argument. Fix the syntax error. * fix: change description layout
1 parent 7c70484 commit 50bdb01

File tree

4 files changed

+70
-36
lines changed

4 files changed

+70
-36
lines changed

google/cloud/aiplatform/jobs.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2020 Google LLC
3+
# Copyright 2022 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -40,6 +40,7 @@
4040
job_state as gca_job_state,
4141
hyperparameter_tuning_job as gca_hyperparameter_tuning_job_compat,
4242
machine_resources as gca_machine_resources_compat,
43+
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
4344
study as gca_study_compat,
4445
)
4546
from google.cloud.aiplatform.constants import base as constants
@@ -376,6 +377,7 @@ def create(
376377
encryption_spec_key_name: Optional[str] = None,
377378
sync: bool = True,
378379
create_request_timeout: Optional[float] = None,
380+
batch_size: Optional[int] = None,
379381
) -> "BatchPredictionJob":
380382
"""Create a batch prediction job.
381383
@@ -534,6 +536,13 @@ def create(
534536
be immediately returned and synced when the Future has completed.
535537
create_request_timeout (float):
536538
Optional. The timeout for the create request in seconds.
539+
batch_size (int):
540+
Optional. The number of the records (e.g. instances) of the operation given in each batch
541+
to a machine replica. Machine type, and size of a single record should be considered
542+
when setting this parameter, higher value speeds up the batch operation's execution,
543+
but too high value will result in a whole batch not fitting in a machine's memory,
544+
and the whole operation will fail.
545+
The default value is 64.
537546
Returns:
538547
(jobs.BatchPredictionJob):
539548
Instantiated representation of the created batch prediction job.
@@ -647,7 +656,14 @@ def create(
647656

648657
gapic_batch_prediction_job.dedicated_resources = dedicated_resources
649658

650-
gapic_batch_prediction_job.manual_batch_tuning_parameters = None
659+
manual_batch_tuning_parameters = (
660+
gca_manual_batch_tuning_parameters_compat.ManualBatchTuningParameters()
661+
)
662+
manual_batch_tuning_parameters.batch_size = batch_size
663+
664+
gapic_batch_prediction_job.manual_batch_tuning_parameters = (
665+
manual_batch_tuning_parameters
666+
)
651667

652668
# User Labels
653669
gapic_batch_prediction_job.labels = labels

google/cloud/aiplatform/models.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2020 Google LLC
3+
# Copyright 2022 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -2284,6 +2284,7 @@ def batch_predict(
22842284
encryption_spec_key_name: Optional[str] = None,
22852285
sync: bool = True,
22862286
create_request_timeout: Optional[float] = None,
2287+
batch_size: Optional[int] = None,
22872288
) -> jobs.BatchPredictionJob:
22882289
"""Creates a batch prediction job using this Model and outputs
22892290
prediction results to the provided destination prefix in the specified
@@ -2442,6 +2443,13 @@ def batch_predict(
24422443
Overrides encryption_spec_key_name set in aiplatform.init.
24432444
create_request_timeout (float):
24442445
Optional. The timeout for the create request in seconds.
2446+
batch_size (int):
2447+
Optional. The number of the records (e.g. instances) of the operation given in each batch
2448+
to a machine replica. Machine type, and size of a single record should be considered
2449+
when setting this parameter, higher value speeds up the batch operation's execution,
2450+
but too high value will result in a whole batch not fitting in a machine's memory,
2451+
and the whole operation will fail.
2452+
The default value is 64.
24452453
Returns:
24462454
(jobs.BatchPredictionJob):
24472455
Instantiated representation of the created batch prediction job.
@@ -2462,6 +2470,7 @@ def batch_predict(
24622470
accelerator_count=accelerator_count,
24632471
starting_replica_count=starting_replica_count,
24642472
max_replica_count=max_replica_count,
2473+
batch_size=batch_size,
24652474
generate_explanation=generate_explanation,
24662475
explanation_metadata=explanation_metadata,
24672476
explanation_parameters=explanation_parameters,

tests/unit/aiplatform/test_jobs.py

+6
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
io as gca_io_compat,
3838
job_state as gca_job_state_compat,
3939
machine_resources as gca_machine_resources_compat,
40+
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
4041
)
4142

4243
from google.cloud.aiplatform_v1.services.job_service import client as job_service_client
@@ -132,6 +133,7 @@
132133
_TEST_ACCELERATOR_COUNT = 2
133134
_TEST_STARTING_REPLICA_COUNT = 2
134135
_TEST_MAX_REPLICA_COUNT = 12
136+
_TEST_BATCH_SIZE = 16
135137

136138
_TEST_LABEL = {"team": "experimentation", "trial_id": "x435"}
137139

@@ -725,6 +727,7 @@ def test_batch_predict_with_all_args(
725727
credentials=creds,
726728
sync=sync,
727729
create_request_timeout=None,
730+
batch_size=_TEST_BATCH_SIZE,
728731
)
729732

730733
batch_prediction_job.wait_for_resource_creation()
@@ -756,6 +759,9 @@ def test_batch_predict_with_all_args(
756759
starting_replica_count=_TEST_STARTING_REPLICA_COUNT,
757760
max_replica_count=_TEST_MAX_REPLICA_COUNT,
758761
),
762+
manual_batch_tuning_parameters=gca_manual_batch_tuning_parameters_compat.ManualBatchTuningParameters(
763+
batch_size=_TEST_BATCH_SIZE
764+
),
759765
generate_explanation=True,
760766
explanation_spec=gca_explanation_compat.ExplanationSpec(
761767
metadata=_TEST_EXPLANATION_METADATA,

tests/unit/aiplatform/test_models.py

+36-33
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
env_var as gca_env_var,
5050
explanation as gca_explanation,
5151
machine_resources as gca_machine_resources,
52+
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
5253
model_service as gca_model_service,
5354
model_evaluation as gca_model_evaluation,
5455
endpoint_service as gca_endpoint_service,
@@ -86,6 +87,8 @@
8687
_TEST_STARTING_REPLICA_COUNT = 2
8788
_TEST_MAX_REPLICA_COUNT = 12
8889

90+
_TEST_BATCH_SIZE = 16
91+
8992
_TEST_PIPELINE_RESOURCE_NAME = (
9093
"projects/my-project/locations/us-central1/trainingPipeline/12345"
9194
)
@@ -1402,47 +1405,47 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn
14021405
encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
14031406
sync=sync,
14041407
create_request_timeout=None,
1408+
batch_size=_TEST_BATCH_SIZE,
14051409
)
14061410

14071411
if not sync:
14081412
batch_prediction_job.wait()
14091413

14101414
# Construct expected request
1411-
expected_gapic_batch_prediction_job = (
1412-
gca_batch_prediction_job.BatchPredictionJob(
1413-
display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
1414-
model=model_service_client.ModelServiceClient.model_path(
1415-
_TEST_PROJECT, _TEST_LOCATION, _TEST_ID
1416-
),
1417-
input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig(
1418-
instances_format="jsonl",
1419-
gcs_source=gca_io.GcsSource(
1420-
uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
1421-
),
1422-
),
1423-
output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig(
1424-
gcs_destination=gca_io.GcsDestination(
1425-
output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX
1426-
),
1427-
predictions_format="csv",
1428-
),
1429-
dedicated_resources=gca_machine_resources.BatchDedicatedResources(
1430-
machine_spec=gca_machine_resources.MachineSpec(
1431-
machine_type=_TEST_MACHINE_TYPE,
1432-
accelerator_type=_TEST_ACCELERATOR_TYPE,
1433-
accelerator_count=_TEST_ACCELERATOR_COUNT,
1434-
),
1435-
starting_replica_count=_TEST_STARTING_REPLICA_COUNT,
1436-
max_replica_count=_TEST_MAX_REPLICA_COUNT,
1415+
expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob(
1416+
display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
1417+
model=model_service_client.ModelServiceClient.model_path(
1418+
_TEST_PROJECT, _TEST_LOCATION, _TEST_ID
1419+
),
1420+
input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig(
1421+
instances_format="jsonl",
1422+
gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
1423+
),
1424+
output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig(
1425+
gcs_destination=gca_io.GcsDestination(
1426+
output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX
14371427
),
1438-
generate_explanation=True,
1439-
explanation_spec=gca_explanation.ExplanationSpec(
1440-
metadata=_TEST_EXPLANATION_METADATA,
1441-
parameters=_TEST_EXPLANATION_PARAMETERS,
1428+
predictions_format="csv",
1429+
),
1430+
dedicated_resources=gca_machine_resources.BatchDedicatedResources(
1431+
machine_spec=gca_machine_resources.MachineSpec(
1432+
machine_type=_TEST_MACHINE_TYPE,
1433+
accelerator_type=_TEST_ACCELERATOR_TYPE,
1434+
accelerator_count=_TEST_ACCELERATOR_COUNT,
14421435
),
1443-
labels=_TEST_LABEL,
1444-
encryption_spec=_TEST_ENCRYPTION_SPEC,
1445-
)
1436+
starting_replica_count=_TEST_STARTING_REPLICA_COUNT,
1437+
max_replica_count=_TEST_MAX_REPLICA_COUNT,
1438+
),
1439+
manual_batch_tuning_parameters=gca_manual_batch_tuning_parameters_compat.ManualBatchTuningParameters(
1440+
batch_size=_TEST_BATCH_SIZE
1441+
),
1442+
generate_explanation=True,
1443+
explanation_spec=gca_explanation.ExplanationSpec(
1444+
metadata=_TEST_EXPLANATION_METADATA,
1445+
parameters=_TEST_EXPLANATION_PARAMETERS,
1446+
),
1447+
labels=_TEST_LABEL,
1448+
encryption_spec=_TEST_ENCRYPTION_SPEC,
14461449
)
14471450

14481451
create_batch_prediction_job_mock.assert_called_once_with(

0 commit comments

Comments
 (0)