|
29 | 29 | from google.cloud import aiplatform
|
30 | 30 | from google.cloud.aiplatform import base
|
31 | 31 | from google.cloud.aiplatform import initializer
|
| 32 | +from google.cloud.aiplatform_v1 import Context as GapicContext |
| 33 | +from google.cloud.aiplatform_v1 import MetadataStore as GapicMetadataStore |
| 34 | +from google.cloud.aiplatform.metadata import constants |
| 35 | +from google.cloud.aiplatform_v1 import MetadataServiceClient |
32 | 36 | from google.cloud.aiplatform import pipeline_jobs
|
33 | 37 | from google.cloud.aiplatform.compat.types import pipeline_failure_policy
|
34 | 38 | from google.cloud import storage
|
|
188 | 192 | )
|
189 | 193 | _TEST_PIPELINE_CREATE_TIME = datetime.now()
|
190 | 194 |
|
| 195 | +# experiments |
| 196 | +_TEST_EXPERIMENT = "test-experiment" |
| 197 | + |
| 198 | +_TEST_METADATASTORE = ( |
| 199 | + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default" |
| 200 | +) |
| 201 | +_TEST_CONTEXT_ID = _TEST_EXPERIMENT |
| 202 | +_TEST_CONTEXT_NAME = f"{_TEST_METADATASTORE}/contexts/{_TEST_CONTEXT_ID}" |
| 203 | + |
| 204 | +_EXPERIMENT_MOCK = GapicContext( |
| 205 | + name=_TEST_CONTEXT_NAME, |
| 206 | + schema_title=constants.SYSTEM_EXPERIMENT, |
| 207 | + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], |
| 208 | + metadata={**constants.EXPERIMENT_METADATA}, |
| 209 | +) |
| 210 | + |
191 | 211 |
|
192 | 212 | @pytest.fixture
|
193 | 213 | def mock_pipeline_service_create():
|
@@ -303,6 +323,90 @@ def mock_request_urlopen(job_spec):
|
303 | 323 | yield mock_urlopen
|
304 | 324 |
|
305 | 325 |
|
| 326 | +# experiment mocks |
| 327 | +@pytest.fixture |
| 328 | +def get_metadata_store_mock(): |
| 329 | + with patch.object( |
| 330 | + MetadataServiceClient, "get_metadata_store" |
| 331 | + ) as get_metadata_store_mock: |
| 332 | + get_metadata_store_mock.return_value = GapicMetadataStore( |
| 333 | + name=_TEST_METADATASTORE, |
| 334 | + ) |
| 335 | + yield get_metadata_store_mock |
| 336 | + |
| 337 | + |
| 338 | +@pytest.fixture |
| 339 | +def get_experiment_mock(): |
| 340 | + with patch.object(MetadataServiceClient, "get_context") as get_context_mock: |
| 341 | + get_context_mock.return_value = _EXPERIMENT_MOCK |
| 342 | + yield get_context_mock |
| 343 | + |
| 344 | + |
| 345 | +@pytest.fixture |
| 346 | +def add_context_children_mock(): |
| 347 | + with patch.object( |
| 348 | + MetadataServiceClient, "add_context_children" |
| 349 | + ) as add_context_children_mock: |
| 350 | + yield add_context_children_mock |
| 351 | + |
| 352 | + |
| 353 | +@pytest.fixture |
| 354 | +def list_contexts_mock(): |
| 355 | + with patch.object(MetadataServiceClient, "list_contexts") as list_contexts_mock: |
| 356 | + list_contexts_mock.return_value = [_EXPERIMENT_MOCK] |
| 357 | + yield list_contexts_mock |
| 358 | + |
| 359 | + |
| 360 | +@pytest.fixture |
| 361 | +def create_experiment_run_context_mock(): |
| 362 | + with patch.object(MetadataServiceClient, "create_context") as create_context_mock: |
| 363 | + create_context_mock.side_effect = [_EXPERIMENT_MOCK] |
| 364 | + yield create_context_mock |
| 365 | + |
| 366 | + |
| 367 | +def make_pipeline_job_with_experiment(state): |
| 368 | + return gca_pipeline_job.PipelineJob( |
| 369 | + name=_TEST_PIPELINE_JOB_NAME, |
| 370 | + state=state, |
| 371 | + create_time=_TEST_PIPELINE_CREATE_TIME, |
| 372 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 373 | + network=_TEST_NETWORK, |
| 374 | + job_detail=gca_pipeline_job.PipelineJobDetail( |
| 375 | + pipeline_run_context=gca_context.Context( |
| 376 | + name=_TEST_PIPELINE_JOB_NAME, |
| 377 | + parent_contexts=[_TEST_CONTEXT_NAME], |
| 378 | + ), |
| 379 | + ), |
| 380 | + ) |
| 381 | + |
| 382 | + |
| 383 | +@pytest.fixture |
| 384 | +def mock_create_pipeline_job_with_experiment(): |
| 385 | + with mock.patch.object( |
| 386 | + pipeline_service_client.PipelineServiceClient, "create_pipeline_job" |
| 387 | + ) as mock_pipeline_with_experiment: |
| 388 | + mock_pipeline_with_experiment.return_value = make_pipeline_job_with_experiment( |
| 389 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED |
| 390 | + ) |
| 391 | + yield mock_pipeline_with_experiment |
| 392 | + |
| 393 | + |
| 394 | +@pytest.fixture |
| 395 | +def mock_get_pipeline_job_with_experiment(): |
| 396 | + with mock.patch.object( |
| 397 | + pipeline_service_client.PipelineServiceClient, "get_pipeline_job" |
| 398 | + ) as mock_pipeline_with_experiment: |
| 399 | + mock_pipeline_with_experiment.side_effect = [ |
| 400 | + make_pipeline_job_with_experiment( |
| 401 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING |
| 402 | + ), |
| 403 | + make_pipeline_job_with_experiment( |
| 404 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED |
| 405 | + ), |
| 406 | + ] |
| 407 | + yield mock_pipeline_with_experiment |
| 408 | + |
| 409 | + |
306 | 410 | @pytest.mark.usefixtures("google_auth_mock")
|
307 | 411 | class TestPipelineJob:
|
308 | 412 | def setup_method(self):
|
@@ -1384,3 +1488,90 @@ def test_clone_pipeline_job_with_all_args(
|
1384 | 1488 | assert cloned._gca_resource == make_pipeline_job(
|
1385 | 1489 | gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
|
1386 | 1490 | )
|
| 1491 | + |
| 1492 | + @pytest.mark.parametrize( |
| 1493 | + "job_spec", |
| 1494 | + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], |
| 1495 | + ) |
| 1496 | + def test_get_associated_experiment_from_pipeline_returns_none_without_experiment( |
| 1497 | + self, |
| 1498 | + mock_pipeline_service_create, |
| 1499 | + mock_pipeline_service_get, |
| 1500 | + job_spec, |
| 1501 | + mock_load_yaml_and_json, |
| 1502 | + ): |
| 1503 | + aiplatform.init( |
| 1504 | + project=_TEST_PROJECT, |
| 1505 | + staging_bucket=_TEST_GCS_BUCKET_NAME, |
| 1506 | + location=_TEST_LOCATION, |
| 1507 | + credentials=_TEST_CREDENTIALS, |
| 1508 | + ) |
| 1509 | + |
| 1510 | + job = pipeline_jobs.PipelineJob( |
| 1511 | + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, |
| 1512 | + template_path=_TEST_TEMPLATE_PATH, |
| 1513 | + job_id=_TEST_PIPELINE_JOB_ID, |
| 1514 | + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, |
| 1515 | + enable_caching=True, |
| 1516 | + ) |
| 1517 | + |
| 1518 | + job.submit( |
| 1519 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 1520 | + network=_TEST_NETWORK, |
| 1521 | + create_request_timeout=None, |
| 1522 | + ) |
| 1523 | + |
| 1524 | + job.wait() |
| 1525 | + |
| 1526 | + test_experiment = job.get_associated_experiment() |
| 1527 | + |
| 1528 | + assert test_experiment is None |
| 1529 | + |
| 1530 | + @pytest.mark.parametrize( |
| 1531 | + "job_spec", |
| 1532 | + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], |
| 1533 | + ) |
| 1534 | + def test_get_associated_experiment_from_pipeline_returns_experiment( |
| 1535 | + self, |
| 1536 | + job_spec, |
| 1537 | + mock_load_yaml_and_json, |
| 1538 | + add_context_children_mock, |
| 1539 | + get_experiment_mock, |
| 1540 | + create_experiment_run_context_mock, |
| 1541 | + get_metadata_store_mock, |
| 1542 | + mock_create_pipeline_job_with_experiment, |
| 1543 | + mock_get_pipeline_job_with_experiment, |
| 1544 | + ): |
| 1545 | + aiplatform.init( |
| 1546 | + project=_TEST_PROJECT, |
| 1547 | + staging_bucket=_TEST_GCS_BUCKET_NAME, |
| 1548 | + location=_TEST_LOCATION, |
| 1549 | + credentials=_TEST_CREDENTIALS, |
| 1550 | + ) |
| 1551 | + |
| 1552 | + test_experiment = aiplatform.Experiment(_TEST_EXPERIMENT) |
| 1553 | + |
| 1554 | + job = pipeline_jobs.PipelineJob( |
| 1555 | + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, |
| 1556 | + template_path=_TEST_TEMPLATE_PATH, |
| 1557 | + job_id=_TEST_PIPELINE_JOB_ID, |
| 1558 | + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, |
| 1559 | + enable_caching=True, |
| 1560 | + ) |
| 1561 | + |
| 1562 | + assert get_experiment_mock.call_count == 1 |
| 1563 | + |
| 1564 | + job.submit( |
| 1565 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 1566 | + network=_TEST_NETWORK, |
| 1567 | + create_request_timeout=None, |
| 1568 | + experiment=test_experiment, |
| 1569 | + ) |
| 1570 | + |
| 1571 | + job.wait() |
| 1572 | + |
| 1573 | + associated_experiment = job.get_associated_experiment() |
| 1574 | + |
| 1575 | + assert associated_experiment.resource_name == _TEST_CONTEXT_NAME |
| 1576 | + |
| 1577 | + assert add_context_children_mock.call_count == 1 |
0 commit comments