|
49 | 49 | env_var as gca_env_var,
|
50 | 50 | explanation as gca_explanation,
|
51 | 51 | machine_resources as gca_machine_resources,
|
| 52 | + manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat, |
52 | 53 | model_service as gca_model_service,
|
53 | 54 | model_evaluation as gca_model_evaluation,
|
54 | 55 | endpoint_service as gca_endpoint_service,
|
|
86 | 87 | _TEST_STARTING_REPLICA_COUNT = 2
|
87 | 88 | _TEST_MAX_REPLICA_COUNT = 12
|
88 | 89 |
|
| 90 | +_TEST_BATCH_SIZE = 16 |
| 91 | + |
89 | 92 | _TEST_PIPELINE_RESOURCE_NAME = (
|
90 | 93 | "projects/my-project/locations/us-central1/trainingPipeline/12345"
|
91 | 94 | )
|
@@ -1402,47 +1405,47 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn
|
1402 | 1405 | encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
|
1403 | 1406 | sync=sync,
|
1404 | 1407 | create_request_timeout=None,
|
| 1408 | + batch_size=_TEST_BATCH_SIZE, |
1405 | 1409 | )
|
1406 | 1410 |
|
1407 | 1411 | if not sync:
|
1408 | 1412 | batch_prediction_job.wait()
|
1409 | 1413 |
|
1410 | 1414 | # Construct expected request
|
1411 |
| - expected_gapic_batch_prediction_job = ( |
1412 |
| - gca_batch_prediction_job.BatchPredictionJob( |
1413 |
| - display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, |
1414 |
| - model=model_service_client.ModelServiceClient.model_path( |
1415 |
| - _TEST_PROJECT, _TEST_LOCATION, _TEST_ID |
1416 |
| - ), |
1417 |
| - input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( |
1418 |
| - instances_format="jsonl", |
1419 |
| - gcs_source=gca_io.GcsSource( |
1420 |
| - uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] |
1421 |
| - ), |
1422 |
| - ), |
1423 |
| - output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( |
1424 |
| - gcs_destination=gca_io.GcsDestination( |
1425 |
| - output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX |
1426 |
| - ), |
1427 |
| - predictions_format="csv", |
1428 |
| - ), |
1429 |
| - dedicated_resources=gca_machine_resources.BatchDedicatedResources( |
1430 |
| - machine_spec=gca_machine_resources.MachineSpec( |
1431 |
| - machine_type=_TEST_MACHINE_TYPE, |
1432 |
| - accelerator_type=_TEST_ACCELERATOR_TYPE, |
1433 |
| - accelerator_count=_TEST_ACCELERATOR_COUNT, |
1434 |
| - ), |
1435 |
| - starting_replica_count=_TEST_STARTING_REPLICA_COUNT, |
1436 |
| - max_replica_count=_TEST_MAX_REPLICA_COUNT, |
| 1415 | + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( |
| 1416 | + display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, |
| 1417 | + model=model_service_client.ModelServiceClient.model_path( |
| 1418 | + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID |
| 1419 | + ), |
| 1420 | + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( |
| 1421 | + instances_format="jsonl", |
| 1422 | + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), |
| 1423 | + ), |
| 1424 | + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( |
| 1425 | + gcs_destination=gca_io.GcsDestination( |
| 1426 | + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX |
1437 | 1427 | ),
|
1438 |
| - generate_explanation=True, |
1439 |
| - explanation_spec=gca_explanation.ExplanationSpec( |
1440 |
| - metadata=_TEST_EXPLANATION_METADATA, |
1441 |
| - parameters=_TEST_EXPLANATION_PARAMETERS, |
| 1428 | + predictions_format="csv", |
| 1429 | + ), |
| 1430 | + dedicated_resources=gca_machine_resources.BatchDedicatedResources( |
| 1431 | + machine_spec=gca_machine_resources.MachineSpec( |
| 1432 | + machine_type=_TEST_MACHINE_TYPE, |
| 1433 | + accelerator_type=_TEST_ACCELERATOR_TYPE, |
| 1434 | + accelerator_count=_TEST_ACCELERATOR_COUNT, |
1442 | 1435 | ),
|
1443 |
| - labels=_TEST_LABEL, |
1444 |
| - encryption_spec=_TEST_ENCRYPTION_SPEC, |
1445 |
| - ) |
| 1436 | + starting_replica_count=_TEST_STARTING_REPLICA_COUNT, |
| 1437 | + max_replica_count=_TEST_MAX_REPLICA_COUNT, |
| 1438 | + ), |
| 1439 | + manual_batch_tuning_parameters=gca_manual_batch_tuning_parameters_compat.ManualBatchTuningParameters( |
| 1440 | + batch_size=_TEST_BATCH_SIZE |
| 1441 | + ), |
| 1442 | + generate_explanation=True, |
| 1443 | + explanation_spec=gca_explanation.ExplanationSpec( |
| 1444 | + metadata=_TEST_EXPLANATION_METADATA, |
| 1445 | + parameters=_TEST_EXPLANATION_PARAMETERS, |
| 1446 | + ), |
| 1447 | + labels=_TEST_LABEL, |
| 1448 | + encryption_spec=_TEST_ENCRYPTION_SPEC, |
1446 | 1449 | )
|
1447 | 1450 |
|
1448 | 1451 | create_batch_prediction_job_mock.assert_called_once_with(
|
|
0 commit comments