|
17 | 17 | import copy
|
18 | 18 | from typing import Optional, Dict, List
|
19 | 19 |
|
| 20 | +from google.auth import credentials as auth_credentials |
20 | 21 | from google.cloud.aiplatform.compat.types import artifact as gca_artifact
|
21 | 22 | from google.cloud.aiplatform.metadata.schema import base_artifact
|
22 | 23 | from google.cloud.aiplatform.metadata.schema import utils
|
@@ -359,7 +360,6 @@ def __init__(
|
359 | 360 | extended_metadata = copy.deepcopy(metadata) if metadata else {}
|
360 | 361 | if aggregation_type:
|
361 | 362 | if aggregation_type not in _CLASSIFICATION_METRICS_AGGREGATION_TYPE:
|
362 |
| - ## Todo: add negative test case for this |
363 | 363 | raise ValueError(
|
364 | 364 | "aggregation_type can only be 'AGGREGATION_TYPE_UNSPECIFIED', 'MACRO_AVERAGE', or 'MICRO_AVERAGE'."
|
365 | 365 | )
|
@@ -583,3 +583,158 @@ def __init__(
|
583 | 583 | metadata=extended_metadata,
|
584 | 584 | state=state,
|
585 | 585 | )
|
| 586 | + |
| 587 | + |
| 588 | +class ExperimentModel(base_artifact.BaseArtifactSchema): |
| 589 | + """An artifact representing a Vertex Experiment Model.""" |
| 590 | + |
| 591 | + schema_title = "google.ExperimentModel" |
| 592 | + |
| 593 | + RESERVED_METADATA_KEYS = [ |
| 594 | + "frameworkName", |
| 595 | + "frameworkVersion", |
| 596 | + "modelFile", |
| 597 | + "modelClass", |
| 598 | + "predictSchemata", |
| 599 | + ] |
| 600 | + |
| 601 | + def __init__( |
| 602 | + self, |
| 603 | + *, |
| 604 | + framework_name: str, |
| 605 | + framework_version: str, |
| 606 | + model_file: str, |
| 607 | + uri: str, |
| 608 | + model_class: Optional[str] = None, |
| 609 | + predict_schemata: Optional[utils.PredictSchemata] = None, |
| 610 | + artifact_id: Optional[str] = None, |
| 611 | + display_name: Optional[str] = None, |
| 612 | + schema_version: Optional[str] = None, |
| 613 | + description: Optional[str] = None, |
| 614 | + metadata: Optional[Dict] = None, |
| 615 | + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, |
| 616 | + ): |
| 617 | + """Args: |
| 618 | + framework_name (str): |
| 619 | + Required. The name of the model's framework. E.g., 'sklearn' |
| 620 | + framework_version (str): |
| 621 | + Required. The version of the model's framework. E.g., '1.1.0' |
| 622 | + model_file (str): |
| 623 | + Required. The file name of the model. E.g., 'model.pkl' |
| 624 | + uri (str): |
| 625 | + Required. The uniform resource identifier of the model artifact directory. |
| 626 | + model_class (str): |
| 627 | + Optional. The class name of the model. E.g., 'sklearn.linear_model._base.LinearRegression' |
| 628 | + predict_schemata (PredictSchemata): |
| 629 | + Optional. An instance of PredictSchemata which holds instance, parameter and prediction schema uris. |
| 630 | + artifact_id (str): |
| 631 | + Optional. The <resource_id> portion of the Artifact name with |
| 632 | + the format. This is globally unique in a metadataStore: |
| 633 | + projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>. |
| 634 | + display_name (str): |
| 635 | + Optional. The user-defined name of the Artifact. |
| 636 | + schema_version (str): |
| 637 | + Optional. schema_version specifies the version used by the Artifact. |
| 638 | + If not set, defaults to use the latest version. |
| 639 | + description (str): |
| 640 | + Optional. Describes the purpose of the Artifact to be created. |
| 641 | + metadata (Dict): |
| 642 | + Optional. Contains the metadata information that will be stored in the Artifact. |
| 643 | + state (google.cloud.gapic.types.Artifact.State): |
| 644 | + Optional. The state of this Artifact. This is a |
| 645 | + property of the Artifact, and does not imply or |
| 646 | + apture any ongoing process. This property is |
| 647 | + managed by clients (such as Vertex AI |
| 648 | + Pipelines), and the system does not prescribe or |
| 649 | + check the validity of state transitions. |
| 650 | + """ |
| 651 | + if metadata: |
| 652 | + for k in metadata: |
| 653 | + if k in self.RESERVED_METADATA_KEYS: |
| 654 | + raise ValueError(f"'{k}' is a system reserved key in metadata.") |
| 655 | + extended_metadata = copy.deepcopy(metadata) |
| 656 | + else: |
| 657 | + extended_metadata = {} |
| 658 | + extended_metadata["frameworkName"] = framework_name |
| 659 | + extended_metadata["frameworkVersion"] = framework_version |
| 660 | + extended_metadata["modelFile"] = model_file |
| 661 | + if model_class is not None: |
| 662 | + extended_metadata["modelClass"] = model_class |
| 663 | + if predict_schemata is not None: |
| 664 | + extended_metadata["predictSchemata"] = predict_schemata.to_dict() |
| 665 | + |
| 666 | + super().__init__( |
| 667 | + uri=uri, |
| 668 | + artifact_id=artifact_id, |
| 669 | + display_name=display_name, |
| 670 | + schema_version=schema_version, |
| 671 | + description=description, |
| 672 | + metadata=extended_metadata, |
| 673 | + state=state, |
| 674 | + ) |
| 675 | + |
| 676 | + @classmethod |
| 677 | + def get( |
| 678 | + cls, |
| 679 | + artifact_id: str, |
| 680 | + *, |
| 681 | + metadata_store_id: str = "default", |
| 682 | + project: Optional[str] = None, |
| 683 | + location: Optional[str] = None, |
| 684 | + credentials: Optional[auth_credentials.Credentials] = None, |
| 685 | + ) -> "ExperimentModel": |
| 686 | + """Retrieves an existing ExperimentModel artifact given an artifact id. |
| 687 | +
|
| 688 | + Args: |
| 689 | + artifact_id (str): |
| 690 | + Required. An artifact id of the ExperimentModel artifact. |
| 691 | + metadata_store_id (str): |
| 692 | + Optional. MetadataStore to retrieve Artifact from. If not set, metadata_store_id is set to "default". |
| 693 | + If artifact_id is a fully-qualified resource name, its metadata_store_id overrides this one. |
| 694 | + project (str): |
| 695 | + Optional. Project to retrieve the artifact from. If not set, project |
| 696 | + set in aiplatform.init will be used. |
| 697 | + location (str): |
| 698 | + Optional. Location to retrieve the Artifact from. If not set, location |
| 699 | + set in aiplatform.init will be used. |
| 700 | + credentials (auth_credentials.Credentials): |
| 701 | + Optional. Custom credentials to use to retrieve this Artifact. Overrides |
| 702 | + credentials set in aiplatform.init. |
| 703 | +
|
| 704 | + Returns: |
| 705 | + An ExperimentModel class that represents an Artifact resource. |
| 706 | +
|
| 707 | + Raises: |
| 708 | + ValueError: if artifact's schema title is not 'google.ExperimentModel'. |
| 709 | + """ |
| 710 | + experiment_model = ExperimentModel( |
| 711 | + framework_name="", |
| 712 | + framework_version="", |
| 713 | + model_file="", |
| 714 | + uri="", |
| 715 | + ) |
| 716 | + experiment_model._init_with_resource_name( |
| 717 | + artifact_name=artifact_id, |
| 718 | + metadata_store_id=metadata_store_id, |
| 719 | + project=project, |
| 720 | + location=location, |
| 721 | + credentials=credentials, |
| 722 | + ) |
| 723 | + if experiment_model.schema_title != cls.schema_title: |
| 724 | + raise ValueError( |
| 725 | + f"The schema title of the artifact must be {cls.schema_title}." |
| 726 | + f"Got {experiment_model.schema_title}." |
| 727 | + ) |
| 728 | + return experiment_model |
| 729 | + |
| 730 | + @property |
| 731 | + def framework_name(self) -> Optional[str]: |
| 732 | + return self.metadata.get("frameworkName") |
| 733 | + |
| 734 | + @property |
| 735 | + def framework_version(self) -> Optional[str]: |
| 736 | + return self.metadata.get("frameworkVersion") |
| 737 | + |
| 738 | + @property |
| 739 | + def model_class(self) -> Optional[str]: |
| 740 | + return self.metadata.get("modelClass") |
0 commit comments