diff --git a/bigframes/_config/__init__.py b/bigframes/_config/__init__.py index bdd7a8f2d6..bf33420e60 100644 --- a/bigframes/_config/__init__.py +++ b/bigframes/_config/__init__.py @@ -17,6 +17,9 @@ DataFrames from this package. """ +import copy +import threading + import bigframes_vendored.pandas._config.config as pandas_config import bigframes._config.bigquery_options as bigquery_options @@ -29,44 +32,91 @@ class Options: """Global options affecting BigQuery DataFrames behavior.""" def __init__(self): + self._local = threading.local() + + # Initialize these in the property getters to make sure we do have a + # separate instance per thread. + self._local.bigquery_options = None + self._local.display_options = None + self._local.sampling_options = None + self._local.compute_options = None + + # BigQuery options are special because they can only be set once per + # session, so we need an indicator as to whether we are using the + # thread-local session or the global session. self._bigquery_options = bigquery_options.BigQueryOptions() - self._display_options = display_options.DisplayOptions() - self._sampling_options = sampling_options.SamplingOptions() - self._compute_options = compute_options.ComputeOptions() + + def _init_bigquery_thread_local(self): + """Initialize thread-local options, based on current global options.""" + + # Already thread-local, so don't reset any options that have been set + # already. No locks needed since this only modifies thread-local + # variables. + if self._local.bigquery_options is not None: + return + + self._local.bigquery_options = copy.deepcopy(self._bigquery_options) + self._local.bigquery_options._session_started = False @property def bigquery(self) -> bigquery_options.BigQueryOptions: """Options to use with the BigQuery engine.""" + if self._local.bigquery_options is not None: + # The only way we can get here is if someone called + # _init_bigquery_thread_local. + return self._local.bigquery_options + return self._bigquery_options @property def display(self) -> display_options.DisplayOptions: """Options controlling object representation.""" - return self._display_options + if self._local.display_options is None: + self._local.display_options = display_options.DisplayOptions() + + return self._local.display_options @property def sampling(self) -> sampling_options.SamplingOptions: """Options controlling downsampling when downloading data - to memory. The data will be downloaded into memory explicitly + to memory. + + The data can be downloaded into memory explicitly (e.g., to_pandas, to_numpy, values) or implicitly (e.g., matplotlib plotting). This option can be overriden by - parameters in specific functions.""" - return self._sampling_options + parameters in specific functions. + """ + if self._local.sampling_options is None: + self._local.sampling_options = sampling_options.SamplingOptions() + + return self._local.sampling_options @property def compute(self) -> compute_options.ComputeOptions: - """Options controlling object computation.""" - return self._compute_options + """Thread-local options controlling object computation.""" + if self._local.compute_options is None: + self._local.compute_options = compute_options.ComputeOptions() + + return self._local.compute_options + + @property + def is_bigquery_thread_local(self) -> bool: + """Indicator that we're using a thread-local session. + + A thread-local session can be started by using + `with bigframes.option_context("bigquery.some_option", "some-value"):`. + """ + return self._local.bigquery_options is not None options = Options() """Global options for default session.""" +option_context = pandas_config.option_context + __all__ = ( "Options", "options", + "option_context", ) - - -option_context = pandas_config.option_context diff --git a/bigframes/core/global_session.py b/bigframes/core/global_session.py index 31dfc9bd17..3187c5c11b 100644 --- a/bigframes/core/global_session.py +++ b/bigframes/core/global_session.py @@ -26,6 +26,24 @@ _global_session: Optional[bigframes.session.Session] = None _global_session_lock = threading.Lock() +_global_session_state = threading.local() +_global_session_state.thread_local_session = None + + +def _try_close_session(session): + """Try to close the session and warn if couldn't.""" + try: + session.close() + except google.auth.exceptions.RefreshError as e: + session_id = session.session_id + location = session._location + project_id = session._project + warnings.warn( + f"Session cleanup failed for session with id: {session_id}, " + f"location: {location}, project: {project_id}", + category=bigframes.exceptions.CleanupFailedWarning, + ) + traceback.print_tb(e.__traceback__) def close_session() -> None: @@ -37,24 +55,30 @@ def close_session() -> None: Returns: None """ - global _global_session + global _global_session, _global_session_lock, _global_session_state + + if bigframes._config.options.is_bigquery_thread_local: + if _global_session_state.thread_local_session is not None: + _try_close_session(_global_session_state.thread_local_session) + _global_session_state.thread_local_session = None + + # Currently using thread-local options, so no global lock needed. + # Don't reset options.bigquery, as that's the responsibility + # of the context manager that started it in the first place. The user + # might have explicitly closed the session in the context manager and + # the thread-locality property needs to be retained. + bigframes._config.options.bigquery._session_started = False + + # Don't close the non-thread-local session. + return with _global_session_lock: if _global_session is not None: - try: - _global_session.close() - except google.auth.exceptions.RefreshError as e: - session_id = _global_session.session_id - location = _global_session._location - project_id = _global_session._project - warnings.warn( - f"Session cleanup failed for session with id: {session_id}, " - f"location: {location}, project: {project_id}", - category=bigframes.exceptions.CleanupFailedWarning, - ) - traceback.print_tb(e.__traceback__) + _try_close_session(_global_session) _global_session = None + # This should be global, not thread-local because of the if clause + # above. bigframes._config.options.bigquery._session_started = False @@ -63,7 +87,15 @@ def get_global_session(): Creates the global session if it does not exist. """ - global _global_session, _global_session_lock + global _global_session, _global_session_lock, _global_session_state + + if bigframes._config.options.is_bigquery_thread_local: + if _global_session_state.thread_local_session is None: + _global_session_state.thread_local_session = bigframes.session.connect( + bigframes._config.options.bigquery + ) + + return _global_session_state.thread_local_session with _global_session_lock: if _global_session is None: diff --git a/bigframes/core/indexes/base.py b/bigframes/core/indexes/base.py index 46a9e30637..569dae4ffc 100644 --- a/bigframes/core/indexes/base.py +++ b/bigframes/core/indexes/base.py @@ -239,7 +239,7 @@ def __repr__(self) -> str: opts = bigframes.options.display max_results = opts.max_rows if opts.repr_mode == "deferred": - return formatter.repr_query_job(self.query_job) + return formatter.repr_query_job(self._block._compute_dry_run()) pandas_df, _, query_job = self._block.retrieve_repr_request_results(max_results) self._query_job = query_job diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index d694216ebe..1f1fb5467f 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -595,7 +595,7 @@ def __repr__(self) -> str: opts = bigframes.options.display max_results = opts.max_rows if opts.repr_mode == "deferred": - return formatter.repr_query_job(self.query_job) + return formatter.repr_query_job(self._compute_dry_run()) self._cached() # TODO(swast): pass max_columns and get the true column count back. Maybe @@ -632,9 +632,9 @@ def _repr_html_(self) -> str: many notebooks are not configured for large tables. """ opts = bigframes.options.display - max_results = bigframes.options.display.max_rows + max_results = opts.max_rows if opts.repr_mode == "deferred": - return formatter.repr_query_job_html(self.query_job) + return formatter.repr_query_job(self._compute_dry_run()) self._cached() # TODO(swast): pass max_columns and get the true column count back. Maybe diff --git a/bigframes/series.py b/bigframes/series.py index 3986d38445..aea3d60ff5 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -282,7 +282,7 @@ def __repr__(self) -> str: opts = bigframes.options.display max_results = opts.max_rows if opts.repr_mode == "deferred": - return formatter.repr_query_job(self.query_job) + return formatter.repr_query_job(self._compute_dry_run()) self._cached() pandas_df, _, query_job = self._block.retrieve_repr_request_results(max_results) diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 6f6b67597a..8a6874b178 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -55,25 +55,28 @@ def test_create_text_generator_model_default_session( ): import bigframes.pandas as bpd - bpd.close_session() - bpd.options.bigquery.bq_connection = bq_connection - bpd.options.bigquery.location = "us" - - model = llm.PaLM2TextGenerator() - assert model is not None - assert model._bqml_model is not None - assert ( - model.connection_name.casefold() - == f"{bigquery_client.project}.us.bigframes-rf-conn" - ) - - llm_text_df = bpd.read_pandas(llm_text_pandas_df) - - df = model.predict(llm_text_df).to_pandas() - assert df.shape == (3, 4) - assert "ml_generate_text_llm_result" in df.columns - series = df["ml_generate_text_llm_result"] - assert all(series.str.len() > 20) + # Note: This starts a thread-local session. + with bpd.option_context( + "bigquery.bq_connection", + bq_connection, + "bigquery.location", + "US", + ): + model = llm.PaLM2TextGenerator() + assert model is not None + assert model._bqml_model is not None + assert ( + model.connection_name.casefold() + == f"{bigquery_client.project}.us.bigframes-rf-conn" + ) + + llm_text_df = bpd.read_pandas(llm_text_pandas_df) + + df = model.predict(llm_text_df).to_pandas() + assert df.shape == (3, 4) + assert "ml_generate_text_llm_result" in df.columns + series = df["ml_generate_text_llm_result"] + assert all(series.str.len() > 20) @pytest.mark.flaky(retries=2) @@ -82,25 +85,28 @@ def test_create_text_generator_32k_model_default_session( ): import bigframes.pandas as bpd - bpd.close_session() - bpd.options.bigquery.bq_connection = bq_connection - bpd.options.bigquery.location = "us" - - model = llm.PaLM2TextGenerator(model_name="text-bison-32k") - assert model is not None - assert model._bqml_model is not None - assert ( - model.connection_name.casefold() - == f"{bigquery_client.project}.us.bigframes-rf-conn" - ) - - llm_text_df = bpd.read_pandas(llm_text_pandas_df) - - df = model.predict(llm_text_df).to_pandas() - assert df.shape == (3, 4) - assert "ml_generate_text_llm_result" in df.columns - series = df["ml_generate_text_llm_result"] - assert all(series.str.len() > 20) + # Note: This starts a thread-local session. + with bpd.option_context( + "bigquery.bq_connection", + bq_connection, + "bigquery.location", + "US", + ): + model = llm.PaLM2TextGenerator(model_name="text-bison-32k") + assert model is not None + assert model._bqml_model is not None + assert ( + model.connection_name.casefold() + == f"{bigquery_client.project}.us.bigframes-rf-conn" + ) + + llm_text_df = bpd.read_pandas(llm_text_pandas_df) + + df = model.predict(llm_text_df).to_pandas() + assert df.shape == (3, 4) + assert "ml_generate_text_llm_result" in df.columns + series = df["ml_generate_text_llm_result"] + assert all(series.str.len() > 20) @pytest.mark.flaky(retries=2) @@ -232,27 +238,33 @@ def test_create_embedding_generator_multilingual_model( def test_create_text_embedding_generator_model_defaults(bq_connection): import bigframes.pandas as bpd - bpd.close_session() - bpd.options.bigquery.bq_connection = bq_connection - bpd.options.bigquery.location = "us" - - model = llm.PaLM2TextEmbeddingGenerator() - assert model is not None - assert model._bqml_model is not None + # Note: This starts a thread-local session. + with bpd.option_context( + "bigquery.bq_connection", + bq_connection, + "bigquery.location", + "US", + ): + model = llm.PaLM2TextEmbeddingGenerator() + assert model is not None + assert model._bqml_model is not None def test_create_text_embedding_generator_multilingual_model_defaults(bq_connection): import bigframes.pandas as bpd - bpd.close_session() - bpd.options.bigquery.bq_connection = bq_connection - bpd.options.bigquery.location = "us" - - model = llm.PaLM2TextEmbeddingGenerator( - model_name="textembedding-gecko-multilingual" - ) - assert model is not None - assert model._bqml_model is not None + # Note: This starts a thread-local session. + with bpd.option_context( + "bigquery.bq_connection", + bq_connection, + "bigquery.location", + "US", + ): + model = llm.PaLM2TextEmbeddingGenerator( + model_name="textembedding-gecko-multilingual" + ) + assert model is not None + assert model._bqml_model is not None @pytest.mark.flaky(retries=2) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index b428207314..5ed6908640 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -142,18 +142,13 @@ def test_df_construct_from_dict(): def test_df_construct_inline_respects_location(): import bigframes.pandas as bpd - bpd.close_session() - bpd.options.bigquery.location = "europe-west1" + # Note: This starts a thread-local session. + with bpd.option_context("bigquery.location", "europe-west1"): + df = bpd.DataFrame([[1, 2, 3], [4, 5, 6]]) + repr(df) - df = bpd.DataFrame([[1, 2, 3], [4, 5, 6]]) - repr(df) - - table = bpd.get_global_session().bqclient.get_table(df.query_job.destination) - assert table.location == "europe-west1" - - # Reset global session - bpd.close_session() - bpd.options.bigquery.location = "us" + table = bpd.get_global_session().bqclient.get_table(df.query_job.destination) + assert table.location == "europe-west1" def test_get_column(scalars_dfs): diff --git a/tests/system/small/test_pandas_options.py b/tests/system/small/test_pandas_options.py index dd13196981..afb75c65e3 100644 --- a/tests/system/small/test_pandas_options.py +++ b/tests/system/small/test_pandas_options.py @@ -27,8 +27,11 @@ @pytest.fixture(autouse=True) def reset_default_session_and_location(): - bpd.close_session() - bpd.options.bigquery.location = None + # Note: This starts a thread-local session and closes it once the test + # finishes. + with bpd.option_context("bigquery.location", None): + bpd.options.bigquery.location = None + yield @pytest.mark.parametrize( @@ -80,7 +83,9 @@ def test_read_gbq_start_sets_session_location( ): read_method(query) - # Close global session to start over + # Close the global session to start over. + # Note: This is a thread-local operation because of the + # reset_default_session_and_location fixture above. bpd.close_session() # There should still be the previous location set in the bigquery options @@ -289,13 +294,25 @@ def test_credentials_need_reauthentication(monkeypatch): with pytest.raises(google.auth.exceptions.RefreshError): bpd.read_gbq(test_query) - # Now verify that closing the session works and we throw - # the expected warning + # Now verify that closing the session works We look at the + # thread-local session because of the + # reset_default_session_and_location fixture and that this test mutates + # state that might otherwise be used by tests running in parallel. + assert ( + bigframes.core.global_session._global_session_state.thread_local_session + is not None + ) + with warnings.catch_warnings(record=True) as warned: bpd.close_session() # CleanupFailedWarning: can't clean up + assert len(warned) == 1 assert warned[0].category == bigframes.exceptions.CleanupFailedWarning - assert bigframes.core.global_session._global_session is None + + assert ( + bigframes.core.global_session._global_session_state.thread_local_session + is None + ) # Now verify that use is able to start over df = bpd.read_gbq(test_query) diff --git a/tests/system/small/test_progress_bar.py b/tests/system/small/test_progress_bar.py index 5ccc6db0ac..73a9743e2f 100644 --- a/tests/system/small/test_progress_bar.py +++ b/tests/system/small/test_progress_bar.py @@ -23,33 +23,37 @@ from bigframes.session import MAX_INLINE_DF_BYTES job_load_message_regex = r"\w+ job [\w-]+ is \w+\." +EXPECTED_DRY_RUN_MESSAGE = "Computation deferred. Computation will process" def test_progress_bar_dataframe( penguins_df_default_index: bf.dataframe.DataFrame, capsys ): - bf.options.display.progress_bar = "terminal" capsys.readouterr() # clear output - penguins_df_default_index.to_pandas() + + with bf.option_context("display.progress_bar", "terminal"): + penguins_df_default_index.to_pandas() assert_loading_msg_exist(capsys.readouterr().out) assert penguins_df_default_index.query_job is not None def test_progress_bar_series(penguins_df_default_index: bf.dataframe.DataFrame, capsys): - bf.options.display.progress_bar = "terminal" series = penguins_df_default_index["body_mass_g"].head(10) capsys.readouterr() # clear output - series.to_pandas() + + with bf.option_context("display.progress_bar", "terminal"): + series.to_pandas() assert_loading_msg_exist(capsys.readouterr().out) assert series.query_job is not None def test_progress_bar_scalar(penguins_df_default_index: bf.dataframe.DataFrame, capsys): - bf.options.display.progress_bar = "terminal" capsys.readouterr() # clear output - penguins_df_default_index["body_mass_g"].head(10).mean() + + with bf.option_context("display.progress_bar", "terminal"): + penguins_df_default_index["body_mass_g"].head(10).mean() assert_loading_msg_exist(capsys.readouterr().out) @@ -57,10 +61,11 @@ def test_progress_bar_scalar(penguins_df_default_index: bf.dataframe.DataFrame, def test_progress_bar_extract_jobs( penguins_df_default_index: bf.dataframe.DataFrame, gcs_folder, capsys ): - bf.options.display.progress_bar = "terminal" path = gcs_folder + "test_read_csv_progress_bar*.csv" capsys.readouterr() # clear output - penguins_df_default_index.to_csv(path) + + with bf.option_context("display.progress_bar", "terminal"): + penguins_df_default_index.to_csv(path) assert_loading_msg_exist(capsys.readouterr().out) @@ -73,8 +78,9 @@ def test_progress_bar_load_jobs( while len(df) < MAX_INLINE_DF_BYTES: df = pd.DataFrame(np.repeat(df.values, 2, axis=0)) - bf.options.display.progress_bar = "terminal" - with tempfile.TemporaryDirectory() as dir: + with bf.option_context( + "display.progress_bar", "terminal" + ), tempfile.TemporaryDirectory() as dir: path = dir + "/test_read_csv_progress_bar*.csv" df.to_csv(path, index=False) capsys.readouterr() # clear output @@ -96,11 +102,12 @@ def assert_loading_msg_exist(capystOut: str, pattern=job_load_message_regex): def test_query_job_repr_html(penguins_df_default_index: bf.dataframe.DataFrame): - bf.options.display.progress_bar = "terminal" - penguins_df_default_index.to_pandas() - query_job_repr = formatting_helpers.repr_query_job_html( - penguins_df_default_index.query_job - ).value + with bf.option_context("display.progress_bar", "terminal"): + penguins_df_default_index.to_pandas() + query_job_repr = formatting_helpers.repr_query_job_html( + penguins_df_default_index.query_job + ).value + string_checks = [ "Job Id", "Destination Table", @@ -126,3 +133,21 @@ def test_query_job_repr(penguins_df_default_index: bf.dataframe.DataFrame): ] for string in string_checks: assert string in query_job_repr + + +def test_query_job_dry_run_dataframe(penguins_df_default_index: bf.dataframe.DataFrame): + with bf.option_context("display.repr_mode", "deferred"): + df_result = repr(penguins_df_default_index) + assert EXPECTED_DRY_RUN_MESSAGE in df_result + + +def test_query_job_dry_run_index(penguins_df_default_index: bf.dataframe.DataFrame): + with bf.option_context("display.repr_mode", "deferred"): + index_result = repr(penguins_df_default_index.index) + assert EXPECTED_DRY_RUN_MESSAGE in index_result + + +def test_query_job_dry_run_series(penguins_df_default_index: bf.dataframe.DataFrame): + with bf.option_context("display.repr_mode", "deferred"): + series_result = repr(penguins_df_default_index["body_mass_g"]) + assert EXPECTED_DRY_RUN_MESSAGE in series_result diff --git a/third_party/bigframes_vendored/pandas/_config/config.py b/third_party/bigframes_vendored/pandas/_config/config.py index 1b73e649c8..13ccfdac89 100644 --- a/third_party/bigframes_vendored/pandas/_config/config.py +++ b/third_party/bigframes_vendored/pandas/_config/config.py @@ -7,10 +7,17 @@ class option_context(contextlib.ContextDecorator): """ - Context manager to temporarily set options in the `with` statement context. + Context manager to temporarily set thread-local options in the `with` + statement context. You need to invoke as ``option_context(pat, val, [(pat, val), ...])``. + .. note:: + + `"bigquery"` options can't be changed on a running session. Setting any + of these options creates a new thread-local session that only lives for + the lifetime of the context manager. + **Examples:** >>> import bigframes @@ -29,7 +36,11 @@ def __init__(self, *args) -> None: def __enter__(self) -> None: self.undo = [ - (pat, operator.attrgetter(pat)(bigframes.options)) for pat, val in self.ops + (pat, operator.attrgetter(pat)(bigframes.options)) + for pat, _ in self.ops + # Don't try to undo changes to bigquery options. We're starting and + # closing a new thread-local session if those are set. + if not pat.startswith("bigquery.") ] for pat, val in self.ops: @@ -40,7 +51,21 @@ def __exit__(self, *args) -> None: for pat, val in self.undo: self._set_option(pat, val) + # TODO(tswast): What to do if someone nests several context managers + # with separate "bigquery" options? We might need a "stack" of + # sessions if we allow that. + if bigframes.options.is_bigquery_thread_local: + bigframes.close_session() + + # Reset bigquery_options so that we're no longer thread-local. + bigframes.options._local.bigquery_options = None + def _set_option(self, pat, val): root, attr = pat.rsplit(".", 1) + + # We are now using a thread-specific session. + if root == "bigquery": + bigframes.options._init_bigquery_thread_local() + parent = operator.attrgetter(root)(bigframes.options) setattr(parent, attr, val)