Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions bigframes/core/compile/polars/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses

from bigframes import dtypes
from bigframes.core import bigframe_node, expression
from bigframes.core.rewrite import op_lowering
from bigframes.operations import numeric_ops
from bigframes.operations import comparison_ops, numeric_ops
import bigframes.operations as ops

# TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops)


@dataclasses.dataclass
class CoerceArgsRule(op_lowering.OpLoweringRule):
op_type: type[ops.BinaryOp]

@property
def op(self) -> type[ops.ScalarOp]:
return self.op_type

def lower(self, expr: expression.OpExpression) -> expression.Expression:
assert isinstance(expr.op, self.op_type)
larg, rarg = _coerce_comparables(expr.children[0], expr.children[1])
return expr.op.as_expr(larg, rarg)


class LowerFloorDivRule(op_lowering.OpLoweringRule):
@property
def op(self) -> type[ops.ScalarOp]:
Expand All @@ -40,7 +56,42 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
return ops.where_op.as_expr(zero_result, divisor_is_zero, expr)


POLARS_LOWERING_RULES = (LowerFloorDivRule(),)
def _coerce_comparables(expr1: expression.Expression, expr2: expression.Expression):

target_type = dtypes.coerce_to_common(expr1.output_type, expr2.output_type)
if expr1.output_type != target_type:
expr1 = _lower_cast(ops.AsTypeOp(target_type), expr1)
if expr2.output_type != target_type:
expr2 = _lower_cast(ops.AsTypeOp(target_type), expr2)
return expr1, expr2


# TODO: Need to handle bool->string cast to get capitalization correct
def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
if arg.output_type == dtypes.BOOL_DTYPE and dtypes.is_numeric(cast_op.to_type):
# bool -> decimal needs two-step cast
new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg)
return cast_op.as_expr(new_arg)
return cast_op.as_expr(arg)


LOWER_COMPARISONS = tuple(
CoerceArgsRule(op)
for op in (
comparison_ops.EqOp,
comparison_ops.EqNullsMatchOp,
comparison_ops.NeOp,
comparison_ops.LtOp,
comparison_ops.GtOp,
comparison_ops.LeOp,
comparison_ops.GeOp,
)
)

POLARS_LOWERING_RULES = (
*LOWER_COMPARISONS,
LowerFloorDivRule(),
)


def lower_ops_to_polars(root: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode:
Expand Down
18 changes: 18 additions & 0 deletions bigframes/core/compile/scalar_op_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,7 @@ def eq_op(
x: ibis_types.Value,
y: ibis_types.Value,
):
x, y = _coerce_comparables(x, y)
return x == y


Expand All @@ -1507,6 +1508,7 @@ def eq_nulls_match_op(
y: ibis_types.Value,
):
"""Variant of eq_op where nulls match each other. Only use where dtypes are known to be same."""
x, y = _coerce_comparables(x, y)
literal = ibis_types.literal("$NULL_SENTINEL$")
if hasattr(x, "fill_null"):
left = x.cast(ibis_dtypes.str).fill_null(literal)
Expand All @@ -1523,6 +1525,7 @@ def ne_op(
x: ibis_types.Value,
y: ibis_types.Value,
):
x, y = _coerce_comparables(x, y)
return x != y


Expand All @@ -1534,6 +1537,17 @@ def _null_or_value(value: ibis_types.Value, where_value: ibis_types.BooleanValue
)


def _coerce_comparables(
x: ibis_types.Value,
y: ibis_types.Value,
):
if x.type().is_boolean() and not y.type().is_boolean():
x = x.cast(ibis_dtypes.int64)
elif y.type().is_boolean() and not x.type().is_boolean():
y = y.cast(ibis_dtypes.int64)
return x, y


@scalar_op_compiler.register_binary_op(ops.and_op)
def and_op(
x: ibis_types.Value,
Expand Down Expand Up @@ -1735,6 +1749,7 @@ def lt_op(
x: ibis_types.Value,
y: ibis_types.Value,
):
x, y = _coerce_comparables(x, y)
return x < y


Expand All @@ -1744,6 +1759,7 @@ def le_op(
x: ibis_types.Value,
y: ibis_types.Value,
):
x, y = _coerce_comparables(x, y)
return x <= y


Expand All @@ -1753,6 +1769,7 @@ def gt_op(
x: ibis_types.Value,
y: ibis_types.Value,
):
x, y = _coerce_comparables(x, y)
return x > y


Expand All @@ -1762,6 +1779,7 @@ def ge_op(
x: ibis_types.Value,
y: ibis_types.Value,
):
x, y = _coerce_comparables(x, y)
return x >= y


Expand Down
11 changes: 10 additions & 1 deletion bigframes/session/polars_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,20 @@
nodes.OrderByNode,
nodes.ReversedNode,
nodes.SelectionNode,
nodes.ProjectionNode,
nodes.SliceNode,
nodes.AggregateNode,
)

_COMPATIBLE_SCALAR_OPS = ()
_COMPATIBLE_SCALAR_OPS = (
bigframes.operations.eq_op,
bigframes.operations.eq_null_match_op,
bigframes.operations.ne_op,
bigframes.operations.gt_op,
bigframes.operations.lt_op,
bigframes.operations.ge_op,
bigframes.operations.le_op,
)
_COMPATIBLE_AGG_OPS = (agg_ops.SizeOp, agg_ops.SizeUnaryOp)


Expand Down
70 changes: 70 additions & 0 deletions tests/system/small/engines/test_comparison_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools

import pytest

from bigframes.core import array_value
import bigframes.operations as ops
from bigframes.session import polars_executor
from bigframes.testing.engine_utils import assert_equivalence_execution

pytest.importorskip("polars")

# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree.
REFERENCE_ENGINE = polars_executor.PolarsExecutor()

# numeric domain


def apply_op_pairwise(
array: array_value.ArrayValue, op: ops.BinaryOp, excluded_cols=[]
) -> array_value.ArrayValue:
exprs = []
for l_arg, r_arg in itertools.permutations(array.column_ids, 2):
if (l_arg in excluded_cols) or (r_arg in excluded_cols):
continue
try:
_ = op.output_type(
array.get_column_type(l_arg), array.get_column_type(r_arg)
)
exprs.append(op.as_expr(l_arg, r_arg))
except TypeError:
continue
assert len(exprs) > 0
new_arr, _ = array.compute_values(exprs)
return new_arr


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize(
"op",
[
ops.eq_op,
ops.eq_null_match_op,
ops.ne_op,
ops.gt_op,
ops.lt_op,
ops.le_op,
ops.ge_op,
],
)
def test_engines_project_comparison_op(
scalars_array_value: array_value.ArrayValue, engine, op
):
# exclude string cols as does not contain dates
# bool col actually doesn't work properly for bq engine
arr = apply_op_pairwise(scalars_array_value, op, excluded_cols=["string_col"])
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)