Skip to content

feat: Implement DataFrame.dot for matrix multiplication #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 19, 2023
25 changes: 20 additions & 5 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,13 +1261,28 @@ def pivot(
*,
columns: Sequence[str],
values: Sequence[str],
columns_unique_values: typing.Optional[
typing.Union[pd.Index, Sequence[object]]
] = None,
values_in_index: typing.Optional[bool] = None,
):
# Columns+index should uniquely identify rows
# Warning: This is not validated, breaking this constraint will result in silently non-deterministic behavior.
# -1 to allow for ordering column in addition to pivot columns
max_unique_value = (_BQ_MAX_COLUMNS - 1) // len(values)
columns_values = self._get_unique_values(columns, max_unique_value)
# We need the unique values from the pivot columns to turn them into
# column ids. It can be deteremined by running a SQL query on the
# underlying data. However, the caller can save that if they know the
# unique values upfront by providing them explicitly.
if columns_unique_values is None:
# Columns+index should uniquely identify rows
# Warning: This is not validated, breaking this constraint will
# result in silently non-deterministic behavior.
# -1 to allow for ordering column in addition to pivot columns
max_unique_value = (_BQ_MAX_COLUMNS - 1) // len(values)
columns_values = self._get_unique_values(columns, max_unique_value)
else:
columns_values = (
columns_unique_values
if isinstance(columns_unique_values, pd.Index)
else pd.Index(columns_unique_values)
)
column_index = columns_values

column_ids: list[str] = []
Expand Down
102 changes: 101 additions & 1 deletion bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,10 +1716,13 @@ def kurt(self, *, numeric_only: bool = False):

kurtosis = kurt

def pivot(
def _pivot(
self,
*,
columns: typing.Union[blocks.Label, Sequence[blocks.Label]],
columns_unique_values: typing.Optional[
typing.Union[pandas.Index, Sequence[object]]
] = None,
index: typing.Optional[
typing.Union[blocks.Label, Sequence[blocks.Label]]
] = None,
Expand All @@ -1743,10 +1746,24 @@ def pivot(
pivot_block = block.pivot(
columns=column_ids,
values=value_col_ids,
columns_unique_values=columns_unique_values,
values_in_index=utils.is_list_like(values),
)
return DataFrame(pivot_block)

def pivot(
self,
*,
columns: typing.Union[blocks.Label, Sequence[blocks.Label]],
index: typing.Optional[
typing.Union[blocks.Label, Sequence[blocks.Label]]
] = None,
values: typing.Optional[
typing.Union[blocks.Label, Sequence[blocks.Label]]
] = None,
) -> DataFrame:
return self._pivot(columns=columns, index=index, values=values)

def stack(self, level: LevelsType = -1):
if not isinstance(self.columns, pandas.MultiIndex):
if level not in [0, -1, self.columns.name]:
Expand Down Expand Up @@ -2578,3 +2595,86 @@ def _get_block(self) -> blocks.Block:

def _cached(self) -> DataFrame:
return DataFrame(self._block.cached())

_DataFrameOrSeries = typing.TypeVar("_DataFrameOrSeries")

def dot(self, other: _DataFrameOrSeries) -> _DataFrameOrSeries:
if not isinstance(other, (DataFrame, bf_series.Series)):
raise NotImplementedError(
f"Only DataFrame or Series operand is supported. {constants.FEEDBACK_LINK}"
)

if len(self.index.names) > 1 or len(other.index.names) > 1:
raise NotImplementedError(
f"Multi-index input is not supported. {constants.FEEDBACK_LINK}"
)

if len(self.columns.names) > 1 or (
isinstance(other, DataFrame) and len(other.columns.names) > 1
):
raise NotImplementedError(
f"Multi-level column input is not supported. {constants.FEEDBACK_LINK}"
)

# Convert the dataframes into cell-value-decomposed representation, i.e.
# each cell value is present in a separate row
row_id = "row"
col_id = "col"
val_id = "val"
left_suffix = "_left"
right_suffix = "_right"
cvd_columns = [row_id, col_id, val_id]

def get_left_id(id):
return f"{id}{left_suffix}"

def get_right_id(id):
return f"{id}{right_suffix}"

other_frame = other if isinstance(other, DataFrame) else other.to_frame()

left = self.stack().reset_index()
left.columns = cvd_columns

right = other_frame.stack().reset_index()
right.columns = cvd_columns

merged = left.merge(
right,
left_on=col_id,
right_on=row_id,
suffixes=(left_suffix, right_suffix),
)

left_row_id = get_left_id(row_id)
right_col_id = get_right_id(col_id)

aggregated = (
merged.assign(
val=merged[get_left_id(val_id)] * merged[get_right_id(val_id)]
)[[left_row_id, right_col_id, val_id]]
.groupby([left_row_id, right_col_id])
.sum(numeric_only=True)
)
aggregated_noindex = aggregated.reset_index()
aggregated_noindex.columns = cvd_columns
result = aggregated_noindex._pivot(
columns=col_id, columns_unique_values=other_frame.columns, index=row_id
)

# Set the index names to match the left side matrix
result.index.names = self.index.names

# Pivot has the result columns ordered alphabetically. It should still
# match the columns in the right sided matrix. Let's reorder them as per
# the right side matrix
if not result.columns.difference(other_frame.columns).empty:
raise RuntimeError(
f"Could not construct all columns. {constants.FEEDBACK_LINK}"
)
result = result[other_frame.columns]

if isinstance(other, bf_series.Series):
result = result[other.name].rename()

return result
22 changes: 22 additions & 0 deletions tests/data/matrix_2by3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[
{
"mode": "REQUIRED",
"name": "rowindex",
"type": "INTEGER"
},
{
"mode": "NULLABLE",
"name": "a",
"type": "INTEGER"
},
{
"mode": "NULLABLE",
"name": "b",
"type": "INTEGER"
},
{
"mode": "NULLABLE",
"name": "c",
"type": "INTEGER"
}
]
2 changes: 2 additions & 0 deletions tests/data/matrix_2by3.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"rowindex": 0, "a": 1, "b": 2, "c": 3}
{"rowindex": 1, "a": 2, "b": 5, "c": 7}
27 changes: 27 additions & 0 deletions tests/data/matrix_3by4.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
[
{
"mode": "REQUIRED",
"name": "rowindex",
"type": "STRING"
},
{
"mode": "NULLABLE",
"name": "w",
"type": "INTEGER"
},
{
"mode": "NULLABLE",
"name": "x",
"type": "INTEGER"
},
{
"mode": "NULLABLE",
"name": "y",
"type": "INTEGER"
},
{
"mode": "NULLABLE",
"name": "z",
"type": "INTEGER"
}
]
3 changes: 3 additions & 0 deletions tests/data/matrix_3by4.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"rowindex": "a", "w": 2, "x": 4, "y": 8, "z": 21}
{"rowindex": "b", "w": 1, "x": 5, "y": 10, "z": -11}
{"rowindex": "c", "w": 3, "x": 6, "y": 9, "z": 0}
68 changes: 68 additions & 0 deletions tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ def load_test_data_tables(
("penguins", "penguins_schema.json", "penguins.jsonl"),
("time_series", "time_series_schema.json", "time_series.jsonl"),
("hockey_players", "hockey_players.json", "hockey_players.jsonl"),
("matrix_2by3", "matrix_2by3.json", "matrix_2by3.jsonl"),
("matrix_3by4", "matrix_3by4.json", "matrix_3by4.jsonl"),
]:
test_data_hash = hashlib.md5()
_hash_digest_file(test_data_hash, DATA_DIR / schema_filename)
Expand Down Expand Up @@ -304,6 +306,16 @@ def time_series_table_id(test_data_tables) -> str:
return test_data_tables["time_series"]


@pytest.fixture(scope="session")
def matrix_2by3_table_id(test_data_tables) -> str:
return test_data_tables["matrix_2by3"]


@pytest.fixture(scope="session")
def matrix_3by4_table_id(test_data_tables) -> str:
return test_data_tables["matrix_3by4"]


@pytest.fixture(scope="session")
def scalars_df_default_index(
scalars_df_index: bigframes.dataframe.DataFrame,
Expand Down Expand Up @@ -411,6 +423,62 @@ def hockey_pandas_df() -> pd.DataFrame:
return df


@pytest.fixture(scope="session")
def matrix_2by3_df(
matrix_2by3_table_id: str, session: bigframes.Session
) -> bigframes.dataframe.DataFrame:
"""DataFrame pointing at a test 2-by-3 matrix data."""
df = session.read_gbq(matrix_2by3_table_id)
df = df.set_index("rowindex").sort_index()
return df


@pytest.fixture(scope="session")
def matrix_2by3_pandas_df() -> pd.DataFrame:
"""pd.DataFrame pointing at a test 2-by-3 matrix data."""
df = pd.read_json(
DATA_DIR / "matrix_2by3.jsonl",
lines=True,
dtype={
"rowindex": pd.Int64Dtype(),
"a": pd.Int64Dtype(),
"b": pd.Int64Dtype(),
"c": pd.Int64Dtype(),
},
)
df = df.set_index("rowindex").sort_index()
df.index = df.index.astype("Int64")
return df


@pytest.fixture(scope="session")
def matrix_3by4_df(
matrix_3by4_table_id: str, session: bigframes.Session
) -> bigframes.dataframe.DataFrame:
"""DataFrame pointing at a test 3-by-4 matrix data."""
df = session.read_gbq(matrix_3by4_table_id)
df = df.set_index("rowindex").sort_index()
return df


@pytest.fixture(scope="session")
def matrix_3by4_pandas_df() -> pd.DataFrame:
"""pd.DataFrame pointing at a test 3-by-4 matrix data."""
df = pd.read_json(
DATA_DIR / "matrix_3by4.jsonl",
lines=True,
dtype={
"rowindex": pd.StringDtype(storage="pyarrow"),
"w": pd.Int64Dtype(),
"x": pd.Int64Dtype(),
"y": pd.Int64Dtype(),
"z": pd.Int64Dtype(),
},
)
df = df.set_index("rowindex").sort_index()
return df


@pytest.fixture(scope="session")
def penguins_df_default_index(
penguins_table_id: str, session: bigframes.Session
Expand Down
54 changes: 54 additions & 0 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3167,3 +3167,57 @@ def test_df_cached(scalars_df_index):

df_cached_copy = df._cached()
pandas.testing.assert_frame_equal(df.to_pandas(), df_cached_copy.to_pandas())


def test_df_dot_inline(session):
df1 = pd.DataFrame([[1, 2, 3], [2, 5, 7]])
df2 = pd.DataFrame([[2, 4, 8], [1, 5, 10], [3, 6, 9]])

bf1 = session.read_pandas(df1)
bf2 = session.read_pandas(df2)
bf_result = bf1.dot(bf2).to_pandas()
pd_result = df1.dot(df2)

# Patch pandas dtypes for testing parity
# Pandas uses int64 instead of Int64 (nullable) dtype.
for name in pd_result.columns:
pd_result[name] = pd_result[name].astype(pd.Int64Dtype())
pd_result.index = pd_result.index.astype(pd.Int64Dtype())

pd.testing.assert_frame_equal(
bf_result,
pd_result,
)


def test_df_dot(
matrix_2by3_df, matrix_2by3_pandas_df, matrix_3by4_df, matrix_3by4_pandas_df
):
bf_result = matrix_2by3_df.dot(matrix_3by4_df).to_pandas()
pd_result = matrix_2by3_pandas_df.dot(matrix_3by4_pandas_df)

# Patch pandas dtypes for testing parity
# Pandas result is object instead of Int64 (nullable) dtype.
for name in pd_result.columns:
pd_result[name] = pd_result[name].astype(pd.Int64Dtype())

pd.testing.assert_frame_equal(
bf_result,
pd_result,
)


def test_df_dot_series(
matrix_2by3_df, matrix_2by3_pandas_df, matrix_3by4_df, matrix_3by4_pandas_df
):
bf_result = matrix_2by3_df.dot(matrix_3by4_df["x"]).to_pandas()
pd_result = matrix_2by3_pandas_df.dot(matrix_3by4_pandas_df["x"])

# Patch pandas dtypes for testing parity
# Pandas result is object instead of Int64 (nullable) dtype.
pd_result = pd_result.astype(pd.Int64Dtype())

pd.testing.assert_series_equal(
bf_result,
pd_result,
)
Loading