Skip to content
55 changes: 33 additions & 22 deletions bigframes/core/rewrite/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import dataclasses
import functools
import typing
Expand All @@ -27,6 +29,14 @@ class _TypedExpr:
expr: ex.Expression
dtype: dtypes.Dtype

@classmethod
def create_op_expr(
cls, op: typing.Union[ops.ScalarOp, ops.RowOp], *inputs: _TypedExpr
) -> _TypedExpr:
expr = op.as_expr(*tuple(x.expr for x in inputs)) # type: ignore
dtype = op.output_type(*tuple(x.dtype for x in inputs))
return cls(expr, dtype)


def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNode:
"""
Expand All @@ -38,12 +48,27 @@ def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNod
(_rewrite_expressions(expr, root.schema).expr, column_id)
for expr, column_id in root.assignments
)
root = nodes.ProjectionNode(root.child, updated_assignments)
return nodes.ProjectionNode(root.child, updated_assignments)

if isinstance(root, nodes.FilterNode):
return nodes.FilterNode(
root.child, _rewrite_expressions(root.predicate, root.schema).expr
)

if isinstance(root, nodes.OrderByNode):
by = tuple(_rewrite_ordering_expr(x, root.schema) for x in root.by)
return nodes.OrderByNode(root.child, by)

# TODO(b/394354614): FilterByNode and OrderNode also contain expressions. Need to update them too.
return root


def _rewrite_ordering_expr(
expr: nodes.OrderingExpression, schema: schema.ArraySchema
) -> nodes.OrderingExpression:
by = _rewrite_expressions(expr.scalar_expression, schema).expr
return nodes.OrderingExpression(by, expr.direction, expr.na_last)


@functools.cache
def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _TypedExpr:
if isinstance(expr, ex.DerefOp):
Expand Down Expand Up @@ -78,37 +103,23 @@ def _rewrite_op_expr(
if isinstance(expr.op, ops.AddOp):
return _rewrite_add_op(inputs[0], inputs[1])

input_types = tuple(map(lambda x: x.dtype, inputs))
return _TypedExpr(expr, expr.op.output_type(*input_types))
return _TypedExpr.create_op_expr(expr.op, *inputs)


def _rewrite_sub_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
result_op: ops.BinaryOp = ops.sub_op
if dtypes.is_datetime_like(left.dtype) and dtypes.is_datetime_like(right.dtype):
result_op = ops.timestamp_diff_op
return _TypedExpr.create_op_expr(ops.timestamp_diff_op, left, right)

return _TypedExpr(
result_op.as_expr(left.expr, right.expr),
result_op.output_type(left.dtype, right.dtype),
)
return _TypedExpr.create_op_expr(ops.sub_op, left, right)


def _rewrite_add_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
if dtypes.is_datetime_like(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE:
return _TypedExpr(
ops.timestamp_add_op.as_expr(left.expr, right.expr),
ops.timestamp_add_op.output_type(left.dtype, right.dtype),
)
return _TypedExpr.create_op_expr(ops.timestamp_add_op, left, right)

if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right.dtype):
# Re-arrange operands such that timestamp is always on the left and timedelta is
# always on the right.
return _TypedExpr(
ops.timestamp_add_op.as_expr(right.expr, left.expr),
ops.timestamp_add_op.output_type(right.dtype, left.dtype),
)
return _TypedExpr.create_op_expr(ops.timestamp_add_op, right, left)

return _TypedExpr(
ops.add_op.as_expr(left.expr, right.expr),
ops.add_op.output_type(left.dtype, right.dtype),
)
return _TypedExpr.create_op_expr(ops.add_op, left, right)
2 changes: 1 addition & 1 deletion bigframes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def is_comparable(type_: ExpressionType) -> bool:

def is_orderable(type_: ExpressionType) -> bool:
# On BQ side, ARRAY, STRUCT, GEOGRAPHY, JSON are not orderable
return type_ in _ORDERABLE_SIMPLE_TYPES
return type_ in _ORDERABLE_SIMPLE_TYPES or type_ is TIMEDELTA_DTYPE


_CLUSTERABLE_SIMPLE_TYPES = set(
Expand Down
132 changes: 123 additions & 9 deletions tests/system/small/operations/test_timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import datetime
import operator

import numpy as np
import pandas as pd
Expand All @@ -28,12 +29,23 @@ def temporal_dfs(session):
"datetime_col": [
pd.Timestamp("2025-02-01 01:00:01"),
pd.Timestamp("2019-01-02 02:00:00"),
pd.Timestamp("1997-01-01 19:00:00"),
],
"timestamp_col": [
pd.Timestamp("2023-01-01 01:00:01", tz="UTC"),
pd.Timestamp("2024-01-02 02:00:00", tz="UTC"),
pd.Timestamp("2005-03-05 02:00:00", tz="UTC"),
],
"timedelta_col_1": [
pd.Timedelta(3, "s"),
pd.Timedelta(-4, "d"),
pd.Timedelta(5, "h"),
],
"timedelta_col_2": [
pd.Timedelta(2, "s"),
pd.Timedelta(-4, "d"),
pd.Timedelta(6, "h"),
],
"timedelta_col": [pd.Timedelta(3, "s"), pd.Timedelta(-4, "d")],
}
)

Expand All @@ -53,10 +65,10 @@ def test_timestamp_add__ts_series_plus_td_series(temporal_dfs, column, pd_dtype)
bf_df, pd_df = temporal_dfs

actual_result = (
(bf_df[column] + bf_df["timedelta_col"]).to_pandas().astype(pd_dtype)
(bf_df[column] + bf_df["timedelta_col_1"]).to_pandas().astype(pd_dtype)
)

expected_result = pd_df[column] + pd_df["timedelta_col"]
expected_result = pd_df[column] + pd_df["timedelta_col_1"]
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)
Expand Down Expand Up @@ -94,10 +106,10 @@ def test_timestamp_add__td_series_plus_ts_series(temporal_dfs, column, pd_dtype)
bf_df, pd_df = temporal_dfs

actual_result = (
(bf_df["timedelta_col"] + bf_df[column]).to_pandas().astype(pd_dtype)
(bf_df["timedelta_col_1"] + bf_df[column]).to_pandas().astype(pd_dtype)
)

expected_result = pd_df["timedelta_col"] + pd_df[column]
expected_result = pd_df["timedelta_col_1"] + pd_df[column]
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)
Expand All @@ -120,10 +132,10 @@ def test_timestamp_add__ts_literal_plus_td_series(temporal_dfs):
timestamp = pd.Timestamp("2025-01-01", tz="UTC")

actual_result = (
(timestamp + bf_df["timedelta_col"]).to_pandas().astype("datetime64[ns, UTC]")
(timestamp + bf_df["timedelta_col_1"]).to_pandas().astype("datetime64[ns, UTC]")
)

expected_result = timestamp + pd_df["timedelta_col"]
expected_result = timestamp + pd_df["timedelta_col_1"]
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)
Expand All @@ -140,10 +152,10 @@ def test_timestamp_add_with_numpy_op(temporal_dfs, column, pd_dtype):
bf_df, pd_df = temporal_dfs

actual_result = (
np.add(bf_df[column], bf_df["timedelta_col"]).to_pandas().astype(pd_dtype)
np.add(bf_df[column], bf_df["timedelta_col_1"]).to_pandas().astype(pd_dtype)
)

expected_result = np.add(pd_df[column], pd_df["timedelta_col"])
expected_result = np.add(pd_df[column], pd_df["timedelta_col_1"])
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)
Expand All @@ -164,3 +176,105 @@ def test_timestamp_add_dataframes(temporal_dfs):
pandas.testing.assert_frame_equal(
actual_result, expected_result, check_index_type=False
)


@pytest.mark.parametrize(
"compare_func",
[
pytest.param(operator.gt, id="gt"),
pytest.param(operator.ge, id="ge"),
pytest.param(operator.eq, id="eq"),
pytest.param(operator.ne, id="ne"),
pytest.param(operator.lt, id="lt"),
pytest.param(operator.le, id="le"),
],
)
def test_timedelta_series_comparison(temporal_dfs, compare_func):
bf_df, pd_df = temporal_dfs

actual_result = compare_func(
bf_df["timedelta_col_1"], bf_df["timedelta_col_2"]
).to_pandas()

expected_result = compare_func(
pd_df["timedelta_col_1"], pd_df["timedelta_col_2"]
).astype("boolean")
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)


@pytest.mark.parametrize(
"compare_func",
[
pytest.param(operator.gt, id="gt"),
pytest.param(operator.ge, id="ge"),
pytest.param(operator.eq, id="eq"),
pytest.param(operator.ne, id="ne"),
pytest.param(operator.lt, id="lt"),
pytest.param(operator.le, id="le"),
],
)
def test_timedelta_series_and_literal_comparison(temporal_dfs, compare_func):
bf_df, pd_df = temporal_dfs
literal = pd.Timedelta(3, "s")

actual_result = compare_func(literal, bf_df["timedelta_col_2"]).to_pandas()

expected_result = compare_func(literal, pd_df["timedelta_col_2"]).astype("boolean")
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)


def test_timedelta_filtering(session):
pd_series = pd.Series(
[
pd.Timestamp("2025-01-01 01:00:00"),
pd.Timestamp("2025-01-01 02:00:00"),
pd.Timestamp("2025-01-01 03:00:00"),
]
)
bf_series = session.read_pandas(pd_series)
timestamp = pd.Timestamp("2025-01-01, 00:00:01")

actual_result = (
bf_series[((bf_series - timestamp) > pd.Timedelta(1, "h"))]
.to_pandas()
.astype("<M8[ns]")
)

expected_result = pd_series[(pd_series - timestamp) > pd.Timedelta(1, "h")]
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)


def test_timedelta_ordering(session):
pd_df = pd.DataFrame(
{
"col_1": [
pd.Timestamp("2025-01-01 01:00:00"),
pd.Timestamp("2025-01-01 02:00:00"),
pd.Timestamp("2025-01-01 03:00:00"),
],
"col_2": [
pd.Timestamp("2025-01-01 01:00:02"),
pd.Timestamp("2025-01-01 02:00:01"),
pd.Timestamp("2025-01-01 02:59:59"),
],
}
)
bf_df = session.read_pandas(pd_df)

actual_result = (
(bf_df["col_2"] - bf_df["col_1"])
.sort_values()
.to_pandas()
.astype("timedelta64[ns]")
)

expected_result = (pd_df["col_2"] - pd_df["col_1"]).sort_values()
pandas.testing.assert_series_equal(
actual_result, expected_result, check_index_type=False
)