Skip to content

Fix indexing #400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 75 additions & 75 deletions src/blosc2/lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
k: v
for k, v in dict(sys._getframe(_frame_depth).f_locals).items()
if (
(hasattr(v, "shape") or np.isscalar(v))
and
# Do not overwrite the local_dict with the expression variables
not (k in local_dict or k in ("_where_x", "_where_y"))
(hasattr(v, "shape") or np.isscalar(v))
and
# Do not overwrite the local_dict with the expression variables
not (k in local_dict or k in ("_where_x", "_where_y"))
)
}
if blosc2.IS_WASM:
Expand Down Expand Up @@ -472,7 +472,7 @@ def convert_inputs(inputs):
inputs_ = []
for obj in inputs:
if not isinstance(
obj, np.ndarray | blosc2.NDArray | blosc2.NDField | blosc2.C2Array
obj, np.ndarray | blosc2.NDArray | blosc2.NDField | blosc2.C2Array
) and not np.isscalar(obj):
try:
obj = np.asarray(obj)
Expand Down Expand Up @@ -638,10 +638,10 @@ def __init__(self):
def visit_Call(self, node):
# Check if the call is a numpy type-casting call
if (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id in ["np", "numpy"]
and isinstance(node.args[0], ast.Constant)
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id in ["np", "numpy"]
and isinstance(node.args[0], ast.Constant)
):
# Create a new temporary variable name
tmp_var = f"tmp{self.tmp_counter}"
Expand Down Expand Up @@ -863,7 +863,7 @@ def read_nchunk(arrs, info):


def fill_chunk_operands( # noqa: C901
operands, slice_, chunks_, full_chunk, aligned, nchunk, iter_disk, chunk_operands, reduc=False
operands, slice_, chunks_, full_chunk, aligned, nchunk, iter_disk, chunk_operands, reduc=False
):
"""Retrieve the chunk operands for evaluating an expression.

Expand Down Expand Up @@ -938,9 +938,9 @@ def fill_chunk_operands( # noqa: C901

# If key is in operands, we can reuse the buffer
if (
key in chunk_operands
and chunks_ == chunk_operands[key].shape
and isinstance(value, blosc2.NDArray)
key in chunk_operands
and chunks_ == chunk_operands[key].shape
and isinstance(value, blosc2.NDArray)
):
value.get_slice_numpy(chunk_operands[key], (starts, stops))
continue
Expand All @@ -955,10 +955,10 @@ def fill_chunk_operands( # noqa: C901


def fast_eval( # noqa: C901
expression: str | Callable[[tuple, np.ndarray, tuple[int]], None],
operands: dict,
getitem: bool,
**kwargs,
expression: str | Callable[[tuple, np.ndarray, tuple[int]], None],
operands: dict,
getitem: bool,
**kwargs,
) -> blosc2.NDArray | np.ndarray:
"""Evaluate the expression in chunks of operands using a fast path.

Expand Down Expand Up @@ -1125,11 +1125,11 @@ def compute_start_index(shape, slice_obj):


def slices_eval( # noqa: C901
expression: str | Callable[[tuple, np.ndarray, tuple[int]], None],
operands: dict,
getitem: bool,
_slice=None,
**kwargs,
expression: str | Callable[[tuple, np.ndarray, tuple[int]], None],
operands: dict,
getitem: bool,
_slice=None,
**kwargs,
) -> blosc2.NDArray | np.ndarray:
"""Evaluate the expression in chunks of operands.

Expand Down Expand Up @@ -1259,9 +1259,9 @@ def slices_eval( # noqa: C901
continue
# If key is in operands, we can reuse the buffer
if (
key in chunk_operands
and slice_shape == chunk_operands[key].shape
and isinstance(value, blosc2.NDArray)
key in chunk_operands
and slice_shape == chunk_operands[key].shape
and isinstance(value, blosc2.NDArray)
):
value.get_slice_numpy(chunk_operands[key], (starts, stops))
continue
Expand Down Expand Up @@ -1345,9 +1345,9 @@ def slices_eval( # noqa: C901
out[slice_] = result
elif len(where) == 1:
lenres = len(result)
out[lenout : lenout + lenres] = result
out[lenout: lenout + lenres] = result
if _order is not None:
indices_[lenout : lenout + lenres] = chunk_indices
indices_[lenout: lenout + lenres] = chunk_indices
lenout += lenres
else:
raise ValueError("The where condition must be a tuple with one or two elements")
Expand Down Expand Up @@ -1394,11 +1394,11 @@ def infer_reduction_dtype(dtype, operation):


def reduce_slices( # noqa: C901
expression: str | Callable[[tuple, np.ndarray, tuple[int]], None],
operands: dict,
reduce_args,
_slice=None,
**kwargs,
expression: str | Callable[[tuple, np.ndarray, tuple[int]], None],
operands: dict,
reduce_args,
_slice=None,
**kwargs,
) -> blosc2.NDArray | np.ndarray:
"""Evaluate the expression in chunks of operands.

Expand Down Expand Up @@ -1544,9 +1544,9 @@ def reduce_slices( # noqa: C901
continue
# If key is in operands, we can reuse the buffer
if (
key in chunk_operands
and chunks_ == chunk_operands[key].shape
and isinstance(value, blosc2.NDArray)
key in chunk_operands
and chunks_ == chunk_operands[key].shape
and isinstance(value, blosc2.NDArray)
):
value.get_slice_numpy(chunk_operands[key], (starts, stops))
continue
Expand Down Expand Up @@ -1657,7 +1657,7 @@ def convert_none_out(dtype, reduce_op, reduced_shape):


def chunked_eval( # noqa: C901
expression: str | Callable[[tuple, np.ndarray, tuple[int]], None], operands: dict, item=None, **kwargs
expression: str | Callable[[tuple, np.ndarray, tuple[int]], None], operands: dict, item=None, **kwargs
):
"""
Evaluate the expression in chunks of operands.
Expand Down Expand Up @@ -1716,7 +1716,7 @@ def chunked_eval( # noqa: C901
# When using getitem, taking the fast path is always possible
return fast_eval(expression, operands, getitem=True, **kwargs)
elif (kwargs.get("chunks") is None and kwargs.get("blocks") is None) and (
out is None or isinstance(out, blosc2.NDArray)
out is None or isinstance(out, blosc2.NDArray)
):
# If not, the conditions to use the fast path are a bit more restrictive
# e.g. the user cannot specify chunks or blocks, or an output that is not
Expand Down Expand Up @@ -1773,7 +1773,7 @@ def fuse_expressions(expr, new_base, dup_op):
break
if expr[i + j] == ")":
j -= 1
old_pos = int(expr[i + 1 : i + j + 1])
old_pos = int(expr[i + 1: i + j + 1])
old_op = f"o{old_pos}"
if old_op not in dup_op:
if old_pos in prev_pos:
Expand Down Expand Up @@ -1977,9 +1977,9 @@ def dtype(self):
# In some situations, we already know the dtype
return self._dtype
if (
hasattr(self, "_dtype_")
and hasattr(self, "_expression_")
and self._expression_ == self.expression
hasattr(self, "_dtype_")
and hasattr(self, "_expression_")
and self._expression_ == self.expression
):
# Use the cached dtype
return self._dtype_
Expand Down Expand Up @@ -2019,9 +2019,9 @@ def shape(self):
if hasattr(self, "_shape"):
return self._shape
if (
hasattr(self, "_shape_")
and hasattr(self, "_expression_")
and self._expression_ == self.expression
hasattr(self, "_shape_")
and hasattr(self, "_expression_")
and self._expression_ == self.expression
):
# Use the cached shape
return self._shape_
Expand Down Expand Up @@ -2396,14 +2396,14 @@ def find_args(expr):
idx = expression.find(f"{constructor}")
# Find the arguments of the constructor function
try:
args, idx2 = find_args(expression[idx + len(constructor) :])
args, idx2 = find_args(expression[idx + len(constructor):])
except ValueError as err:
raise ValueError(f"Unbalanced parenthesis in expression: {expression}") from err
idx2 = idx + len(constructor) + idx2

# Give a chance to a possible .reshape() method
if expression[idx2 : idx2 + len(".reshape(")] == ".reshape(":
args2, idx3 = find_args(expression[idx2 + len("reshape(") :])
if expression[idx2: idx2 + len(".reshape(")] == ".reshape(":
args2, idx3 = find_args(expression[idx2 + len("reshape("):])
# Remove a possible shape= from the reshape call (due to rewriting the expression
# via extract_numpy_scalars(), other variants like .reshape(shape = shape_) work too)
args2 = args2.replace("shape=", "")
Expand Down Expand Up @@ -2440,16 +2440,16 @@ def _compute_expr(self, item, kwargs): # noqa: C901
if len(self._where_args) == 1:
# We have a single argument
where_x = self._where_args["_where_x"]
return where_x[:][lazy_expr]
return (where_x[:][lazy_expr])[item]
if len(self._where_args) == 2:
# We have two arguments
where_x = self._where_args["_where_x"]
where_y = self._where_args["_where_y"]
return np.where(lazy_expr, where_x, where_y)
return np.where(lazy_expr, where_x, where_y)[item]
if hasattr(self, "_output"):
# This is not exactly optimized, but it works for now
self._output[:] = lazy_expr
return lazy_expr
self._output[:] = lazy_expr[item]
return lazy_expr[item]

return chunked_eval(lazy_expr.expression, lazy_expr.operands, item, **kwargs)

Expand Down Expand Up @@ -2534,10 +2534,10 @@ def compute(self, item=None, **kwargs) -> blosc2.NDArray:
x = self._where_args["_where_x"]
result = x[result] # always a numpy array; TODO: optimize this for _getitem not in kwargs
if (
"_getitem" not in kwargs
and "_output" not in kwargs
and "_reduce_args" not in kwargs
and not isinstance(result, blosc2.NDArray)
"_getitem" not in kwargs
and "_output" not in kwargs
and "_reduce_args" not in kwargs
and not isinstance(result, blosc2.NDArray)
):
# Get rid of all the extra kwargs that are not accepted by blosc2.asarray
kwargs_not_accepted = {"_where_args", "_indices", "_order", "_ne_args", "dtype"}
Expand Down Expand Up @@ -2824,7 +2824,7 @@ def compute(self, item=None, **kwargs):
_ = kwargs.pop("dparams", None)
urlpath = kwargs.get("urlpath")
if urlpath is not None and urlpath == aux_kwargs.get(
"urlpath",
"urlpath",
):
raise ValueError("Cannot use the same urlpath for LazyArray and eval NDArray")
_ = aux_kwargs.pop("urlpath", None)
Expand Down Expand Up @@ -2875,12 +2875,12 @@ def save(self, **kwargs):


def lazyudf(
func: Callable[[tuple, np.ndarray, tuple[int]], None],
inputs: tuple | list | None,
dtype: np.dtype,
shape: tuple | list | None = None,
chunked_eval: bool = True,
**kwargs: Any,
func: Callable[[tuple, np.ndarray, tuple[int]], None],
inputs: tuple | list | None,
dtype: np.dtype,
shape: tuple | list | None = None,
chunked_eval: bool = True,
**kwargs: Any,
) -> LazyUDF:
"""
Get a LazyUDF from a python user-defined function.
Expand Down Expand Up @@ -2980,14 +2980,14 @@ def seek_operands(names, local_dict=None, global_dict=None, _frame_depth: int =


def lazyexpr(
expression: str | bytes | LazyExpr | blosc2.NDArray,
operands: dict | None = None,
out: blosc2.NDArray | np.ndarray = None,
where: tuple | list | None = None,
local_dict: dict | None = None,
global_dict: dict | None = None,
ne_args: dict | None = None,
_frame_depth: int = 2,
expression: str | bytes | LazyExpr | blosc2.NDArray,
operands: dict | None = None,
out: blosc2.NDArray | np.ndarray = None,
where: tuple | list | None = None,
local_dict: dict | None = None,
global_dict: dict | None = None,
ne_args: dict | None = None,
_frame_depth: int = 2,
) -> LazyExpr:
"""
Get a LazyExpr from an expression.
Expand Down Expand Up @@ -3125,11 +3125,11 @@ def _open_lazyarray(array):

# Mimim numexpr's evaluate function
def evaluate(
ex: str,
local_dict: dict | None = None,
global_dict: dict | None = None,
out: np.ndarray | blosc2.NDArray = None,
**kwargs: Any,
ex: str,
local_dict: dict | None = None,
global_dict: dict | None = None,
out: np.ndarray | blosc2.NDArray = None,
**kwargs: Any,
) -> np.ndarray | blosc2.NDArray:
"""
Evaluate a string expression using the Blosc2 compute engine.
Expand Down
18 changes: 12 additions & 6 deletions tests/ndarray/test_lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_proxy_simple_expression(array_fixture):

def test_iXXX(array_fixture):
a1, a2, a3, a4, na1, na2, na3, na4 = array_fixture
expr = a1**3 + a2**2 + a3**3 - a4 + 3
expr = a1 ** 3 + a2 ** 2 + a3 ** 3 - a4 + 3
expr += 5 # __iadd__
expr -= 15 # __isub__
expr *= 2 # __imul__
Expand Down Expand Up @@ -1126,7 +1126,7 @@ def test_fill_disk_operands(chunks, blocks, disk, fill_value):
b = blosc2.open("b.b2nd")
c = blosc2.open("c.b2nd")

expr = ((a**3 + blosc2.sin(c * 2)) < b) & (c > 0)
expr = ((a ** 3 + blosc2.sin(c * 2)) < b) & (c > 0)

out = expr.compute()
assert out.shape == (N, N)
Expand Down Expand Up @@ -1256,6 +1256,13 @@ def test_indices():
expr.indices().compute()


def test_reduction_index():
shape = (20, 20)
a = blosc2.linspace(0, 20, num=np.prod(shape), shape=shape)
expr = blosc2.sum(a, axis=0)
assert expr[:10].shape == (10,)
assert expr[0].shape == ()

def test_sort():
shape = (20,)
na = np.arange(shape[0])
Expand Down Expand Up @@ -1307,9 +1314,9 @@ def test_only_ndarrays_or_constructors(obj, getitem, item):
def test_numpy_funcs(array_fixture, func):
a1, a2, a3, a4, na1, na2, na3, na4 = array_fixture
npfunc = getattr(np, func)
d_blosc2 = npfunc(((a1**3 + blosc2.sin(na2 * 2)) < a3) & (na2 > 0), axis=0)
d_blosc2 = npfunc(((a1 ** 3 + blosc2.sin(na2 * 2)) < a3) & (na2 > 0), axis=0)
npfunc = getattr(np, func)
d_numpy = npfunc(((na1**3 + np.sin(na2 * 2)) < na3) & (na2 > 0), axis=0)
d_numpy = npfunc(((na1 ** 3 + np.sin(na2 * 2)) < na3) & (na2 > 0), axis=0)
np.testing.assert_equal(d_blosc2, d_numpy)


Expand Down Expand Up @@ -1346,7 +1353,7 @@ def test_chain_expressions():
b = blosc2.linspace(1, 2, N * N, dtype=dtype, shape=(N, N))
c = blosc2.linspace(0, 1, N, dtype=dtype, shape=(N,))

le1 = a**3 + blosc2.sin(a**2)
le1 = a ** 3 + blosc2.sin(a ** 2)
le2 = le1 < c
le3 = le2 & (b < 0)
le1_ = blosc2.lazyexpr("a ** 3 + sin(a ** 2)", {"a": a})
Expand All @@ -1365,7 +1372,6 @@ def test_chain_expressions():
# le4_ = blosc2.lazyexpr("(le2 & le3)", {"le2": le2_, "le3": le3_})
# assert (le4_[:] == le4[:]).all()


# TODO: Test the chaining of multiple persistent lazy expressions
# def test_chain_persistentexpressions():
# N = 1_000
Expand Down
Loading