15
15
# limitations under the License.
16
16
#
17
17
18
+ import importlib
18
19
import os
19
20
import pickle
20
21
import tempfile
45
46
"save_method" : "_save_sklearn_model" ,
46
47
"load_method" : "_load_sklearn_model" ,
47
48
"model_file" : "model.pkl" ,
48
- }
49
+ },
50
+ "xgboost" : {
51
+ "save_method" : "_save_xgboost_model" ,
52
+ "load_method" : "_load_xgboost_model" ,
53
+ "model_file" : "model.bst" ,
54
+ },
49
55
}
50
56
51
57
52
58
def save_model (
53
- model : "sklearn.base.BaseEstimator" , # noqa: F821
59
+ model : Union [ "sklearn.base.BaseEstimator" , "xgb.Booster" ] , # noqa: F821
54
60
artifact_id : Optional [str ] = None ,
55
61
* ,
56
62
uri : Optional [str ] = None ,
@@ -63,7 +69,7 @@ def save_model(
63
69
) -> google_artifact_schema .ExperimentModel :
64
70
"""Saves a ML model into a MLMD artifact.
65
71
66
- Supported model frameworks: sklearn.
72
+ Supported model frameworks: sklearn, xgboost .
67
73
68
74
Example usage:
69
75
aiplatform.init(project="my-project", location="my-location", staging_bucket="gs://my-bucket")
@@ -72,7 +78,7 @@ def save_model(
72
78
aiplatform.save_model(model, "my-sklearn-model")
73
79
74
80
Args:
75
- model (sklearn.base.BaseEstimator):
81
+ model (Union[" sklearn.base.BaseEstimator", "xgb.Booster"] ):
76
82
Required. A machine learning model.
77
83
artifact_id (str):
78
84
Optional. The resource id of the artifact. This id must be globally unique
@@ -116,10 +122,23 @@ def save_model(
116
122
except ImportError :
117
123
pass
118
124
else :
119
- if isinstance (model , sklearn .base .BaseEstimator ):
125
+ # An instance of sklearn.base.BaseEstimator might be a sklearn model
126
+ # or a xgboost/lightgbm model implemented on top of sklearn.
127
+ if isinstance (
128
+ model , sklearn .base .BaseEstimator
129
+ ) and model .__class__ .__module__ .startswith ("sklearn" ):
120
130
framework_name = "sklearn"
121
131
framework_version = sklearn .__version__
122
132
133
+ try :
134
+ import xgboost as xgb
135
+ except ImportError :
136
+ pass
137
+ else :
138
+ if isinstance (model , (xgb .Booster , xgb .XGBModel )):
139
+ framework_name = "xgboost"
140
+ framework_version = xgb .__version__
141
+
123
142
if framework_name not in _FRAMEWORK_SPECS :
124
143
raise ValueError (
125
144
f"Model type { model .__class__ .__module__ } .{ model .__class__ .__name__ } not supported."
@@ -305,9 +324,24 @@ def _save_sklearn_model(
305
324
pickle .dump (model , f , protocol = _PICKLE_PROTOCOL )
306
325
307
326
327
+ def _save_xgboost_model (
328
+ model : Union ["xgb.Booster" , "xgb.XGBModel" ], # noqa: F821
329
+ path : str ,
330
+ ):
331
+ """Saves a xgboost model.
332
+
333
+ Args:
334
+ model (Union[xgb.Booster, xgb.XGBModel]):
335
+ Requred. A xgboost model.
336
+ path (str):
337
+ Required. The local path to save the model.
338
+ """
339
+ model .save_model (path )
340
+
341
+
308
342
def load_model (
309
343
model : Union [str , google_artifact_schema .ExperimentModel ]
310
- ) -> "sklearn.base.BaseEstimator" : # noqa: F821
344
+ ) -> Union [ "sklearn.base.BaseEstimator" , "xgb.Booster" ] : # noqa: F821
311
345
"""Retrieves the original ML model from an ExperimentModel resource.
312
346
313
347
Args:
@@ -375,7 +409,44 @@ def _load_sklearn_model(
375
409
return sk_model
376
410
377
411
378
- # TODO(b/264893283)
412
+ def _load_xgboost_model (
413
+ model_file : str ,
414
+ model_artifact : google_artifact_schema .ExperimentModel ,
415
+ ) -> Union ["xgb.Booster" , "xgb.XGBModel" ]: # noqa: F821
416
+ """Loads a xgboost model from local path.
417
+
418
+ Args:
419
+ model_file (str):
420
+ Required. A local model file to load.
421
+ model_artifact (google_artifact_schema.ExperimentModel):
422
+ Required. The artifact that saved the model.
423
+ Returns:
424
+ The xgboost model instance.
425
+
426
+ Raises:
427
+ ImportError: if xgboost is not installed.
428
+ """
429
+ try :
430
+ import xgboost as xgb
431
+ except ImportError :
432
+ raise ImportError (
433
+ "xgboost is not installed and is required for loading models."
434
+ ) from None
435
+
436
+ if xgb .__version__ < model_artifact .framework_version :
437
+ _LOGGER .warning (
438
+ f"The original model was saved via xgboost { model_artifact .framework_version } . "
439
+ f"You are using xgboost { xgb .__version__ } ."
440
+ "Attempting to load model..."
441
+ )
442
+
443
+ module , class_name = model_artifact .model_class .rsplit ("." , maxsplit = 1 )
444
+ xgb_model = getattr (importlib .import_module (module ), class_name )()
445
+ xgb_model .load_model (model_file )
446
+
447
+ return xgb_model
448
+
449
+
379
450
def register_model (
380
451
model : Union [str , google_artifact_schema .ExperimentModel ],
381
452
* ,
0 commit comments