diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 517176da89..af05f4423c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,4 +38,4 @@ repos: rev: v1.1.1 hooks: - id: mypy - additional_dependencies: [types-requests, types-tabulate] + additional_dependencies: [types-requests, types-tabulate, pandas-stubs] diff --git a/bigframes/pandas/__init__.py b/bigframes/pandas/__init__.py index fc008f36e5..4b0ac4310c 100644 --- a/bigframes/pandas/__init__.py +++ b/bigframes/pandas/__init__.py @@ -577,7 +577,22 @@ def read_gbq_table( read_gbq_table.__doc__ = inspect.getdoc(bigframes.session.Session.read_gbq_table) +@typing.overload def read_pandas(pandas_dataframe: pandas.DataFrame) -> bigframes.dataframe.DataFrame: + ... + + +@typing.overload +def read_pandas(pandas_dataframe: pandas.Series) -> bigframes.series.Series: + ... + + +@typing.overload +def read_pandas(pandas_dataframe: pandas.Index) -> bigframes.core.indexes.Index: + ... + + +def read_pandas(pandas_dataframe: Union[pandas.DataFrame, pandas.Series, pandas.Index]): return global_session.with_default_session( bigframes.session.Session.read_pandas, pandas_dataframe, diff --git a/bigframes/series.py b/bigframes/series.py index e7b358c2fe..7e2b0408b7 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -1514,7 +1514,7 @@ def map( map_df = map_df.rename(columns={arg.name: self.name}) elif isinstance(arg, Mapping): map_df = bigframes.dataframe.DataFrame( - {"keys": list(arg.keys()), self.name: list(arg.values())}, + {"keys": list(arg.keys()), self.name: list(arg.values())}, # type: ignore session=self._get_block().expr.session, ) map_df = map_df.set_index("keys") diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index ac266da3bd..c7605e89d7 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -95,7 +95,9 @@ # Avoid circular imports. if typing.TYPE_CHECKING: + import bigframes.core.indexes import bigframes.dataframe as dataframe + import bigframes.series _BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection" @@ -963,7 +965,23 @@ def read_gbq_model(self, model_name: str): model = self.bqclient.get_model(model_ref) return bigframes.ml.loader.from_bq(self, model) + @typing.overload + def read_pandas( + self, pandas_dataframe: pandas.Index + ) -> bigframes.core.indexes.Index: + ... + + @typing.overload + def read_pandas(self, pandas_dataframe: pandas.Series) -> bigframes.series.Series: + ... + + @typing.overload def read_pandas(self, pandas_dataframe: pandas.DataFrame) -> dataframe.DataFrame: + ... + + def read_pandas( + self, pandas_dataframe: Union[pandas.DataFrame, pandas.Series, pandas.Index] + ): """Loads DataFrame from a pandas DataFrame. The pandas DataFrame will be persisted as a temporary BigQuery table, which can be @@ -986,13 +1004,31 @@ def read_pandas(self, pandas_dataframe: pandas.DataFrame) -> dataframe.DataFrame [2 rows x 2 columns] Args: - pandas_dataframe (pandas.DataFrame): - a pandas DataFrame object to be loaded. + pandas_dataframe (pandas.DataFrame, pandas.Series, or pandas.Index): + a pandas DataFrame/Series/Index object to be loaded. Returns: - bigframes.dataframe.DataFrame: The BigQuery DataFrame. + An equivalent bigframes.pandas.(DataFrame/Series/Index) object """ - return self._read_pandas(pandas_dataframe, "read_pandas") + import bigframes.series as series + + # Try to handle non-dataframe pandas objects as well + if isinstance(pandas_dataframe, pandas.Series): + bf_df = self._read_pandas(pandas.DataFrame(pandas_dataframe), "read_pandas") + bf_series = typing.cast(series.Series, bf_df[bf_df.columns[0]]) + # wrapping into df can set name to 0 so reset to original object name + bf_series.name = pandas_dataframe.name + return bf_series + if isinstance(pandas_dataframe, pandas.Index): + return self._read_pandas( + pandas.DataFrame(index=pandas_dataframe), "read_pandas" + ).index + if isinstance(pandas_dataframe, pandas.DataFrame): + return self._read_pandas(pandas_dataframe, "read_pandas") + else: + raise ValueError( + f"read_pandas() expects a pandas.DataFrame, but got a {type(pandas_dataframe)}" + ) def _read_pandas( self, pandas_dataframe: pandas.DataFrame, api_name: str diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index 28a3f03860..eb6a0a8dd9 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -421,6 +421,21 @@ def test_read_pandas(session, scalars_dfs): pd.testing.assert_frame_equal(result, expected) +def test_read_pandas_series(session): + idx = pd.Index([2, 7, 1, 2, 8], dtype=pd.Int64Dtype()) + pd_series = pd.Series([3, 1, 4, 1, 5], dtype=pd.Int64Dtype(), index=idx) + bf_series = session.read_pandas(pd_series) + + pd.testing.assert_series_equal(bf_series.to_pandas(), pd_series) + + +def test_read_pandas_index(session): + pd_idx = pd.Index([2, 7, 1, 2, 8], dtype=pd.Int64Dtype()) + bf_idx = session.read_pandas(pd_idx) + + pd.testing.assert_index_equal(bf_idx.to_pandas(), pd_idx) + + def test_read_pandas_inline_respects_location(): options = bigframes.BigQueryOptions(location="europe-west1") session = bigframes.Session(options)