Skip to content

feat: bigframes.options and bigframes.option_context now uses thread-local variables to prevent context managers in separate threads from affecting each other #652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 62 additions & 12 deletions bigframes/_config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
DataFrames from this package.
"""

import copy
import threading

import bigframes_vendored.pandas._config.config as pandas_config

import bigframes._config.bigquery_options as bigquery_options
Expand All @@ -29,44 +32,91 @@ class Options:
"""Global options affecting BigQuery DataFrames behavior."""

def __init__(self):
self._local = threading.local()

# Initialize these in the property getters to make sure we do have a
# separate instance per thread.
self._local.bigquery_options = None
self._local.display_options = None
self._local.sampling_options = None
self._local.compute_options = None

# BigQuery options are special because they can only be set once per
# session, so we need an indicator as to whether we are using the
# thread-local session or the global session.
self._bigquery_options = bigquery_options.BigQueryOptions()
self._display_options = display_options.DisplayOptions()
self._sampling_options = sampling_options.SamplingOptions()
self._compute_options = compute_options.ComputeOptions()

def _init_bigquery_thread_local(self):
"""Initialize thread-local options, based on current global options."""

# Already thread-local, so don't reset any options that have been set
# already. No locks needed since this only modifies thread-local
# variables.
if self._local.bigquery_options is not None:
return

self._local.bigquery_options = copy.deepcopy(self._bigquery_options)
self._local.bigquery_options._session_started = False

@property
def bigquery(self) -> bigquery_options.BigQueryOptions:
"""Options to use with the BigQuery engine."""
if self._local.bigquery_options is not None:
# The only way we can get here is if someone called
# _init_bigquery_thread_local.
return self._local.bigquery_options

return self._bigquery_options

@property
def display(self) -> display_options.DisplayOptions:
"""Options controlling object representation."""
return self._display_options
if self._local.display_options is None:
self._local.display_options = display_options.DisplayOptions()

return self._local.display_options

@property
def sampling(self) -> sampling_options.SamplingOptions:
"""Options controlling downsampling when downloading data
to memory. The data will be downloaded into memory explicitly
to memory.

The data can be downloaded into memory explicitly
(e.g., to_pandas, to_numpy, values) or implicitly (e.g.,
matplotlib plotting). This option can be overriden by
parameters in specific functions."""
return self._sampling_options
parameters in specific functions.
"""
if self._local.sampling_options is None:
self._local.sampling_options = sampling_options.SamplingOptions()

return self._local.sampling_options

@property
def compute(self) -> compute_options.ComputeOptions:
"""Options controlling object computation."""
return self._compute_options
"""Thread-local options controlling object computation."""
if self._local.compute_options is None:
self._local.compute_options = compute_options.ComputeOptions()

return self._local.compute_options

@property
def is_bigquery_thread_local(self) -> bool:
"""Indicator that we're using a thread-local session.

A thread-local session can be started by using
`with bigframes.option_context("bigquery.some_option", "some-value"):`.
"""
return self._local.bigquery_options is not None


options = Options()
"""Global options for default session."""

option_context = pandas_config.option_context


__all__ = (
"Options",
"options",
"option_context",
)


option_context = pandas_config.option_context
60 changes: 46 additions & 14 deletions bigframes/core/global_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@

_global_session: Optional[bigframes.session.Session] = None
_global_session_lock = threading.Lock()
_global_session_state = threading.local()
_global_session_state.thread_local_session = None


def _try_close_session(session):
"""Try to close the session and warn if couldn't."""
try:
session.close()
except google.auth.exceptions.RefreshError as e:
session_id = session.session_id
location = session._location
project_id = session._project
warnings.warn(
f"Session cleanup failed for session with id: {session_id}, "
f"location: {location}, project: {project_id}",
category=bigframes.exceptions.CleanupFailedWarning,
)
traceback.print_tb(e.__traceback__)


def close_session() -> None:
Expand All @@ -37,24 +55,30 @@ def close_session() -> None:
Returns:
None
"""
global _global_session
global _global_session, _global_session_lock, _global_session_state

if bigframes._config.options.is_bigquery_thread_local:
if _global_session_state.thread_local_session is not None:
_try_close_session(_global_session_state.thread_local_session)
_global_session_state.thread_local_session = None

# Currently using thread-local options, so no global lock needed.
# Don't reset options.bigquery, as that's the responsibility
# of the context manager that started it in the first place. The user
# might have explicitly closed the session in the context manager and
# the thread-locality property needs to be retained.
bigframes._config.options.bigquery._session_started = False

# Don't close the non-thread-local session.
return

with _global_session_lock:
if _global_session is not None:
try:
_global_session.close()
except google.auth.exceptions.RefreshError as e:
session_id = _global_session.session_id
location = _global_session._location
project_id = _global_session._project
warnings.warn(
f"Session cleanup failed for session with id: {session_id}, "
f"location: {location}, project: {project_id}",
category=bigframes.exceptions.CleanupFailedWarning,
)
traceback.print_tb(e.__traceback__)
_try_close_session(_global_session)
_global_session = None

# This should be global, not thread-local because of the if clause
# above.
bigframes._config.options.bigquery._session_started = False


Expand All @@ -63,7 +87,15 @@ def get_global_session():

Creates the global session if it does not exist.
"""
global _global_session, _global_session_lock
global _global_session, _global_session_lock, _global_session_state

if bigframes._config.options.is_bigquery_thread_local:
if _global_session_state.thread_local_session is None:
_global_session_state.thread_local_session = bigframes.session.connect(
bigframes._config.options.bigquery
)

return _global_session_state.thread_local_session

with _global_session_lock:
if _global_session is None:
Expand Down
2 changes: 1 addition & 1 deletion bigframes/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __repr__(self) -> str:
opts = bigframes.options.display
max_results = opts.max_rows
if opts.repr_mode == "deferred":
return formatter.repr_query_job(self.query_job)
return formatter.repr_query_job(self._block._compute_dry_run())

pandas_df, _, query_job = self._block.retrieve_repr_request_results(max_results)
self._query_job = query_job
Expand Down
6 changes: 3 additions & 3 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def __repr__(self) -> str:
opts = bigframes.options.display
max_results = opts.max_rows
if opts.repr_mode == "deferred":
return formatter.repr_query_job(self.query_job)
return formatter.repr_query_job(self._compute_dry_run())

self._cached()
# TODO(swast): pass max_columns and get the true column count back. Maybe
Expand Down Expand Up @@ -632,9 +632,9 @@ def _repr_html_(self) -> str:
many notebooks are not configured for large tables.
"""
opts = bigframes.options.display
max_results = bigframes.options.display.max_rows
max_results = opts.max_rows
if opts.repr_mode == "deferred":
return formatter.repr_query_job_html(self.query_job)
return formatter.repr_query_job(self._compute_dry_run())

self._cached()
# TODO(swast): pass max_columns and get the true column count back. Maybe
Expand Down
2 changes: 1 addition & 1 deletion bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def __repr__(self) -> str:
opts = bigframes.options.display
max_results = opts.max_rows
if opts.repr_mode == "deferred":
return formatter.repr_query_job(self.query_job)
return formatter.repr_query_job(self._compute_dry_run())

self._cached()
pandas_df, _, query_job = self._block.retrieve_repr_request_results(max_results)
Expand Down
120 changes: 66 additions & 54 deletions tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,28 @@ def test_create_text_generator_model_default_session(
):
import bigframes.pandas as bpd

bpd.close_session()
bpd.options.bigquery.bq_connection = bq_connection
bpd.options.bigquery.location = "us"

model = llm.PaLM2TextGenerator()
assert model is not None
assert model._bqml_model is not None
assert (
model.connection_name.casefold()
== f"{bigquery_client.project}.us.bigframes-rf-conn"
)

llm_text_df = bpd.read_pandas(llm_text_pandas_df)

df = model.predict(llm_text_df).to_pandas()
assert df.shape == (3, 4)
assert "ml_generate_text_llm_result" in df.columns
series = df["ml_generate_text_llm_result"]
assert all(series.str.len() > 20)
# Note: This starts a thread-local session.
with bpd.option_context(
"bigquery.bq_connection",
bq_connection,
"bigquery.location",
"US",
):
model = llm.PaLM2TextGenerator()
assert model is not None
assert model._bqml_model is not None
assert (
model.connection_name.casefold()
== f"{bigquery_client.project}.us.bigframes-rf-conn"
)

llm_text_df = bpd.read_pandas(llm_text_pandas_df)

df = model.predict(llm_text_df).to_pandas()
assert df.shape == (3, 4)
assert "ml_generate_text_llm_result" in df.columns
series = df["ml_generate_text_llm_result"]
assert all(series.str.len() > 20)


@pytest.mark.flaky(retries=2)
Expand All @@ -82,25 +85,28 @@ def test_create_text_generator_32k_model_default_session(
):
import bigframes.pandas as bpd

bpd.close_session()
bpd.options.bigquery.bq_connection = bq_connection
bpd.options.bigquery.location = "us"

model = llm.PaLM2TextGenerator(model_name="text-bison-32k")
assert model is not None
assert model._bqml_model is not None
assert (
model.connection_name.casefold()
== f"{bigquery_client.project}.us.bigframes-rf-conn"
)

llm_text_df = bpd.read_pandas(llm_text_pandas_df)

df = model.predict(llm_text_df).to_pandas()
assert df.shape == (3, 4)
assert "ml_generate_text_llm_result" in df.columns
series = df["ml_generate_text_llm_result"]
assert all(series.str.len() > 20)
# Note: This starts a thread-local session.
with bpd.option_context(
"bigquery.bq_connection",
bq_connection,
"bigquery.location",
"US",
):
model = llm.PaLM2TextGenerator(model_name="text-bison-32k")
assert model is not None
assert model._bqml_model is not None
assert (
model.connection_name.casefold()
== f"{bigquery_client.project}.us.bigframes-rf-conn"
)

llm_text_df = bpd.read_pandas(llm_text_pandas_df)

df = model.predict(llm_text_df).to_pandas()
assert df.shape == (3, 4)
assert "ml_generate_text_llm_result" in df.columns
series = df["ml_generate_text_llm_result"]
assert all(series.str.len() > 20)


@pytest.mark.flaky(retries=2)
Expand Down Expand Up @@ -232,27 +238,33 @@ def test_create_embedding_generator_multilingual_model(
def test_create_text_embedding_generator_model_defaults(bq_connection):
import bigframes.pandas as bpd

bpd.close_session()
bpd.options.bigquery.bq_connection = bq_connection
bpd.options.bigquery.location = "us"

model = llm.PaLM2TextEmbeddingGenerator()
assert model is not None
assert model._bqml_model is not None
# Note: This starts a thread-local session.
with bpd.option_context(
"bigquery.bq_connection",
bq_connection,
"bigquery.location",
"US",
):
model = llm.PaLM2TextEmbeddingGenerator()
assert model is not None
assert model._bqml_model is not None


def test_create_text_embedding_generator_multilingual_model_defaults(bq_connection):
import bigframes.pandas as bpd

bpd.close_session()
bpd.options.bigquery.bq_connection = bq_connection
bpd.options.bigquery.location = "us"

model = llm.PaLM2TextEmbeddingGenerator(
model_name="textembedding-gecko-multilingual"
)
assert model is not None
assert model._bqml_model is not None
# Note: This starts a thread-local session.
with bpd.option_context(
"bigquery.bq_connection",
bq_connection,
"bigquery.location",
"US",
):
model = llm.PaLM2TextEmbeddingGenerator(
model_name="textembedding-gecko-multilingual"
)
assert model is not None
assert model._bqml_model is not None


@pytest.mark.flaky(retries=2)
Expand Down
Loading