Skip to content

feat: support comparison, ordering, and filtering for timedeltas #1387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 13, 2025
55 changes: 33 additions & 22 deletions bigframes/core/rewrite/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import dataclasses
import functools
import typing
Expand All @@ -27,6 +29,14 @@ class _TypedExpr:
expr: ex.Expression
dtype: dtypes.Dtype

@classmethod
def create_op_expr(
cls, op: typing.Union[ops.ScalarOp, ops.RowOp], *inputs: _TypedExpr
) -> _TypedExpr:
expr = op.as_expr(*tuple(x.expr for x in inputs)) # type: ignore
dtype = op.output_type(*tuple(x.dtype for x in inputs))
return cls(expr, dtype)


def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNode:
"""
Expand All @@ -38,12 +48,27 @@ def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNod
(_rewrite_expressions(expr, root.schema).expr, column_id)
for expr, column_id in root.assignments
)
root = nodes.ProjectionNode(root.child, updated_assignments)
return nodes.ProjectionNode(root.child, updated_assignments)

if isinstance(root, nodes.FilterNode):
return nodes.FilterNode(
root.child, _rewrite_expressions(root.predicate, root.schema).expr
)

if isinstance(root, nodes.OrderByNode):
by = tuple(_rewrite_ordering_expr(x, root.schema) for x in root.by)
return nodes.OrderByNode(root.child, by)

# TODO(b/394354614): FilterByNode and OrderNode also contain expressions. Need to update them too.
return root


def _rewrite_ordering_expr(
expr: nodes.OrderingExpression, schema: schema.ArraySchema
) -> nodes.OrderingExpression:
by = _rewrite_expressions(expr.scalar_expression, schema).expr
return nodes.OrderingExpression(by, expr.direction, expr.na_last)


@functools.cache
def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _TypedExpr:
if isinstance(expr, ex.DerefOp):
Expand Down Expand Up @@ -78,37 +103,23 @@ def _rewrite_op_expr(
if isinstance(expr.op, ops.AddOp):
return _rewrite_add_op(inputs[0], inputs[1])

input_types = tuple(map(lambda x: x.dtype, inputs))
return _TypedExpr(expr, expr.op.output_type(*input_types))
return _TypedExpr.create_op_expr(expr.op, *inputs)


def _rewrite_sub_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
result_op: ops.BinaryOp = ops.sub_op
if dtypes.is_datetime_like(left.dtype) and dtypes.is_datetime_like(right.dtype):
result_op = ops.timestamp_diff_op
return _TypedExpr.create_op_expr(ops.timestamp_diff_op, left, right)

return _TypedExpr(
result_op.as_expr(left.expr, right.expr),
result_op.output_type(left.dtype, right.dtype),
)
return _TypedExpr.create_op_expr(ops.sub_op, left, right)


def _rewrite_add_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
if dtypes.is_datetime_like(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE:
return _TypedExpr(
ops.timestamp_add_op.as_expr(left.expr, right.expr),
ops.timestamp_add_op.output_type(left.dtype, right.dtype),
)
return _TypedExpr.create_op_expr(ops.timestamp_add_op, left, right)

if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right.dtype):
# Re-arrange operands such that timestamp is always on the left and timedelta is
# always on the right.
return _TypedExpr(
ops.timestamp_add_op.as_expr(right.expr, left.expr),
ops.timestamp_add_op.output_type(right.dtype, left.dtype),
)
return _TypedExpr.create_op_expr(ops.timestamp_add_op, right, left)

return _TypedExpr(
ops.add_op.as_expr(left.expr, right.expr),
ops.add_op.output_type(left.dtype, right.dtype),
)
return _TypedExpr.create_op_expr(ops.add_op, left, right)
2 changes: 1 addition & 1 deletion bigframes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def is_comparable(type_: ExpressionType) -> bool:

def is_orderable(type_: ExpressionType) -> bool:
# On BQ side, ARRAY, STRUCT, GEOGRAPHY, JSON are not orderable
return type_ in _ORDERABLE_SIMPLE_TYPES
return type_ in _ORDERABLE_SIMPLE_TYPES or type_ is TIMEDELTA_DTYPE


_CLUSTERABLE_SIMPLE_TYPES = set(
Expand Down
132 changes: 123 additions & 9 deletions tests/system/small/operations/test_timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import datetime
import operator

import numpy as np
import pandas as pd
Expand All @@ -28,12 +29,23 @@ def temporal_dfs(session):
"datetime_col": [
pd.Timestamp("2025-02-01 01:00:01"),
pd.Timestamp("2019-01-02 02:00:00"),
pd.Timestamp("1997-01-01 19:00:00"),
],
"timestamp_col": [
pd.Timestamp("2023-01-01 01:00:01", tz="UTC"),
pd.Timestamp("2024-01-02 02:00:00", tz="UTC"),
pd.Timestamp("2005-03-05 02:00:00", tz="UTC"),
],
"timedelta_col_1": [
pd.Timedelta(3, "s"),
pd.Timedelta(-4, "d"),
pd.Timedelta(5, "h"),
],
"timedelta_col_2": [
pd.Timedelta(2, "s"),
pd.Timedelta(-4, "d"),
pd.Timedelta(6, "h"),
],
"timedelta_col": [pd.Timedelta(3, "s"), pd.Timedelta(-4, "d")],
}
)

Expand All @@ -53,10 +65,10 @@ def test_timestamp_add__ts_series_plus_td_series(temporal_dfs, column, pd_dtype)
bf_df, pd_df = temporal_dfs

actual_result = (
(bf_df[column] + bf_df["timedelta_col"]).to_pandas().astype(pd_dtype)
(bf_df[column] + bf_df["timedelta_col_1"]).to_pandas().astype(pd_dtype)
)

expected_result = pd_df[column] + pd_df["timedelta_col"]
expected_result = pd_df[column] + pd_df["timedelta_col_1"]
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)
Expand Down Expand Up @@ -94,10 +106,10 @@ def test_timestamp_add__td_series_plus_ts_series(temporal_dfs, column, pd_dtype)
bf_df, pd_df = temporal_dfs

actual_result = (
(bf_df["timedelta_col"] + bf_df[column]).to_pandas().astype(pd_dtype)
(bf_df["timedelta_col_1"] + bf_df[column]).to_pandas().astype(pd_dtype)
)

expected_result = pd_df["timedelta_col"] + pd_df[column]
expected_result = pd_df["timedelta_col_1"] + pd_df[column]
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)
Expand All @@ -120,10 +132,10 @@ def test_timestamp_add__ts_literal_plus_td_series(temporal_dfs):
timestamp = pd.Timestamp("2025-01-01", tz="UTC")

actual_result = (
(timestamp + bf_df["timedelta_col"]).to_pandas().astype("datetime64[ns, UTC]")
(timestamp + bf_df["timedelta_col_1"]).to_pandas().astype("datetime64[ns, UTC]")
)

expected_result = timestamp + pd_df["timedelta_col"]
expected_result = timestamp + pd_df["timedelta_col_1"]
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)
Expand All @@ -140,10 +152,10 @@ def test_timestamp_add_with_numpy_op(temporal_dfs, column, pd_dtype):
bf_df, pd_df = temporal_dfs

actual_result = (
np.add(bf_df[column], bf_df["timedelta_col"]).to_pandas().astype(pd_dtype)
np.add(bf_df[column], bf_df["timedelta_col_1"]).to_pandas().astype(pd_dtype)
)

expected_result = np.add(pd_df[column], pd_df["timedelta_col"])
expected_result = np.add(pd_df[column], pd_df["timedelta_col_1"])
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)
Expand All @@ -164,3 +176,105 @@ def test_timestamp_add_dataframes(temporal_dfs):
pandas.testing.assert_frame_equal(
actual_result, expected_result, check_index_type=False
)


@pytest.mark.parametrize(
"compare_func",
[
pytest.param(operator.gt, id="gt"),
pytest.param(operator.ge, id="ge"),
pytest.param(operator.eq, id="eq"),
pytest.param(operator.ne, id="ne"),
pytest.param(operator.lt, id="lt"),
pytest.param(operator.le, id="le"),
],
)
def test_timedelta_series_comparison(temporal_dfs, compare_func):
bf_df, pd_df = temporal_dfs

actual_result = compare_func(
bf_df["timedelta_col_1"], bf_df["timedelta_col_2"]
).to_pandas()

expected_result = compare_func(
pd_df["timedelta_col_1"], pd_df["timedelta_col_2"]
).astype("boolean")
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)


@pytest.mark.parametrize(
"compare_func",
[
pytest.param(operator.gt, id="gt"),
pytest.param(operator.ge, id="ge"),
pytest.param(operator.eq, id="eq"),
pytest.param(operator.ne, id="ne"),
pytest.param(operator.lt, id="lt"),
pytest.param(operator.le, id="le"),
],
)
def test_timedelta_series_and_literal_comparison(temporal_dfs, compare_func):
bf_df, pd_df = temporal_dfs
literal = pd.Timedelta(3, "s")

actual_result = compare_func(literal, bf_df["timedelta_col_2"]).to_pandas()

expected_result = compare_func(literal, pd_df["timedelta_col_2"]).astype("boolean")
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)


def test_timedelta_filtering(session):
pd_series = pd.Series(
[
pd.Timestamp("2025-01-01 01:00:00"),
pd.Timestamp("2025-01-01 02:00:00"),
pd.Timestamp("2025-01-01 03:00:00"),
]
)
bf_series = session.read_pandas(pd_series)
timestamp = pd.Timestamp("2025-01-01, 00:00:01")

actual_result = (
bf_series[((bf_series - timestamp) > pd.Timedelta(1, "h"))]
.to_pandas()
.astype("<M8[ns]")
)

expected_result = pd_series[(pd_series - timestamp) > pd.Timedelta(1, "h")]
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)


def test_timedelta_ordering(session):
pd_df = pd.DataFrame(
{
"col_1": [
pd.Timestamp("2025-01-01 01:00:00"),
pd.Timestamp("2025-01-01 02:00:00"),
pd.Timestamp("2025-01-01 03:00:00"),
],
"col_2": [
pd.Timestamp("2025-01-01 01:00:02"),
pd.Timestamp("2025-01-01 02:00:01"),
pd.Timestamp("2025-01-01 02:59:59"),
],
}
)
bf_df = session.read_pandas(pd_df)

actual_result = (
(bf_df["col_2"] - bf_df["col_1"])
.sort_values()
.to_pandas()
.astype("timedelta64[ns]")
)

expected_result = (pd_df["col_2"] - pd_df["col_1"]).sort_values()
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)