Skip to content

fix!: exclude remote models for .register() #465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bigframes/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __repr__(self):
return prettyprinter.pformat(self)


# TODO(garrettwu): refactor to reflect the actual property. Now the class contains .register() method.
class Predictor(BaseEstimator):
"""A BigQuery DataFrames ML Model base class that can be used to predict outputs."""

Expand Down
6 changes: 3 additions & 3 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@


@log_adapter.class_logger
class PaLM2TextGenerator(base.Predictor):
class PaLM2TextGenerator(base.BaseEstimator):
"""PaLM2 text generator LLM model.

Args:
Expand Down Expand Up @@ -258,7 +258,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator:


@log_adapter.class_logger
class PaLM2TextEmbeddingGenerator(base.Predictor):
class PaLM2TextEmbeddingGenerator(base.BaseEstimator):
"""PaLM2 text embedding generator LLM model.

Args:
Expand Down Expand Up @@ -418,7 +418,7 @@ def to_gbq(


@log_adapter.class_logger
class GeminiTextGenerator(base.Predictor):
class GeminiTextGenerator(base.BaseEstimator):
"""Gemini text generator LLM model.

Args:
Expand Down
17 changes: 4 additions & 13 deletions tests/system/small/ml/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from typing import cast

import pytest

from bigframes.ml import core, imported, linear_model, llm


Expand Down Expand Up @@ -54,19 +56,8 @@ def test_linear_reg_register_with_params(
def test_palm2_text_generator_register(
ephemera_palm2_text_generator_model: llm.PaLM2TextGenerator,
):
model = ephemera_palm2_text_generator_model
model.register()

model_name = "bigframes_" + cast(
str, cast(core.BqmlModel, model._bqml_model).model.model_id
)
# Only registered model contains the field, and the field includes project/dataset. Here only check model_id.
assert (
model_name[:63] # truncated
in cast(core.BqmlModel, model._bqml_model).model.training_runs[-1][
"vertexAiModelId"
]
)
with pytest.raises(AttributeError):
ephemera_palm2_text_generator_model.register() # type: ignore


def test_imported_tensorflow_register(
Expand Down