Skip to content

feat: add to_gbq() method for LLM models #299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from typing import cast, Literal, Optional, Union
import warnings

from google.cloud import bigquery

import bigframes
from bigframes import clients, constants
from bigframes.core import blocks, log_adapter
Expand Down Expand Up @@ -113,6 +115,26 @@ def _create_bqml_model(self):
session=self.session, connection_name=self.connection_name, options=options
)

@classmethod
def _from_bq(
cls, session: bigframes.Session, model: bigquery.Model
) -> PaLM2TextGenerator:
assert model.model_type == "MODEL_TYPE_UNSPECIFIED"
assert "remoteModelInfo" in model._properties
assert "endpoint" in model._properties["remoteModelInfo"]
assert "connection" in model._properties["remoteModelInfo"]

# Parse the remote model endpoint
bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"]
model_connection = model._properties["remoteModelInfo"]["connection"]
model_endpoint = bqml_endpoint.split("/")[-1]

text_generator_model = cls(
session=session, model_name=model_endpoint, connection_name=model_connection
)
text_generator_model._bqml_model = core.BqmlModel(session, model)
return text_generator_model

def predict(
self,
X: Union[bpd.DataFrame, bpd.Series],
Expand Down Expand Up @@ -200,6 +222,21 @@ def predict(

return df

def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator:
"""Save the model to BigQuery.

Args:
model_name (str):
the name of the model.
replace (bool, default False):
whether to replace if the model already exists. Default to False.

Returns:
PaLM2TextGenerator: saved model."""

new_model = self._bqml_model.copy(model_name, replace)
return new_model.session.read_gbq_model(model_name)


@log_adapter.class_logger
class PaLM2TextEmbeddingGenerator(base.Predictor):
Expand Down Expand Up @@ -271,6 +308,26 @@ def _create_bqml_model(self):
session=self.session, connection_name=self.connection_name, options=options
)

@classmethod
def _from_bq(
cls, session: bigframes.Session, model: bigquery.Model
) -> PaLM2TextEmbeddingGenerator:
assert model.model_type == "MODEL_TYPE_UNSPECIFIED"
assert "remoteModelInfo" in model._properties
assert "endpoint" in model._properties["remoteModelInfo"]
assert "connection" in model._properties["remoteModelInfo"]

# Parse the remote model endpoint
bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"]
model_connection = model._properties["remoteModelInfo"]["connection"]
model_endpoint = bqml_endpoint.split("/")[-1]

embedding_generator_model = cls(
session=session, model_name=model_endpoint, connection_name=model_connection
)
embedding_generator_model._bqml_model = core.BqmlModel(session, model)
return embedding_generator_model

def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
"""Predict the result from input DataFrame.

Expand Down Expand Up @@ -307,3 +364,20 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
)

return df

def to_gbq(
self, model_name: str, replace: bool = False
) -> PaLM2TextEmbeddingGenerator:
"""Save the model to BigQuery.

Args:
model_name (str):
the name of the model.
replace (bool, default False):
whether to replace if the model already exists. Default to False.

Returns:
PaLM2TextEmbeddingGenerator: saved model."""

new_model = self._bqml_model.copy(model_name, replace)
return new_model.session.read_gbq_model(model_name)
23 changes: 23 additions & 0 deletions bigframes/ml/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
forecasting,
imported,
linear_model,
llm,
pipeline,
)

Expand All @@ -47,6 +48,15 @@
}
)

_BQML_ENDPOINT_TYPE_MAPPING = MappingProxyType(
{
llm._TEXT_GENERATOR_BISON_ENDPOINT: llm.PaLM2TextGenerator,
llm._TEXT_GENERATOR_BISON_32K_ENDPOINT: llm.PaLM2TextGenerator,
llm._EMBEDDING_GENERATOR_GECKO_ENDPOINT: llm.PaLM2TextEmbeddingGenerator,
llm._EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT: llm.PaLM2TextEmbeddingGenerator,
}
)


def from_bq(
session: bigframes.Session, bq_model: bigquery.Model
Expand All @@ -62,6 +72,8 @@ def from_bq(
ensemble.RandomForestClassifier,
imported.TensorFlowModel,
imported.ONNXModel,
llm.PaLM2TextGenerator,
llm.PaLM2TextEmbeddingGenerator,
pipeline.Pipeline,
]:
"""Load a BQML model to BigQuery DataFrames ML.
Expand All @@ -84,6 +96,17 @@ def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model):
return _BQML_MODEL_TYPE_MAPPING[bq_model.model_type]._from_bq( # type: ignore
session=session, model=bq_model
)
if (
bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
and "remoteModelInfo" in bq_model._properties
and "endpoint" in bq_model._properties["remoteModelInfo"]
):
# Parse the remote model endpoint
bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]
endpoint_model = bqml_endpoint.split("/")[-1]
return _BQML_ENDPOINT_TYPE_MAPPING[endpoint_model]._from_bq( # type: ignore
session=session, model=bq_model
)

raise NotImplementedError(
f"Model type {bq_model.model_type} is not yet supported by BigQuery DataFrames. {constants.FEEDBACK_LINK}"
Expand Down
50 changes: 48 additions & 2 deletions tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,37 @@
from bigframes.ml import llm


def test_create_text_generator_model(palm2_text_generator_model):
def test_create_text_generator_model(
palm2_text_generator_model, dataset_id, bq_connection
):
# Model creation doesn't return error
assert palm2_text_generator_model is not None
assert palm2_text_generator_model._bqml_model is not None

# save, load to ensure configuration was kept
reloaded_model = palm2_text_generator_model.to_gbq(
f"{dataset_id}.temp_text_model", replace=True
)
assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name
assert reloaded_model.model_name == "text-bison"
assert reloaded_model.connection_name == bq_connection


def test_create_text_generator_32k_model(
palm2_text_generator_32k_model, dataset_id, bq_connection
):
# Model creation doesn't return error
assert palm2_text_generator_32k_model is not None
assert palm2_text_generator_32k_model._bqml_model is not None

# save, load to ensure configuration was kept
reloaded_model = palm2_text_generator_32k_model.to_gbq(
f"{dataset_id}.temp_text_model", replace=True
)
assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name
assert reloaded_model.model_name == "text-bison-32k"
assert reloaded_model.connection_name == bq_connection


@pytest.mark.flaky(retries=2, delay=120)
def test_create_text_generator_model_default_session(
Expand Down Expand Up @@ -152,19 +178,39 @@ def test_text_generator_predict_with_params_success(
assert all(series.str.len() > 20)


def test_create_embedding_generator_model(palm2_embedding_generator_model):
def test_create_embedding_generator_model(
palm2_embedding_generator_model, dataset_id, bq_connection
):
# Model creation doesn't return error
assert palm2_embedding_generator_model is not None
assert palm2_embedding_generator_model._bqml_model is not None

# save, load to ensure configuration was kept
reloaded_model = palm2_embedding_generator_model.to_gbq(
f"{dataset_id}.temp_embedding_model", replace=True
)
assert f"{dataset_id}.temp_embedding_model" == reloaded_model._bqml_model.model_name
assert reloaded_model.model_name == "textembedding-gecko"
assert reloaded_model.connection_name == bq_connection


def test_create_embedding_generator_multilingual_model(
palm2_embedding_generator_multilingual_model,
dataset_id,
bq_connection,
):
# Model creation doesn't return error
assert palm2_embedding_generator_multilingual_model is not None
assert palm2_embedding_generator_multilingual_model._bqml_model is not None

# save, load to ensure configuration was kept
reloaded_model = palm2_embedding_generator_multilingual_model.to_gbq(
f"{dataset_id}.temp_embedding_model", replace=True
)
assert f"{dataset_id}.temp_embedding_model" == reloaded_model._bqml_model.model_name
assert reloaded_model.model_name == "textembedding-gecko-multilingual"
assert reloaded_model.connection_name == bq_connection


def test_create_text_embedding_generator_model_defaults(bq_connection):
import bigframes.pandas as bpd
Expand Down