Skip to content

Enable chaining of lazy expressions for logical operators #391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/blosc2/lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,6 @@ def update_expr(self, new_op): # noqa: C901
if hasattr(value2, "_where_args"):
value2 = value2.compute()

self._dtype = infer_dtype(op, value1, value2)
if not isinstance(value1, LazyExpr) and not isinstance(value2, LazyExpr):
# We converted some of the operands to NDArray (where() handling above)
new_operands = {"o0": value1, "o1": value2}
Expand Down Expand Up @@ -2677,8 +2676,8 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
_shape = new_expr.shape
if isinstance(new_expr, blosc2.LazyExpr):
# Restore the original expression and operands
new_expr.expression = _expression
new_expr.expression_tosave = expression
new_expr.expression = f"({_expression})" # forcibly add parenthesis
new_expr.expression_tosave = new_expr.expression
new_expr.operands = _operands
new_expr.operands_tosave = operands
else:
Expand Down
22 changes: 22 additions & 0 deletions tests/ndarray/test_lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,10 @@ def test_dtype_infer(dtype1, dtype2, scalar):
np.testing.assert_allclose(res, nres)
assert res.dtype == nres.dtype

# Check dtype not changed by expression creation (bug fix)
assert a.dtype == dtype1
assert b.dtype == dtype2


@pytest.mark.parametrize(
"cfunc", ["np.int8", "np.int16", "np.int32", "np.int64", "np.float32", "np.float64"]
Expand Down Expand Up @@ -1330,3 +1334,21 @@ def test_missing_operator():
# Clean up
blosc2.remove_urlpath("a.b2nd")
blosc2.remove_urlpath("expr.b2nd")


# Test the chaining of multiple lazy expressions
def test_chain_expressions():
N = 1_000
dtype = "float64"
a = blosc2.linspace(0, 1, N * N, dtype=dtype, shape=(N, N))
b = blosc2.linspace(1, 2, N * N, dtype=dtype, shape=(N, N))
c = blosc2.linspace(0, 1, N, dtype=dtype, shape=(N,))

le1 = a**3 + blosc2.sin(a**2)
le2 = le1 < c
le3 = le2 & (b < 0)

le1_ = blosc2.lazyexpr("a ** 3 + sin(a ** 2)", {"a": a})
le2_ = blosc2.lazyexpr("(le1 < c)", {"le1": le1_, "c": c})
le3_ = blosc2.lazyexpr("(le2 & (b < 0))", {"le2": le2_, "b": b})
assert (le3_[:] == le3[:]).all()