|
19 | 19 |
|
20 | 20 | from typing import List, Literal, Optional, Union
|
21 | 21 |
|
| 22 | +import bigframes_vendored.sklearn.decomposition._mf |
22 | 23 | import bigframes_vendored.sklearn.decomposition._pca
|
23 | 24 | from google.cloud import bigquery
|
24 | 25 |
|
|
27 | 28 | import bigframes.pandas as bpd
|
28 | 29 | import bigframes.session
|
29 | 30 |
|
30 |
| -_BQML_PARAMS_MAPPING = {"svd_solver": "pcaSolver"} |
| 31 | +_BQML_PARAMS_MAPPING = { |
| 32 | + "svd_solver": "pcaSolver", |
| 33 | + "feedback_type": "feedbackType", |
| 34 | + "num_factors": "numFactors", |
| 35 | + "user_col": "userColumn", |
| 36 | + "item_col": "itemColumn", |
| 37 | + "_input_label_columns": "inputLabelColumns", |
| 38 | + "l2_reg": "l2Regularization", |
| 39 | +} |
31 | 40 |
|
32 | 41 |
|
33 | 42 | @log_adapter.class_logger
|
@@ -197,3 +206,159 @@ def score(
|
197 | 206 |
|
198 | 207 | # TODO(b/291973741): X param is ignored. Update BQML supports input in ML.EVALUATE.
|
199 | 208 | return self._bqml_model.evaluate()
|
| 209 | + |
| 210 | + |
| 211 | +@log_adapter.class_logger |
| 212 | +class MatrixFactorization( |
| 213 | + base.UnsupervisedTrainablePredictor, |
| 214 | + bigframes_vendored.sklearn.decomposition._mf.MatrixFactorization, |
| 215 | +): |
| 216 | + __doc__ = bigframes_vendored.sklearn.decomposition._mf.MatrixFactorization.__doc__ |
| 217 | + |
| 218 | + def __init__( |
| 219 | + self, |
| 220 | + *, |
| 221 | + feedback_type: Literal["explicit", "implicit"] = "explicit", |
| 222 | + num_factors: int, |
| 223 | + user_col: str, |
| 224 | + item_col: str, |
| 225 | + rating_col: str = "rating", |
| 226 | + # TODO: Add support for hyperparameter tuning. |
| 227 | + l2_reg: float = 1.0, |
| 228 | + ): |
| 229 | + |
| 230 | + feedback_type = feedback_type.lower() # type: ignore |
| 231 | + if feedback_type not in ("explicit", "implicit"): |
| 232 | + raise ValueError("Expected feedback_type to be `explicit` or `implicit`.") |
| 233 | + |
| 234 | + self.feedback_type = feedback_type |
| 235 | + |
| 236 | + if not isinstance(num_factors, int): |
| 237 | + raise TypeError( |
| 238 | + f"Expected num_factors to be an int, but got {type(num_factors)}." |
| 239 | + ) |
| 240 | + |
| 241 | + if num_factors < 0: |
| 242 | + raise ValueError( |
| 243 | + f"Expected num_factors to be a positive integer, but got {num_factors}." |
| 244 | + ) |
| 245 | + |
| 246 | + self.num_factors = num_factors |
| 247 | + |
| 248 | + if not isinstance(user_col, str): |
| 249 | + raise TypeError(f"Expected user_col to be a str, but got {type(user_col)}.") |
| 250 | + |
| 251 | + self.user_col = user_col |
| 252 | + |
| 253 | + if not isinstance(item_col, str): |
| 254 | + raise TypeError(f"Expected item_col to be STR, but got {type(item_col)}.") |
| 255 | + |
| 256 | + self.item_col = item_col |
| 257 | + |
| 258 | + if not isinstance(rating_col, str): |
| 259 | + raise TypeError( |
| 260 | + f"Expected rating_col to be a str, but got {type(rating_col)}." |
| 261 | + ) |
| 262 | + |
| 263 | + self._input_label_columns = [rating_col] |
| 264 | + |
| 265 | + if not isinstance(l2_reg, (float, int)): |
| 266 | + raise TypeError( |
| 267 | + f"Expected l2_reg to be a float or int, but got {type(l2_reg)}." |
| 268 | + ) |
| 269 | + |
| 270 | + self.l2_reg = l2_reg |
| 271 | + self._bqml_model: Optional[core.BqmlModel] = None |
| 272 | + self._bqml_model_factory = globals.bqml_model_factory() |
| 273 | + |
| 274 | + @property |
| 275 | + def rating_col(self) -> str: |
| 276 | + """str: The rating column name. Defaults to 'rating'.""" |
| 277 | + return self._input_label_columns[0] |
| 278 | + |
| 279 | + @classmethod |
| 280 | + def _from_bq( |
| 281 | + cls, session: bigframes.session.Session, bq_model: bigquery.Model |
| 282 | + ) -> MatrixFactorization: |
| 283 | + assert bq_model.model_type == "MATRIX_FACTORIZATION" |
| 284 | + |
| 285 | + kwargs = utils.retrieve_params_from_bq_model( |
| 286 | + cls, bq_model, _BQML_PARAMS_MAPPING |
| 287 | + ) |
| 288 | + |
| 289 | + model = cls(**kwargs) |
| 290 | + model._bqml_model = core.BqmlModel(session, bq_model) |
| 291 | + return model |
| 292 | + |
| 293 | + @property |
| 294 | + def _bqml_options(self) -> dict: |
| 295 | + """The model options as they will be set for BQML""" |
| 296 | + options: dict = { |
| 297 | + "model_type": "matrix_factorization", |
| 298 | + "feedback_type": self.feedback_type, |
| 299 | + "user_col": self.user_col, |
| 300 | + "item_col": self.item_col, |
| 301 | + "rating_col": self.rating_col, |
| 302 | + "l2_reg": self.l2_reg, |
| 303 | + } |
| 304 | + |
| 305 | + if self.num_factors is not None: |
| 306 | + options["num_factors"] = self.num_factors |
| 307 | + |
| 308 | + return options |
| 309 | + |
| 310 | + def _fit( |
| 311 | + self, |
| 312 | + X: utils.ArrayType, |
| 313 | + y=None, |
| 314 | + transforms: Optional[List[str]] = None, |
| 315 | + ) -> MatrixFactorization: |
| 316 | + if y is not None: |
| 317 | + raise ValueError( |
| 318 | + "Label column not supported for Matrix Factorization model but y was not `None`" |
| 319 | + ) |
| 320 | + |
| 321 | + (X,) = utils.batch_convert_to_dataframe(X) |
| 322 | + |
| 323 | + self._bqml_model = self._bqml_model_factory.create_model( |
| 324 | + X_train=X, |
| 325 | + transforms=transforms, |
| 326 | + options=self._bqml_options, |
| 327 | + ) |
| 328 | + return self |
| 329 | + |
| 330 | + def predict(self, X: utils.ArrayType) -> bpd.DataFrame: |
| 331 | + if not self._bqml_model: |
| 332 | + raise RuntimeError("A model must be fitted before recommend") |
| 333 | + |
| 334 | + (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) |
| 335 | + |
| 336 | + return self._bqml_model.recommend(X) |
| 337 | + |
| 338 | + def to_gbq(self, model_name: str, replace: bool = False) -> MatrixFactorization: |
| 339 | + """Save the model to BigQuery. |
| 340 | +
|
| 341 | + Args: |
| 342 | + model_name (str): |
| 343 | + The name of the model. |
| 344 | + replace (bool, default False): |
| 345 | + Determine whether to replace if the model already exists. Default to False. |
| 346 | +
|
| 347 | + Returns: |
| 348 | + MatrixFactorization: Saved model.""" |
| 349 | + if not self._bqml_model: |
| 350 | + raise RuntimeError("A model must be fitted before it can be saved") |
| 351 | + |
| 352 | + new_model = self._bqml_model.copy(model_name, replace) |
| 353 | + return new_model.session.read_gbq_model(model_name) |
| 354 | + |
| 355 | + def score( |
| 356 | + self, |
| 357 | + X=None, |
| 358 | + y=None, |
| 359 | + ) -> bpd.DataFrame: |
| 360 | + if not self._bqml_model: |
| 361 | + raise RuntimeError("A model must be fitted before score") |
| 362 | + |
| 363 | + # TODO(b/291973741): X param is ignored. Update BQML supports input in ML.EVALUATE. |
| 364 | + return self._bqml_model.evaluate() |
0 commit comments