diff --git a/google/cloud/aiplatform/metadata/schema/base_execution.py b/google/cloud/aiplatform/metadata/schema/base_execution.py index 811b7d9791..c8d55a0637 100644 --- a/google/cloud/aiplatform/metadata/schema/base_execution.py +++ b/google/cloud/aiplatform/metadata/schema/base_execution.py @@ -24,6 +24,7 @@ from google.cloud.aiplatform.compat.types import execution as gca_execution from google.cloud.aiplatform.metadata import constants from google.cloud.aiplatform.metadata import execution +from google.cloud.aiplatform.metadata import metadata class BaseExecutionSchema(metaclass=abc.ABCMeta): @@ -112,3 +113,75 @@ def create( credentials=credentials, ) return self.execution + + def start_execution( + self, + *, + metadata_store_id: Optional[str] = "default", + resume: bool = False, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "execution.Execution": + """Create and starts a new Metadata Execution or resumes a previously created Execution. + + This method is similar to create_execution with additional support for Experiments. + If an Experiment is set prior to running this command, the Experiment will be + associtaed with the created execution, otherwise this method behaves the same + as create_execution. + + To start a new execution: + ``` + instance_of_execution_schema = execution_schema.ContainerExecution(...) + with instance_of_execution_schema.start_execution() as exc: + exc.assign_input_artifacts([my_artifact]) + model = aiplatform.Artifact.create(uri='gs://my-uri', schema_title='system.Model') + exc.assign_output_artifacts([model]) + ``` + + To continue a previously created execution: + ``` + with execution_schema.ContainerExecution(resource_id='my-exc', resume=True) as exc: + ... + ``` + Args: + metadata_store_id (str): + Optional. The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores//executions/ + If not provided, the MetadataStore's ID will be set to "default". Currently only the 'default' + MetadataStore ID is supported. + resume (bool): + Resume an existing execution. + project (str): + Optional. Project used to create this Execution. Overrides project set in + aiplatform.init. + location (str): + Optional. Location used to create this Execution. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials used to create this Execution. Overrides + credentials set in aiplatform.init. + Returns: + Execution: Instantiated representation of the managed Metadata Execution. + Raises: + ValueError: If metadata_store_id other than 'default' is provided. + """ + if metadata_store_id != "default": + raise ValueError( + f"metadata_store_id {metadata_store_id} is not supported. Only the default MetadataStore ID is supported." + ) + + return metadata._ExperimentTracker().start_execution( + schema_title=self.schema_title, + display_name=self.display_name, + resource_id=self.execution_id, + metadata=self.metadata, + schema_version=self.schema_version, + description=self.description, + # TODO: Add support for metadata_store_id once it is supported in experiment. + resume=resume, + project=project, + location=location, + credentials=credentials, + ) diff --git a/tests/unit/aiplatform/test_metadata_schema.py b/tests/unit/aiplatform/test_metadata_schema.py index cbf7d38609..800a490cc8 100644 --- a/tests/unit/aiplatform/test_metadata_schema.py +++ b/tests/unit/aiplatform/test_metadata_schema.py @@ -561,3 +561,30 @@ def test_container_spec_to_dict_method_returns_correct_schema(self): } assert json.dumps(container_spec.to_dict()) == json.dumps(expected_results) + + @pytest.mark.usefixtures("create_execution_mock") + def test_start_execution_method_calls_gapic_library_with_correct_parameters( + self, create_execution_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + class TestExecution(base_execution.BaseExecutionSchema): + schema_title = _TEST_SCHEMA_TITLE + + execution = TestExecution( + state=_TEST_EXECUTION_STATE, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + execution.start_execution() + create_execution_mock.assert_called_once_with( + parent=f"{_TEST_PARENT}/metadataStores/default", + execution=mock.ANY, + execution_id=None, + ) + _, _, kwargs = create_execution_mock.mock_calls[0] + assert kwargs["execution"].schema_title == _TEST_SCHEMA_TITLE + assert kwargs["execution"].display_name == _TEST_DISPLAY_NAME + assert kwargs["execution"].description == _TEST_DESCRIPTION + assert kwargs["execution"].metadata == _TEST_UPDATED_METADATA