Skip to content

fix: sampling plot cannot preserve ordering if index is not ordered #475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import itertools
import random
import typing
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple
from typing import Iterable, List, Literal, Mapping, Optional, Sequence, Tuple
import warnings

import google.cloud.bigquery as bigquery
Expand Down Expand Up @@ -555,7 +555,7 @@ def _downsample(
block = self._split(
fracs=(fraction,),
random_state=random_state,
preserve_order=True,
sort=False,
)[0]
return block
else:
Expand All @@ -571,7 +571,7 @@ def _split(
fracs: Iterable[float] = (),
*,
random_state: Optional[int] = None,
preserve_order: Optional[bool] = False,
sort: Optional[bool | Literal["random"]] = "random",
) -> List[Block]:
"""Internal function to support splitting Block to multiple parts along index axis.

Expand Down Expand Up @@ -623,7 +623,18 @@ def _split(
typing.cast(Block, block.slice(start=lower, stop=upper))
for lower, upper in intervals
]
if preserve_order:

if sort is True:
sliced_blocks = [
sliced_block.order_by(
[
ordering.OrderingColumnReference(idx_col)
for idx_col in sliced_block.index_columns
]
)
for sliced_block in sliced_blocks
]
elif sort is False:
sliced_blocks = [
sliced_block.order_by([ordering.OrderingColumnReference(ordering_col)])
for sliced_block in sliced_blocks
Expand Down
5 changes: 4 additions & 1 deletion bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2504,14 +2504,17 @@ def sample(
frac: Optional[float] = None,
*,
random_state: Optional[int] = None,
sort: Optional[bool | Literal["random"]] = "random",
) -> DataFrame:
if n is not None and frac is not None:
raise ValueError("Only one of 'n' or 'frac' parameter can be specified.")

ns = (n,) if n is not None else ()
fracs = (frac,) if frac is not None else ()
return DataFrame(
self._block._split(ns=ns, fracs=fracs, random_state=random_state)[0]
self._block._split(
ns=ns, fracs=fracs, random_state=random_state, sort=sort
)[0]
)

def _split(
Expand Down
10 changes: 5 additions & 5 deletions bigframes/operations/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def _compute_plot_data(self, data):
# TODO: Cache the sampling data in the PlotAccessor.
sampling_n = self.kwargs.pop("sampling_n", 100)
sampling_random_state = self.kwargs.pop("sampling_random_state", 0)
return (
data.sample(n=sampling_n, random_state=sampling_random_state)
.to_pandas()
.sort_index()
)
return data.sample(
n=sampling_n,
random_state=sampling_random_state,
sort=False,
).to_pandas()


class LinePlot(SamplingPlot):
Expand Down
7 changes: 5 additions & 2 deletions bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
import textwrap
import typing
from typing import Any, Mapping, Optional, Tuple, Union
from typing import Any, Literal, Mapping, Optional, Tuple, Union

import bigframes_vendored.pandas.core.series as vendored_pandas_series
import google.cloud.bigquery as bigquery
Expand Down Expand Up @@ -1535,14 +1535,17 @@ def sample(
frac: Optional[float] = None,
*,
random_state: Optional[int] = None,
sort: Optional[bool | Literal["random"]] = "random",
) -> Series:
if n is not None and frac is not None:
raise ValueError("Only one of 'n' or 'frac' parameter can be specified.")

ns = (n,) if n is not None else ()
fracs = (frac,) if frac is not None else ()
return Series(
self._block._split(ns=ns, fracs=fracs, random_state=random_state)[0]
self._block._split(
ns=ns, fracs=fracs, random_state=random_state, sort=sort
)[0]
)

def __array_ufunc__(
Expand Down
15 changes: 14 additions & 1 deletion tests/system/small/operations/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import numpy as np
import pandas as pd
import pandas._testing as tm
import pytest

Expand Down Expand Up @@ -235,6 +236,18 @@ def test_sampling_plot_args_random_state():
tm.assert_almost_equal(ax_0.lines[0].get_data()[1], ax_2.lines[0].get_data()[1])


def test_sampling_preserve_ordering():
df = bpd.DataFrame([0.0, 1.0, 2.0, 3.0, 4.0], index=[1, 3, 4, 2, 0])
pd_df = pd.DataFrame([0.0, 1.0, 2.0, 3.0, 4.0], index=[1, 3, 4, 2, 0])
ax = df.plot.line()
pd_ax = pd_df.plot.line()
tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks())
tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks())
for line, pd_line in zip(ax.lines, pd_ax.lines):
# Compare y coordinates between the lines
tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1])


@pytest.mark.parametrize(
("kind", "col_names", "kwargs"),
[
Expand All @@ -251,7 +264,7 @@ def test_sampling_plot_args_random_state():
marks=pytest.mark.xfail(raises=ValueError),
),
pytest.param(
"uknown",
"bar",
["int64_col", "int64_too"],
{},
marks=pytest.mark.xfail(raises=NotImplementedError),
Expand Down
22 changes: 22 additions & 0 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3049,6 +3049,28 @@ def test_sample_raises_value_error(scalars_dfs):
scalars_df.sample(frac=0.5, n=4)


def test_sample_args_sort(scalars_dfs):
scalars_df, _ = scalars_dfs
index = [4, 3, 2, 5, 1, 0]
scalars_df = scalars_df.iloc[index]

kwargs = {"frac": 1.0, "random_state": 333}

df = scalars_df.sample(**kwargs).to_pandas()
assert df.index.values != index
assert df.index.values != sorted(index)

df = scalars_df.sample(sort="random", **kwargs).to_pandas()
assert df.index.values != index
assert df.index.values != sorted(index)

df = scalars_df.sample(sort=True, **kwargs).to_pandas()
assert df.index.values == sorted(index)

df = scalars_df.sample(sort=False, **kwargs).to_pandas()
assert df.index.values == index


@pytest.mark.parametrize(
("axis",),
[
Expand Down
7 changes: 7 additions & 0 deletions third_party/bigframes_vendored/pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def sample(
frac: Optional[float] = None,
*,
random_state: Optional[int] = None,
sort: Optional[bool | Literal["random"]] = "random",
):
"""Return a random sample of items from an axis of object.

Expand Down Expand Up @@ -530,6 +531,12 @@ def sample(
Fraction of axis items to return. Cannot be used with `n`.
random_state (Optional[int], default None):
Seed for random number generator.
sort (Optional[bool|Literal["random"]], default "random"):

- 'random' (default): No specific ordering will be applied after
sampling.
- 'True' : Index columns will determine the sample's order.
- 'False': The sample will retain the original object's order.

Returns:
A new object of same type as caller containing `n` items randomly
Expand Down