|
28 | 28 | from google.cloud.aiplatform.metadata import artifact
|
29 | 29 | from google.cloud.aiplatform.metadata import context
|
30 | 30 | from google.cloud.aiplatform.metadata import execution
|
| 31 | +from google.cloud.aiplatform.metadata import utils as metadata_utils |
31 | 32 | from google.cloud.aiplatform_v1 import (
|
32 | 33 | MetadataServiceClient,
|
33 | 34 | AddExecutionEventsResponse,
|
|
39 | 40 | AddContextArtifactsAndExecutionsResponse,
|
40 | 41 | )
|
41 | 42 |
|
| 43 | +import test_models |
| 44 | + |
42 | 45 | # project
|
43 | 46 | _TEST_PROJECT = "test-project"
|
44 | 47 | _TEST_LOCATION = "us-central1"
|
@@ -543,6 +546,34 @@ def test_add_executions_only(self, add_context_artifacts_and_executions_mock):
|
543 | 546 | )
|
544 | 547 |
|
545 | 548 |
|
| 549 | +get_model_with_version_mock = test_models.get_model_with_version |
| 550 | +_VERTEX_MODEL_ARTIFACT_URI = f"https://{_TEST_LOCATION}-aiplatform.googleapis.com/v1/{test_models._TEST_MODEL_OBJ_WITH_VERSION.name}" |
| 551 | + |
| 552 | + |
| 553 | +@pytest.fixture |
| 554 | +def list_vertex_model_artifact_mock(): |
| 555 | + with patch.object(MetadataServiceClient, "list_artifacts") as list_artifacts_mock: |
| 556 | + list_artifacts_mock.return_value = [ |
| 557 | + GapicArtifact( |
| 558 | + name=_TEST_ARTIFACT_NAME, |
| 559 | + uri=_VERTEX_MODEL_ARTIFACT_URI, |
| 560 | + display_name=_TEST_DISPLAY_NAME, |
| 561 | + schema_title=_TEST_SCHEMA_TITLE, |
| 562 | + schema_version=_TEST_SCHEMA_VERSION, |
| 563 | + description=_TEST_DESCRIPTION, |
| 564 | + metadata=_TEST_METADATA, |
| 565 | + ) |
| 566 | + ] |
| 567 | + yield list_artifacts_mock |
| 568 | + |
| 569 | + |
| 570 | +@pytest.fixture |
| 571 | +def list_artifact_empty_mock(): |
| 572 | + with patch.object(MetadataServiceClient, "list_artifacts") as list_artifacts_mock: |
| 573 | + list_artifacts_mock.return_value = [] |
| 574 | + yield list_artifacts_mock |
| 575 | + |
| 576 | + |
546 | 577 | class TestExecution:
|
547 | 578 | def setup_method(self):
|
548 | 579 | reload(initializer)
|
@@ -680,6 +711,86 @@ def test_add_artifact(self, add_execution_events_mock):
|
680 | 711 | events=[Event(artifact=_TEST_ARTIFACT_NAME, type_=Event.Type.OUTPUT)],
|
681 | 712 | )
|
682 | 713 |
|
| 714 | + @pytest.mark.usefixtures("get_execution_mock", "get_model_with_version_mock") |
| 715 | + def test_add_vertex_model( |
| 716 | + self, add_execution_events_mock, list_vertex_model_artifact_mock |
| 717 | + ): |
| 718 | + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) |
| 719 | + |
| 720 | + my_execution = execution.Execution.get_or_create( |
| 721 | + resource_id=_TEST_EXECUTION_ID, |
| 722 | + schema_title=_TEST_SCHEMA_TITLE, |
| 723 | + display_name=_TEST_DISPLAY_NAME, |
| 724 | + schema_version=_TEST_SCHEMA_VERSION, |
| 725 | + description=_TEST_DESCRIPTION, |
| 726 | + metadata=_TEST_METADATA, |
| 727 | + metadata_store_id=_TEST_METADATA_STORE, |
| 728 | + ) |
| 729 | + |
| 730 | + my_model = aiplatform.Model(test_models._TEST_MODEL_NAME) |
| 731 | + my_execution.assign_output_artifacts(artifacts=[my_model]) |
| 732 | + |
| 733 | + list_vertex_model_artifact_mock.assert_called_once_with( |
| 734 | + request=dict( |
| 735 | + parent="projects/test-project/locations/us-central1/metadataStores/default", |
| 736 | + filter=metadata_utils._make_filter_string( |
| 737 | + schema_title="google.VertexModel", uri=_VERTEX_MODEL_ARTIFACT_URI |
| 738 | + ), |
| 739 | + ) |
| 740 | + ) |
| 741 | + |
| 742 | + add_execution_events_mock.assert_called_once_with( |
| 743 | + execution=_TEST_EXECUTION_NAME, |
| 744 | + events=[Event(artifact=_TEST_ARTIFACT_NAME, type_=Event.Type.OUTPUT)], |
| 745 | + ) |
| 746 | + |
| 747 | + @pytest.mark.usefixtures("get_execution_mock", "get_model_with_version_mock") |
| 748 | + def test_add_vertex_model_not_resolved( |
| 749 | + self, add_execution_events_mock, list_artifact_empty_mock, create_artifact_mock |
| 750 | + ): |
| 751 | + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) |
| 752 | + |
| 753 | + my_execution = execution.Execution.get_or_create( |
| 754 | + resource_id=_TEST_EXECUTION_ID, |
| 755 | + schema_title=_TEST_SCHEMA_TITLE, |
| 756 | + display_name=_TEST_DISPLAY_NAME, |
| 757 | + schema_version=_TEST_SCHEMA_VERSION, |
| 758 | + description=_TEST_DESCRIPTION, |
| 759 | + metadata=_TEST_METADATA, |
| 760 | + metadata_store_id=_TEST_METADATA_STORE, |
| 761 | + ) |
| 762 | + |
| 763 | + my_model = aiplatform.Model(test_models._TEST_MODEL_NAME) |
| 764 | + my_execution.assign_output_artifacts(artifacts=[my_model]) |
| 765 | + |
| 766 | + list_artifact_empty_mock.assert_called_once_with( |
| 767 | + request=dict( |
| 768 | + parent="projects/test-project/locations/us-central1/metadataStores/default", |
| 769 | + filter=metadata_utils._make_filter_string( |
| 770 | + schema_title="google.VertexModel", uri=_VERTEX_MODEL_ARTIFACT_URI |
| 771 | + ), |
| 772 | + ) |
| 773 | + ) |
| 774 | + |
| 775 | + expected_artifact = GapicArtifact( |
| 776 | + schema_title="google.VertexModel", |
| 777 | + display_name=test_models._TEST_MODEL_OBJ_WITH_VERSION.display_name, |
| 778 | + uri=_VERTEX_MODEL_ARTIFACT_URI, |
| 779 | + metadata={"resourceName": test_models._TEST_MODEL_OBJ_WITH_VERSION.name}, |
| 780 | + state=GapicArtifact.State.LIVE, |
| 781 | + ) |
| 782 | + |
| 783 | + create_artifact_mock.assert_called_once_with( |
| 784 | + parent="projects/test-project/locations/us-central1/metadataStores/default", |
| 785 | + artifact=expected_artifact, |
| 786 | + artifact_id=None, |
| 787 | + ) |
| 788 | + |
| 789 | + add_execution_events_mock.assert_called_once_with( |
| 790 | + execution=_TEST_EXECUTION_NAME, |
| 791 | + events=[Event(artifact=_TEST_ARTIFACT_NAME, type_=Event.Type.OUTPUT)], |
| 792 | + ) |
| 793 | + |
683 | 794 | @pytest.mark.usefixtures("get_execution_mock")
|
684 | 795 | def test_query_input_and_output_artifacts(
|
685 | 796 | self, query_execution_inputs_and_outputs_mock
|
|
0 commit comments