|
63 | 63 | _TEST_HTTPS_TEMPLATE_PATH = "https://raw.githubusercontent.com/repo/pipeline.json"
|
64 | 64 | _TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
|
65 | 65 | _TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}"
|
| 66 | +_TEST_RESERVED_IP_RANGES = ["vertex-ai-ip-range"] |
66 | 67 |
|
67 | 68 | _TEST_PIPELINE_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}"
|
68 | 69 | _TEST_PIPELINE_JOB_LIST_READ_MASK = field_mask.FieldMask(
|
@@ -231,6 +232,7 @@ def mock_pipeline_service_create():
|
231 | 232 | create_time=_TEST_PIPELINE_CREATE_TIME,
|
232 | 233 | service_account=_TEST_SERVICE_ACCOUNT,
|
233 | 234 | network=_TEST_NETWORK,
|
| 235 | + reserved_ip_ranges=_TEST_RESERVED_IP_RANGES, |
234 | 236 | )
|
235 | 237 | yield mock_create_pipeline_job
|
236 | 238 |
|
@@ -267,6 +269,7 @@ def make_pipeline_job(state):
|
267 | 269 | create_time=_TEST_PIPELINE_CREATE_TIME,
|
268 | 270 | service_account=_TEST_SERVICE_ACCOUNT,
|
269 | 271 | network=_TEST_NETWORK,
|
| 272 | + reserved_ip_ranges=_TEST_RESERVED_IP_RANGES, |
270 | 273 | job_detail=gca_pipeline_job.PipelineJobDetail(
|
271 | 274 | pipeline_run_context=gca_context.Context(
|
272 | 275 | name=_TEST_PIPELINE_JOB_NAME,
|
@@ -548,6 +551,90 @@ def test_run_call_pipeline_service_create(
|
548 | 551 | gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
|
549 | 552 | )
|
550 | 553 |
|
| 554 | + @pytest.mark.parametrize( |
| 555 | + "job_spec", |
| 556 | + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], |
| 557 | + ) |
| 558 | + @pytest.mark.parametrize("sync", [True, False]) |
| 559 | + def test_run_call_pipeline_service_run_with_reserved_ip_ranges( |
| 560 | + self, |
| 561 | + mock_pipeline_service_create, |
| 562 | + mock_pipeline_service_get, |
| 563 | + mock_pipeline_bucket_exists, |
| 564 | + job_spec, |
| 565 | + mock_load_yaml_and_json, |
| 566 | + sync, |
| 567 | + ): |
| 568 | + import yaml |
| 569 | + |
| 570 | + aiplatform.init( |
| 571 | + project=_TEST_PROJECT, |
| 572 | + staging_bucket=_TEST_GCS_BUCKET_NAME, |
| 573 | + location=_TEST_LOCATION, |
| 574 | + credentials=_TEST_CREDENTIALS, |
| 575 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 576 | + network=_TEST_NETWORK, |
| 577 | + ) |
| 578 | + |
| 579 | + job = pipeline_jobs.PipelineJob( |
| 580 | + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, |
| 581 | + template_path=_TEST_TEMPLATE_PATH, |
| 582 | + job_id=_TEST_PIPELINE_JOB_ID, |
| 583 | + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, |
| 584 | + input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS, |
| 585 | + enable_caching=True, |
| 586 | + ) |
| 587 | + |
| 588 | + job.run( |
| 589 | + reserved_ip_ranges=_TEST_RESERVED_IP_RANGES, |
| 590 | + sync=sync, |
| 591 | + create_request_timeout=None, |
| 592 | + ) |
| 593 | + |
| 594 | + if not sync: |
| 595 | + job.wait() |
| 596 | + |
| 597 | + expected_runtime_config_dict = { |
| 598 | + "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, |
| 599 | + "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, |
| 600 | + "inputArtifacts": {"vertex_model": {"artifactId": "456"}}, |
| 601 | + } |
| 602 | + runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb |
| 603 | + json_format.ParseDict(expected_runtime_config_dict, runtime_config) |
| 604 | + |
| 605 | + job_spec = yaml.safe_load(job_spec) |
| 606 | + pipeline_spec = job_spec.get("pipelineSpec") or job_spec |
| 607 | + |
| 608 | + # Construct expected request |
| 609 | + expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob( |
| 610 | + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, |
| 611 | + pipeline_spec={ |
| 612 | + "components": {}, |
| 613 | + "pipelineInfo": pipeline_spec["pipelineInfo"], |
| 614 | + "root": pipeline_spec["root"], |
| 615 | + "schemaVersion": "2.1.0", |
| 616 | + }, |
| 617 | + runtime_config=runtime_config, |
| 618 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 619 | + network=_TEST_NETWORK, |
| 620 | + reserved_ip_ranges=_TEST_RESERVED_IP_RANGES, |
| 621 | + ) |
| 622 | + |
| 623 | + mock_pipeline_service_create.assert_called_once_with( |
| 624 | + parent=_TEST_PARENT, |
| 625 | + pipeline_job=expected_gapic_pipeline_job, |
| 626 | + pipeline_job_id=_TEST_PIPELINE_JOB_ID, |
| 627 | + timeout=None, |
| 628 | + ) |
| 629 | + |
| 630 | + mock_pipeline_service_get.assert_called_with( |
| 631 | + name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY |
| 632 | + ) |
| 633 | + |
| 634 | + assert job._gca_resource == make_pipeline_job( |
| 635 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED |
| 636 | + ) |
| 637 | + |
551 | 638 | @pytest.mark.parametrize(
|
552 | 639 | "job_spec",
|
553 | 640 | [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
|
|
0 commit comments