1
1
# -*- coding: utf-8 -*-
2
2
3
- # Copyright 2022 Google LLC
3
+ # Copyright 2023 Google LLC
4
4
#
5
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
6
# you may not use this file except in compliance with the License.
@@ -617,39 +617,41 @@ def __init__(
617
617
metadata : Optional [Dict ] = None ,
618
618
state : Optional [gca_artifact .Artifact .State ] = gca_artifact .Artifact .State .LIVE ,
619
619
):
620
- """Args:
621
- framework_name (str):
622
- Required. The name of the model's framework. E.g., 'sklearn'
623
- framework_version (str):
624
- Required. The version of the model's framework. E.g., '1.1.0'
625
- model_file (str):
626
- Required. The file name of the model. E.g., 'model.pkl'
627
- uri (str):
628
- Required. The uniform resource identifier of the model artifact directory.
629
- model_class (str):
630
- Optional. The class name of the model. E.g., 'sklearn.linear_model._base.LinearRegression'
631
- predict_schemata (PredictSchemata):
632
- Optional. An instance of PredictSchemata which holds instance, parameter and prediction schema uris.
633
- artifact_id (str):
634
- Optional. The <resource_id> portion of the Artifact name with
635
- the format. This is globally unique in a metadataStore:
636
- projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
637
- display_name (str):
638
- Optional. The user-defined name of the Artifact.
639
- schema_version (str):
640
- Optional. schema_version specifies the version used by the Artifact.
641
- If not set, defaults to use the latest version.
642
- description (str):
643
- Optional. Describes the purpose of the Artifact to be created.
644
- metadata (Dict):
645
- Optional. Contains the metadata information that will be stored in the Artifact.
646
- state (google.cloud.gapic.types.Artifact.State):
647
- Optional. The state of this Artifact. This is a
648
- property of the Artifact, and does not imply or
649
- apture any ongoing process. This property is
650
- managed by clients (such as Vertex AI
651
- Pipelines), and the system does not prescribe or
652
- check the validity of state transitions.
620
+ """Instantiates an ExperimentModel that represents a saved ML model.
621
+
622
+ Args:
623
+ framework_name (str):
624
+ Required. The name of the model's framework. E.g., 'sklearn'
625
+ framework_version (str):
626
+ Required. The version of the model's framework. E.g., '1.1.0'
627
+ model_file (str):
628
+ Required. The file name of the model. E.g., 'model.pkl'
629
+ uri (str):
630
+ Required. The uniform resource identifier of the model artifact directory.
631
+ model_class (str):
632
+ Optional. The class name of the model. E.g., 'sklearn.linear_model._base.LinearRegression'
633
+ predict_schemata (PredictSchemata):
634
+ Optional. An instance of PredictSchemata which holds instance, parameter and prediction schema uris.
635
+ artifact_id (str):
636
+ Optional. The <resource_id> portion of the Artifact name with
637
+ the format. This is globally unique in a metadataStore:
638
+ projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
639
+ display_name (str):
640
+ Optional. The user-defined name of the Artifact.
641
+ schema_version (str):
642
+ Optional. schema_version specifies the version used by the Artifact.
643
+ If not set, defaults to use the latest version.
644
+ description (str):
645
+ Optional. Describes the purpose of the Artifact to be created.
646
+ metadata (Dict):
647
+ Optional. Contains the metadata information that will be stored in the Artifact.
648
+ state (google.cloud.gapic.types.Artifact.State):
649
+ Optional. The state of this Artifact. This is a
650
+ property of the Artifact, and does not imply or
651
+ apture any ongoing process. This property is
652
+ managed by clients (such as Vertex AI
653
+ Pipelines), and the system does not prescribe or
654
+ check the validity of state transitions.
653
655
"""
654
656
if metadata :
655
657
for k in metadata :
@@ -732,14 +734,17 @@ def get(
732
734
733
735
@property
734
736
def framework_name (self ) -> Optional [str ]:
737
+ """The framework name of the saved ML model."""
735
738
return self .metadata .get ("frameworkName" )
736
739
737
740
@property
738
741
def framework_version (self ) -> Optional [str ]:
742
+ """The framework version of the saved ML model."""
739
743
return self .metadata .get ("frameworkVersion" )
740
744
741
745
@property
742
746
def model_class (self ) -> Optional [str ]:
747
+ "The class name of the saved ML model."
743
748
return self .metadata .get ("modelClass" )
744
749
745
750
def get_model_info (self ) -> Dict [str , Any ]:
@@ -756,10 +761,12 @@ def load_model(
756
761
) -> Union ["sklearn.base.BaseEstimator" , "xgb.Booster" , "tf.Module" ]: # noqa: F821
757
762
"""Retrieves the original ML model from an ExperimentModel.
758
763
759
- Example usage:
760
- experiment_model = aiplatform.get_experiment_model("my-sklearn-model")
761
- sk_model = experiment_model.load_model()
762
- pred_y = model.predict(test_X)
764
+ Example Usage:
765
+ ```
766
+ experiment_model = aiplatform.get_experiment_model("my-sklearn-model")
767
+ sk_model = experiment_model.load_model()
768
+ pred_y = model.predict(test_X)
769
+ ```
763
770
764
771
Returns:
765
772
The original ML model.
@@ -803,10 +810,12 @@ def register_model(
803
810
) -> Model :
804
811
"""Register an ExperimentModel to Model Registry and returns a Model representing the registered Model resource.
805
812
806
- Example usage:
807
- experiment_model = aiplatform.get_experiment_model("my-sklearn-model")
808
- registered_model = experiment_model.register_model()
809
- registered_model.deploy(endpoint=my_endpoint)
813
+ Example Usage:
814
+ ```
815
+ experiment_model = aiplatform.get_experiment_model("my-sklearn-model")
816
+ registered_model = experiment_model.register_model()
817
+ registered_model.deploy(endpoint=my_endpoint)
818
+ ```
810
819
811
820
Args:
812
821
model_id (str):
0 commit comments