Skip to content

feat: Support CMEK for BQ tables #403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
144190e
feat: Support CMEK for BQ tables
shobsi Feb 29, 2024
eda994e
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Feb 29, 2024
6b41ec0
add more tests
shobsi Mar 1, 2024
d2f3f3d
add unit tests
shobsi Mar 1, 2024
425cf12
add more tests, fix broken tests
shobsi Mar 1, 2024
2443d40
separate bqml client to send kms_key_name via OPTIONS instead of job
shobsi Mar 4, 2024
ca50e6c
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 4, 2024
07d97f4
fix unit tests
shobsi Mar 4, 2024
950dd27
fix mypy
shobsi Mar 4, 2024
ee717fe
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 4, 2024
927d8a2
skip cmek test for empty cmek
shobsi Mar 4, 2024
827bef2
move staticmethods to helper module
shobsi Mar 5, 2024
e1b3258
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 6, 2024
f1c8b00
revert bqmlclient, pass cmek through call time job config
shobsi Mar 7, 2024
28064f8
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 7, 2024
073f989
revert bqmlclient unit test
shobsi Mar 7, 2024
79a4b73
fix mypy failure
shobsi Mar 7, 2024
313bb12
use better named key, disable use_query_cache in test
shobsi Mar 7, 2024
9c8d064
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 7, 2024
4d4bb10
rename bqml create model internal method
shobsi Mar 8, 2024
0eff3a6
fix renamed methods's reference in unit tests
shobsi Mar 8, 2024
185edd9
remove stray bqmlclient variable
shobsi Mar 8, 2024
07dca56
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 8, 2024
86ebba7
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 8, 2024
2ae5361
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
separate bqml client to send kms_key_name via OPTIONS instead of job
config
  • Loading branch information
shobsi committed Mar 4, 2024
commit 2443d4029a467fb98a5d3cdb5e2a567b16fcf69a
13 changes: 10 additions & 3 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,11 @@ def principal_component_info(self) -> bpd.DataFrame:
return self._session.read_gbq(sql)

def copy(self, new_model_name: str, replace: bool = False) -> BqmlModel:
job_config = bigquery.job.CopyJobConfig()
job_config = bigquery.job.CopyJobConfig(
destination_encryption_configuration=bigquery.EncryptionConfiguration(
kms_key_name=self._session._bq_kms_key_name
)
)
if replace:
job_config.write_disposition = "WRITE_TRUNCATE"

Expand All @@ -236,7 +240,7 @@ def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel:
options={"vertex_ai_model_id": vertex_ai_model_id}
)
# Register the model and wait it to finish
self._session._start_query(sql)
self._session._start_query_bqml(sql)

self._model = self._session.bqclient.get_model(self.model_name)
return self
Expand All @@ -255,7 +259,7 @@ def _create_model_ref(

def _create_model_with_sql(self, session: bigframes.Session, sql: str) -> BqmlModel:
# fit the model, synchronously
_, job = session._start_query(sql)
_, job = session._start_query_bqml(sql)

# real model path in the session specific hidden dataset and table prefix
model_name_full = f"{job.destination.project}.{job.destination.dataset_id}.{job.destination.table_id}"
Expand Down Expand Up @@ -298,6 +302,9 @@ def create_model(
options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()})

session = X_train._session
if session._bq_kms_key_name:
options.update({"kms_key_name": session._bq_kms_key_name})

model_ref = self._create_model_ref(session._anonymous_dataset)

sql = self._model_creation_sql_generator.create_model(
Expand Down
43 changes: 36 additions & 7 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ def __init__(
def bqclient(self):
return self._clients_provider.bqclient

@property
def bqmlclient(self):
return self._clients_provider.bqmlclient

@property
def bqconnectionclient(self):
return self._clients_provider.bqconnectionclient
Expand Down Expand Up @@ -1509,23 +1513,23 @@ def read_gbq_function(
session=self,
)

def _start_query(
self,
def _start_query_with_client(
bq_client: bigquery.Client,
sql: str,
job_config: Optional[bigquery.job.QueryJobConfig] = None,
max_results: Optional[int] = None,
) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
"""
Starts query job and waits for results.
"""
job_config = self._prepare_job_config(job_config)
job_config = Session._prepare_job_config(job_config)
api_methods = log_adapter.get_and_reset_api_methods()
job_config.labels = bigframes_io.create_job_configs_labels(
job_configs_labels=job_config.labels, api_methods=api_methods
)

try:
query_job = self.bqclient.query(sql, job_config=job_config)
query_job = bq_client.query(sql, job_config=job_config)
except google.api_core.exceptions.Forbidden as ex:
if "Drive credentials" in ex.message:
ex.message += "\nCheck https://cloud.google.com/bigquery/docs/query-drive-data#Google_Drive_permissions."
Expand All @@ -1540,6 +1544,32 @@ def _start_query(
results_iterator = query_job.result(max_results=max_results)
return results_iterator, query_job

def _start_query(
self,
sql: str,
job_config: Optional[bigquery.job.QueryJobConfig] = None,
max_results: Optional[int] = None,
) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
"""
Starts BigQuery query job and waits for results.
"""
return Session._start_query_with_client(
self.bqclient, sql, job_config, max_results
)

def _start_query_bqml(
self,
sql: str,
job_config: Optional[bigquery.job.QueryJobConfig] = None,
max_results: Optional[int] = None,
) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
"""
Starts BigQuery ML query job and waits for results.
"""
return Session._start_query_with_client(
self.bqmlclient, sql, job_config, max_results
)

def _cache_with_cluster_cols(
self, array_value: core.ArrayValue, cluster_cols: typing.Sequence[str]
) -> core.ArrayValue:
Expand Down Expand Up @@ -1681,11 +1711,10 @@ def _start_generic_job(self, job: formatting_helpers.GenericJob):
else:
job.result()

@staticmethod
def _prepare_job_config(
self, job_config: Optional[bigquery.QueryJobConfig] = None
job_config: Optional[bigquery.QueryJobConfig] = None,
) -> bigquery.QueryJobConfig:
if job_config is None:
job_config = self.bqclient.default_query_job_config
if job_config is None:
job_config = bigquery.QueryJobConfig()
if bigframes.options.compute.maximum_bytes_billed is not None:
Expand Down
60 changes: 35 additions & 25 deletions bigframes/session/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,51 +103,61 @@ def __init__(

# cloud clients initialized for lazy load
self._bqclient = None
self._bqmlclient = None
self._bqconnectionclient = None
self._bqstoragereadclient = None
self._cloudfunctionsclient = None
self._resourcemanagerclient = None

def _create_bigquery_client(self):
bq_options = None
if self._use_regional_endpoints:
bq_options = google.api_core.client_options.ClientOptions(
api_endpoint=(
_BIGQUERY_REGIONAL_ENDPOINT
if self._location.lower() in _REP_SUPPORTED_REGIONS
else _BIGQUERY_LOCATIONAL_ENDPOINT
).format(location=self._location),
)
bq_info = google.api_core.client_info.ClientInfo(
user_agent=self._application_name
)

bq_client = bigquery.Client(
client_info=bq_info,
client_options=bq_options,
credentials=self._credentials,
project=self._project,
location=self._location,
)

return bq_client

@property
def bqclient(self):
if not self._bqclient:
bq_options = None
if self._use_regional_endpoints:
bq_options = google.api_core.client_options.ClientOptions(
api_endpoint=(
_BIGQUERY_REGIONAL_ENDPOINT
if self._location.lower() in _REP_SUPPORTED_REGIONS
else _BIGQUERY_LOCATIONAL_ENDPOINT
).format(location=self._location),
)
bq_info = google.api_core.client_info.ClientInfo(
user_agent=self._application_name
)
default_query_job_config = None
default_load_job_config = None
self._bqclient = self._create_bigquery_client()
if self._bq_kms_key_name:
default_query_job_config = bigquery.QueryJobConfig(
self._bqclient.default_query_job_config = bigquery.QueryJobConfig(
destination_encryption_configuration=bigquery.EncryptionConfiguration(
kms_key_name=self._bq_kms_key_name
)
)
default_load_job_config = bigquery.LoadJobConfig(
self._bqclient.default_load_job_config = bigquery.LoadJobConfig(
destination_encryption_configuration=bigquery.EncryptionConfiguration(
kms_key_name=self._bq_kms_key_name
)
)
self._bqclient = bigquery.Client(
client_info=bq_info,
client_options=bq_options,
credentials=self._credentials,
project=self._project,
location=self._location,
default_query_job_config=default_query_job_config,
default_load_job_config=default_load_job_config,
)

return self._bqclient

@property
def bqmlclient(self):
if not self._bqmlclient:
self._bqmlclient = self._create_bigquery_client()

return self._bqmlclient

@property
def bqconnectionclient(self):
if not self._bqconnectionclient:
Expand Down
43 changes: 42 additions & 1 deletion tests/system/small/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest

import bigframes
import bigframes.ml.linear_model


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -116,7 +117,7 @@ def test_session_load_job(bq_cmek, session_with_bq_cmek):
bq_cmek
)

# The result table should exist with the intended encryption
# The load destination table should be created with the intended encryption
table = session_with_bq_cmek.bqclient.get_table(load_job.destination)
assert table.encryption_configuration.kms_key_name == bq_cmek

Expand Down Expand Up @@ -221,3 +222,43 @@ def test_read_pandas_large(bq_cmek, session_with_bq_cmek):

# Assert encryption
_assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek)


def test_bqml(bq_cmek, session_with_bq_cmek, penguins_table_id):
model = bigframes.ml.linear_model.LinearRegression()

df = session_with_bq_cmek.read_gbq(penguins_table_id).dropna()
X_train = df[
[
"species",
"island",
"culmen_length_mm",
"culmen_depth_mm",
"flipper_length_mm",
"sex",
]
]
y_train = df[["body_mass_g"]]
model.fit(X_train, y_train)

assert model is not None
assert model._bqml_model.model.encryption_configuration is not None
assert model._bqml_model.model.encryption_configuration.kms_key_name == bq_cmek

# Assert that model exists in BQ with intended encryption
model_bq = session_with_bq_cmek.bqclient.get_model(model._bqml_model.model_name)
assert model_bq.encryption_configuration.kms_key_name == bq_cmek

# Explicitly save the model to a destination and assert that encryption holds
model_ref = model._bqml_model_factory._create_model_ref(
session_with_bq_cmek._anonymous_dataset
)
model_ref_full_name = (
f"{model_ref.project}.{model_ref.dataset_id}.{model_ref.model_id}"
)
new_model = model.to_gbq(model_ref_full_name)
assert new_model._bqml_model.model.encryption_configuration.kms_key_name == bq_cmek

# Assert that model exists in BQ with intended encryption
model_bq = session_with_bq_cmek.bqclient.get_model(new_model._bqml_model.model_name)
assert model_bq.encryption_configuration.kms_key_name == bq_cmek
1 change: 1 addition & 0 deletions tests/unit/session/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def assert_clients_w_user_agent(
provider: clients.ClientsProvider, expected_user_agent: str
):
assert_constructed_w_user_agent(provider.bqclient, expected_user_agent)
assert_constructed_w_user_agent(provider.bqmlclient, expected_user_agent)
assert_constructed_w_user_agent(provider.bqconnectionclient, expected_user_agent)
assert_constructed_w_user_agent(provider.bqstoragereadclient, expected_user_agent)
assert_constructed_w_user_agent(provider.cloudfunctionsclient, expected_user_agent)
Expand Down