Skip to content

Commit e9f2c3c

Browse files
authored
feat: add get_associated_experiment method to pipeline_jobs (#1476)
* feat: add get_associated_experiment method to pipeline_jobs * updates from reviewer feedback * clean up system test * re-add check for experiment schema title
1 parent 23a8a27 commit e9f2c3c

File tree

3 files changed

+237
-0
lines changed

3 files changed

+237
-0
lines changed

google/cloud/aiplatform/pipeline_jobs.py

+42
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any, Dict, List, Optional, Union
2323

2424
from google.auth import credentials as auth_credentials
25+
from google.cloud import aiplatform
2526
from google.cloud.aiplatform import base
2627
from google.cloud.aiplatform import initializer
2728
from google.cloud.aiplatform import utils
@@ -770,3 +771,44 @@ def clone(
770771
)
771772

772773
return cloned
774+
775+
def get_associated_experiment(self) -> Optional["aiplatform.Experiment"]:
776+
"""Gets the aiplatform.Experiment associated with this PipelineJob,
777+
or None if this PipelineJob is not associated with an experiment.
778+
779+
Returns:
780+
An aiplatform.Experiment resource or None if this PipelineJob is
781+
not associated with an experiment..
782+
783+
"""
784+
785+
pipeline_parent_contexts = (
786+
self._gca_resource.job_detail.pipeline_run_context.parent_contexts
787+
)
788+
789+
pipeline_experiment_resources = [
790+
context._Context(resource_name=c)._gca_resource
791+
for c in pipeline_parent_contexts
792+
if c != self._gca_resource.job_detail.pipeline_context.name
793+
]
794+
795+
pipeline_experiment_resource_names = []
796+
797+
for c in pipeline_experiment_resources:
798+
if c.schema_title == metadata_constants.SYSTEM_EXPERIMENT:
799+
pipeline_experiment_resource_names.append(c.name)
800+
801+
if len(pipeline_experiment_resource_names) > 1:
802+
_LOGGER.warning(
803+
f"There is more than one Experiment is associated with this pipeline."
804+
f"The following experiments were found: {pipeline_experiment_resource_names.join(', ')}\n"
805+
f"Returning only the following experiment: {pipeline_experiment_resource_names[0]}"
806+
)
807+
808+
if len(pipeline_experiment_resource_names) >= 1:
809+
return experiment_resources.Experiment(
810+
pipeline_experiment_resource_names[0],
811+
project=self.project,
812+
location=self.location,
813+
credentials=self.credentials,
814+
)

tests/system/aiplatform/test_experiments.py

+4
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,10 @@ def pipeline(learning_rate: float, dropout_rate: float):
298298

299299
job.wait()
300300

301+
test_experiment = job.get_associated_experiment()
302+
303+
assert test_experiment.name == self._experiment_name
304+
301305
def test_get_experiments_df(self):
302306
aiplatform.init(
303307
project=e2e_base._PROJECT,

tests/unit/aiplatform/test_pipeline_jobs.py

+191
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
from google.cloud import aiplatform
3030
from google.cloud.aiplatform import base
3131
from google.cloud.aiplatform import initializer
32+
from google.cloud.aiplatform_v1 import Context as GapicContext
33+
from google.cloud.aiplatform_v1 import MetadataStore as GapicMetadataStore
34+
from google.cloud.aiplatform.metadata import constants
35+
from google.cloud.aiplatform_v1 import MetadataServiceClient
3236
from google.cloud.aiplatform import pipeline_jobs
3337
from google.cloud.aiplatform.compat.types import pipeline_failure_policy
3438
from google.cloud import storage
@@ -188,6 +192,22 @@
188192
)
189193
_TEST_PIPELINE_CREATE_TIME = datetime.now()
190194

195+
# experiments
196+
_TEST_EXPERIMENT = "test-experiment"
197+
198+
_TEST_METADATASTORE = (
199+
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default"
200+
)
201+
_TEST_CONTEXT_ID = _TEST_EXPERIMENT
202+
_TEST_CONTEXT_NAME = f"{_TEST_METADATASTORE}/contexts/{_TEST_CONTEXT_ID}"
203+
204+
_EXPERIMENT_MOCK = GapicContext(
205+
name=_TEST_CONTEXT_NAME,
206+
schema_title=constants.SYSTEM_EXPERIMENT,
207+
schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT],
208+
metadata={**constants.EXPERIMENT_METADATA},
209+
)
210+
191211

192212
@pytest.fixture
193213
def mock_pipeline_service_create():
@@ -303,6 +323,90 @@ def mock_request_urlopen(job_spec):
303323
yield mock_urlopen
304324

305325

326+
# experiment mocks
327+
@pytest.fixture
328+
def get_metadata_store_mock():
329+
with patch.object(
330+
MetadataServiceClient, "get_metadata_store"
331+
) as get_metadata_store_mock:
332+
get_metadata_store_mock.return_value = GapicMetadataStore(
333+
name=_TEST_METADATASTORE,
334+
)
335+
yield get_metadata_store_mock
336+
337+
338+
@pytest.fixture
339+
def get_experiment_mock():
340+
with patch.object(MetadataServiceClient, "get_context") as get_context_mock:
341+
get_context_mock.return_value = _EXPERIMENT_MOCK
342+
yield get_context_mock
343+
344+
345+
@pytest.fixture
346+
def add_context_children_mock():
347+
with patch.object(
348+
MetadataServiceClient, "add_context_children"
349+
) as add_context_children_mock:
350+
yield add_context_children_mock
351+
352+
353+
@pytest.fixture
354+
def list_contexts_mock():
355+
with patch.object(MetadataServiceClient, "list_contexts") as list_contexts_mock:
356+
list_contexts_mock.return_value = [_EXPERIMENT_MOCK]
357+
yield list_contexts_mock
358+
359+
360+
@pytest.fixture
361+
def create_experiment_run_context_mock():
362+
with patch.object(MetadataServiceClient, "create_context") as create_context_mock:
363+
create_context_mock.side_effect = [_EXPERIMENT_MOCK]
364+
yield create_context_mock
365+
366+
367+
def make_pipeline_job_with_experiment(state):
368+
return gca_pipeline_job.PipelineJob(
369+
name=_TEST_PIPELINE_JOB_NAME,
370+
state=state,
371+
create_time=_TEST_PIPELINE_CREATE_TIME,
372+
service_account=_TEST_SERVICE_ACCOUNT,
373+
network=_TEST_NETWORK,
374+
job_detail=gca_pipeline_job.PipelineJobDetail(
375+
pipeline_run_context=gca_context.Context(
376+
name=_TEST_PIPELINE_JOB_NAME,
377+
parent_contexts=[_TEST_CONTEXT_NAME],
378+
),
379+
),
380+
)
381+
382+
383+
@pytest.fixture
384+
def mock_create_pipeline_job_with_experiment():
385+
with mock.patch.object(
386+
pipeline_service_client.PipelineServiceClient, "create_pipeline_job"
387+
) as mock_pipeline_with_experiment:
388+
mock_pipeline_with_experiment.return_value = make_pipeline_job_with_experiment(
389+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
390+
)
391+
yield mock_pipeline_with_experiment
392+
393+
394+
@pytest.fixture
395+
def mock_get_pipeline_job_with_experiment():
396+
with mock.patch.object(
397+
pipeline_service_client.PipelineServiceClient, "get_pipeline_job"
398+
) as mock_pipeline_with_experiment:
399+
mock_pipeline_with_experiment.side_effect = [
400+
make_pipeline_job_with_experiment(
401+
gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING
402+
),
403+
make_pipeline_job_with_experiment(
404+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
405+
),
406+
]
407+
yield mock_pipeline_with_experiment
408+
409+
306410
@pytest.mark.usefixtures("google_auth_mock")
307411
class TestPipelineJob:
308412
def setup_method(self):
@@ -1384,3 +1488,90 @@ def test_clone_pipeline_job_with_all_args(
13841488
assert cloned._gca_resource == make_pipeline_job(
13851489
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
13861490
)
1491+
1492+
@pytest.mark.parametrize(
1493+
"job_spec",
1494+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
1495+
)
1496+
def test_get_associated_experiment_from_pipeline_returns_none_without_experiment(
1497+
self,
1498+
mock_pipeline_service_create,
1499+
mock_pipeline_service_get,
1500+
job_spec,
1501+
mock_load_yaml_and_json,
1502+
):
1503+
aiplatform.init(
1504+
project=_TEST_PROJECT,
1505+
staging_bucket=_TEST_GCS_BUCKET_NAME,
1506+
location=_TEST_LOCATION,
1507+
credentials=_TEST_CREDENTIALS,
1508+
)
1509+
1510+
job = pipeline_jobs.PipelineJob(
1511+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
1512+
template_path=_TEST_TEMPLATE_PATH,
1513+
job_id=_TEST_PIPELINE_JOB_ID,
1514+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
1515+
enable_caching=True,
1516+
)
1517+
1518+
job.submit(
1519+
service_account=_TEST_SERVICE_ACCOUNT,
1520+
network=_TEST_NETWORK,
1521+
create_request_timeout=None,
1522+
)
1523+
1524+
job.wait()
1525+
1526+
test_experiment = job.get_associated_experiment()
1527+
1528+
assert test_experiment is None
1529+
1530+
@pytest.mark.parametrize(
1531+
"job_spec",
1532+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
1533+
)
1534+
def test_get_associated_experiment_from_pipeline_returns_experiment(
1535+
self,
1536+
job_spec,
1537+
mock_load_yaml_and_json,
1538+
add_context_children_mock,
1539+
get_experiment_mock,
1540+
create_experiment_run_context_mock,
1541+
get_metadata_store_mock,
1542+
mock_create_pipeline_job_with_experiment,
1543+
mock_get_pipeline_job_with_experiment,
1544+
):
1545+
aiplatform.init(
1546+
project=_TEST_PROJECT,
1547+
staging_bucket=_TEST_GCS_BUCKET_NAME,
1548+
location=_TEST_LOCATION,
1549+
credentials=_TEST_CREDENTIALS,
1550+
)
1551+
1552+
test_experiment = aiplatform.Experiment(_TEST_EXPERIMENT)
1553+
1554+
job = pipeline_jobs.PipelineJob(
1555+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
1556+
template_path=_TEST_TEMPLATE_PATH,
1557+
job_id=_TEST_PIPELINE_JOB_ID,
1558+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
1559+
enable_caching=True,
1560+
)
1561+
1562+
assert get_experiment_mock.call_count == 1
1563+
1564+
job.submit(
1565+
service_account=_TEST_SERVICE_ACCOUNT,
1566+
network=_TEST_NETWORK,
1567+
create_request_timeout=None,
1568+
experiment=test_experiment,
1569+
)
1570+
1571+
job.wait()
1572+
1573+
associated_experiment = job.get_associated_experiment()
1574+
1575+
assert associated_experiment.resource_name == _TEST_CONTEXT_NAME
1576+
1577+
assert add_context_children_mock.call_count == 1

0 commit comments

Comments
 (0)