Skip to content

Commit b4a9355

Browse files
committed
Added some functions to iter.py:
- first_i: an index-version of first - isin_sorted: a numba-compiled 1D version of np.isin optimised for sorted arrays (10x faster) - isin_sorted_intervals: similar to the above but returning the indices of subsequences of the first array also found in the second; allows non-strictness (i.e. ignoring duplicates on both sides) - complement_intervals: inverts a sequence of intervals, mostly written for use with isin_sorted_intervals The isin_* functions are the first use of numba in iter.py; could move them elsewhere in the future, but for these tasks it does make sense to use arrays and not lists.
1 parent 9702612 commit b4a9355

File tree

7 files changed

+112
-35
lines changed

7 files changed

+112
-35
lines changed

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,24 @@ It also contains a variety of convenience functions for the
2222
[numba](https://numba.pydata.org/) JIT compiler library.
2323

2424
See the [documentation][rtd-link] for details, but functions are grouped as follows:
25-
- Generic_Util.benchmarking: functions covering typical code-timing scenarios,
25+
- `Generic_Util.benchmarking`: functions covering typical code-timing scenarios,
2626
such as a "with" statement context, an n-executions timer,
2727
and a convenient function for comparing and summarising n execution times of different implementations
2828
of the same function.
29-
- Generic_Util.iter: iterable-focussed functions, covering multiple varieties of
29+
- `Generic_Util.iter`: iterable-focussed functions, covering multiple varieties of
3030
flattening, iterable combining,
3131
grouping and predicate/property-based processing (including topological sorting),
3232
element-comparison-based operations, value interspersal, and finally batching.
33-
- Generic_Util.operator: functions regarding item retrieval, and syntactic sugar for patterns of function application.
34-
- Generic_Util.misc: functions with less generic purpose than in the above; currently mostly to do with min/max-based operations.
33+
- `Generic_Util.operator`: functions regarding item retrieval, and syntactic sugar for patterns of function application.
34+
- `Generic_Util.misc`: functions with less generic purpose than in the above; currently mostly to do with min/max-based operations.
3535

3636
Then a sub-package is dedicated to utility functions for the [numba](https://numba.pydata.org/) JIT compiler library:
37-
- Generic_Util.numba.benchmarking: functions comparing execution times of (semi-automatically-generated) varieties of
37+
- `Generic_Util.numba.benchmarking`: functions comparing execution times of (semi-automatically-generated) varieties of
3838
numba-compilations of a given function, including
3939
lazy vs eager compilation, vectorisation, parallelisation, as well as varieties of rolling (see Generic_Util.numba.higher_order).
40-
- Generic_Util.numba.higher_order: higher-order numba-compilation functions, currently only functions to "roll" simpler functions
40+
- `Generic_Util.numba.higher_order`: higher-order numba-compilation functions, currently only functions to "roll" simpler functions
4141
(1d-to-scalar or 2d-to-scalar/1d) over arrays, with a few combinations of input and output type signatures.
42-
- Generic_Util.numba.types: convenient shorthands for frequently used numba (and respective numpy) types, with a focus on
42+
- `Generic_Util.numba.types`: convenient shorthands for frequently used numba (and respective numpy) types, with a focus on
4343
C-contiguity of arrays; these are useful in declaring eager-compilation function signatures.
4444

4545
Many functions which would have been included in this package were dropped in favour of using those in the wonderful

docs/index.rst

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ It also contains a variety of convenience functions for the
1717

1818
The functions are grouped as follows:
1919

20-
- Generic_Util.benchmarking: functions covering typical code-timing scenarios,
20+
- :py:mod:`Generic_Util.benchmarking`: functions covering typical code-timing scenarios,
2121
such as a "with" statement context, an n-executions timer,
2222
and a convenient function for comparing and summarising n execution times of different implementations
2323
of the same function.
24-
- Generic_Util.iter: iterable-focussed functions, covering multiple varieties of
24+
- :py:mod:`Generic_Util.iter`: iterable-focussed functions, covering multiple varieties of
2525
flattening, iterable combining,
2626
grouping and predicate/property-based processing (including topological sorting),
2727
element-comparison-based operations, value interspersal, and finally batching.
28-
- Generic_Util.operator: functions regarding item retrieval, and syntactic sugar for patterns of function application.
29-
- Generic_Util.misc: functions with less generic purpose than in the above; currently mostly to do with min/max-based operations.
28+
- :py:mod:`Generic_Util.operator`: functions regarding item retrieval, and syntactic sugar for patterns of function application.
29+
- :py:mod:`Generic_Util.misc`: functions with less generic purpose than in the above; currently mostly to do with min/max-based operations.
3030

3131
Many functions which would have been included here were dropped in favour of using those in the wonderful
3232
`more-itertools <https://github.com/more-itertools/more-itertools>`_ package
@@ -37,13 +37,13 @@ source of algorithm-simplifying ingredients.
3737

3838
Then a sub-package is dedicated to utility functions for the `numba <https://numba.pydata.org/>`_ JIT compiler library:
3939

40-
- Generic_Util.numba.benchmarking: functions comparing execution times of (semi-automatically-generated) varieties of
40+
- :py:mod:`Generic_Util.numba.benchmarking`: functions comparing execution times of (semi-automatically-generated) varieties of
4141
numba-compilations of a given function, including
4242
lazy vs eager compilation, vectorisation, parallelisation, as well as varieties of rolling (see Generic_Util.numba.higher_order).
43-
- Generic_Util.numba.higher_order: higher-order numba-compilation functions, currently only functions to "roll" simpler functions
43+
- :py:mod:`Generic_Util.numba.higher_order`: higher-order numba-compilation functions, currently only functions to "roll" simpler functions
4444
(1d-to-scalar or 2d-to-scalar/1d) over arrays, with a few combinations of input and output type signatures.
45-
- Generic_Util.numba.types: convenient shorthands for frequently used numba (and respective numpy) types, with a focus on
46-
C-contiguity of arrays; these are useful in declaring eager-compilation function signatures.
45+
- :py:mod:`Generic_Util.numba.types`: convenient shorthands for frequently used numba (and respective numpy) types, with a focus on
46+
C-contiguity of arrays; these are useful in declaring eager-compilation function signatures.
4747

4848

4949

@@ -62,4 +62,5 @@ Indices and tables
6262

6363
* :ref:`genindex`
6464
* :ref:`modindex`
65-
* :ref:`search`
65+
..
66+
* :ref:`search`

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dependencies = [
2626
"typing_extensions >=3.7; python_version<'3.8'",
2727
"pandas>=1.5.3",
2828
"numpy>=1.23.5",
29+
"numba>=0.56.4",
2930
"sortedcontainers>=2.4.0",
3031
]
3132

src/Generic_Util/benchmarking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
@contextmanager
1414
def time_context(name: str = None):
15-
'''"with" statement context for timing the execution of the enclosed code block, i.e. `with time_context('Name of code block'): ...` '''
15+
'''"with" statement context for timing the execution of the enclosed code block, i.e. ``with time_context('Name of code block'): ...`` '''
1616
start = time.perf_counter()
1717
yield # No need to yield anything or for the above to be in a "try" and the below in a customary "finally" since there is no dangling resource
1818
end = time.perf_counter()
@@ -58,7 +58,7 @@ def time_n(f: Callable, n = 2, *args, **kwargs):
5858
def compare_implementations(fs_with_shared_args: dict[str, Callable], n = 200, wait = 1, verbose = True,
5959
fs_with_own_args: dict[str, tuple[Callable, list, dict]] = None, args: list = None, kwargs: dict = None):
6060
'''Benchmark multiple implementations of the same function called n times (each with the same args and kwargs), with a break between functions.
61-
Recommended later output view if verbose is False: `print(table.to_markdown(index = False))`.
61+
Recommended later output view if verbose is False: ``print(table.to_markdown(index = False))``.
6262
:param fs_with_own_args: alternative to fs_with_shared_args, args and kwargs arguments: meant for additional functions taking different *args and **kwargs.'''
6363
assert n >= 3
6464
table = []

src/Generic_Util/iter.py

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import numpy as np
1212
from numpy.typing import NDArray
1313

14+
from Generic_Util.numba.types import njit, b1, b1A, b1_NP, i8, i8A, f8A, i8A2, i8_NP
15+
1416
from typing import TypeVar, Callable, Union, Sequence, Iterable, Iterator, Generator, Any, Generic, Mapping
1517
_a = TypeVar('_a')
1618
_b = TypeVar('_b')
@@ -32,7 +34,7 @@ def deep_flatten(xss_: Iterable) -> Generator:
3234

3335
def deep_extract(xss_: Iterable[Iterable], *key_path) -> Generator:
3436
'''Given a nested combination of iterables and a path of keys in it, return a deep_flatten-ed list of entries under said path.
35-
Note: `deep_extract(xss_, *key_path) == deep_flatten(Generic_Util.operator.get_nested(xss_, *key_path))`'''
37+
Note: ``deep_extract(xss_, *key_path) == deep_flatten(Generic_Util.operator.get_nested(xss_, *key_path))``'''
3638
level = xss_
3739
for k in key_path: level = level[k]
3840
return deep_flatten(level)
@@ -64,16 +66,16 @@ def all_combinations(xs: Sequence, min_size = 1, max_size = None) -> list:
6466
## Predicate/Property-Based Functions
6567

6668
def partition(p: Callable[[_a], bool], xs: Iterable[_a]) -> tuple[Iterable[_a], Iterable[_a]]:
67-
'''Haskell's partition function, partitioning xs by some boolean predicate p: `partition p xs == (filter p xs, filter (not . p) xs)`.'''
69+
'''Haskell's partition function, partitioning xs by some boolean predicate p: ``partition p xs == (filter p xs, filter (not . p) xs)``.'''
6870
acc = ([],[])
6971
for x in xs: acc[not p(x)].append(x)
7072
return acc
7173

7274
def group_by(f: Callable[[_a], _b], xs: Iterable[_a]) -> dict[_b, list[_a]]:
7375
'''Generalisation of partition to any-output key-function.
74-
Notes:
75-
- 'Retrieval' functions from the operator package are typical f values (`itemgetter(...)`, `attrgetter(...)` or `methodcaller(...)`)
76-
- This is NOT Haskell's groupBy function'''
76+
Notes:
77+
- 'Retrieval' functions from the operator package are typical f values (``itemgetter(...)``, ``attrgetter(...)`` or ``methodcaller(...)``)
78+
- This is NOT Haskell's groupBy function'''
7779
acc = defaultdict(list)
7880
for x in xs: acc[f(x)].append(x)
7981
return acc
@@ -82,15 +84,19 @@ def first(c: Callable[[_a], bool], xs: Iterable[_a], default: _a = None) -> _a:
8284
'''Return the first value in xs which satisfies condition c.'''
8385
return next((x for x in xs if c(x)), default)
8486

87+
def first_i(c: Callable[[_a], bool], xs: Sequence[_a], default: _a = None) -> _a:
88+
'''Return the index of the first value in xs which satisfies condition c.'''
89+
return next((i for i in range(len(xs)) if c(xs[i])), default)
90+
8591
def foldq(f: Callable[[_b, _a], _b], g: Callable[[_b, _a, list[_a]], list[_a]], c: Callable[[_a], bool], xs: Sequence[_a], acc: _b) -> tuple[_b, list[_a]]:
8692
'''
8793
Fold-like higher-order function where xs is traversed by consumption conditional on c, and remaining xs are updated by g
8894
(therefore consumption order is not known a priori):
8995
- the first/next item to be ingested is the first in the remaining xs to fulfil condition c
9096
- at every x ingestion, the item is removed from (a copy of) xs, and all the remaining ones are potentially modified by function g
91-
- this function always returns a tuple of `(acc, remaining_xs)`, unlike the stricter `foldq_`, which raises an exception for leftover xs
97+
- this function always returns a tuple of ``(acc, remaining_xs)``, unlike the stricter ``foldq_``, which raises an exception for leftover xs
9298
93-
Note: `fold(f, xs, acc) == foldq(f, lambda acc, x, xs: xs, lambda x: True, xs, acc)`.
99+
Note: ``fold(f, xs, acc) == foldq(f, lambda acc, x, xs: xs, lambda x: True, xs, acc)``.
94100
95101
Sequence of suitable names leading to the current one: consumption_fold, condition_update_fold, cu_fold, q_fold, qfold or foldq
96102
:param f: 'Traditional' fold function :: acc -> x -> acc
@@ -119,7 +125,7 @@ def full_step(acc, xs): # Alternative implementation: move function content insi
119125
return acc, xs
120126

121127
def foldq_(f: Callable[[_b, _a], _b], g: Callable[[_b, _a, list[_a]], list[_a]], c: Callable[[_a], bool], xs: list[_a], acc: _b) -> _b:
122-
r'''Stricter version of foldq (see its description for details); only returns the accumulator and raises an exception on leftover xs.
128+
'''Stricter version of foldq (see its description for details); only returns the accumulator and raises an exception on leftover xs.
123129
:raises ValueError on leftover xs'''
124130
acc, xs = foldq(f, g, c, xs, acc)
125131
if xs: raise ValueError('No suitable next element found for given condition while elements remain')
@@ -157,7 +163,7 @@ def unique_by(f: Callable[[_a], Any], xs: Iterable[_a]) -> list[_a]:
157163
return [x for x in xs if (fx := f(x), ) if fx not in seen and not seen.append(fx)] # Neat true-tuple assignment and neat short-circuit 'and' trick
158164

159165
def eq_elems(xs: Iterable[_a], ys: Iterable[_a]) -> bool:
160-
'''Equality of iterables by their elements'''
166+
'''Equality of iterables by their elements.'''
161167
cys = list(ys) # make a mutable copy
162168
try:
163169
for x in xs: cys.remove(x)
@@ -166,15 +172,84 @@ def eq_elems(xs: Iterable[_a], ys: Iterable[_a]) -> bool:
166172

167173
def diff(xs: Iterable[_a], ys: Iterable[_a]) -> list[_a]:
168174
'''Difference of iterables.
169-
Notes:
170-
- not a set difference, so strictly removing as many xs duplicate entries as there are in ys
171-
- preserves order in xs'''
175+
Notes:
176+
- not a set difference, so strictly removing as many xs duplicate entries as there are in ys
177+
- preserves order in xs'''
172178
cxs = list(xs) # make a mutable copy
173179
try:
174180
for y in ys: cxs.remove(y)
175181
except ValueError: pass
176182
return cxs
177183

184+
@njit([b1A(i8A, i8A), b1A(f8A, f8A)])
185+
def isin_sorted(xs: NDArray[_a], ys: NDArray[_a]) -> NDArray[b1_NP]:
186+
'''Optimised (10x faster) 1D version of np.isin assuming BOTH xs and ys are (ascendingly) sorted arrays of the same type (both int or both float):
187+
return a boolean array of the same length as xs indicating whether the corresponding element at that index is present in ys.'''
188+
res = np.zeros_like(xs, dtype = b1_NP)
189+
i = j = 0
190+
while i < len(xs) and j < len(ys):
191+
if ys[j] < xs[i]: j += 1 # Let the ys catch-up to the xs
192+
elif ys[j] == xs[i]: res[i], i = True, i + 1 # Could increase j here as well, but it would imply assuming ys is strictly increasing
193+
else: i += 1 # Let the xs catch up to the ys
194+
return res
195+
196+
@njit([i8A2(i8A, i8A, b1), i8A2(f8A, f8A, b1)])
197+
def isin_sorted_intervals(xs: NDArray[_a], ys: NDArray[_a], strict = True) -> NDArray[i8_NP]:
198+
'''Assuming BOTH xs and ys are (ascendingly) sorted arrays of the same type (both int or both float):
199+
return a 2D array of the intervals (in terms of indices of xs) of subsequences shared with ys.
200+
201+
Notes:
202+
- The interval-end indices are of the last matching value, not of the next (non-matching) one; run ``res[:,1] += 1`` to switch the behaviour
203+
- If the desired intervals are the opposite of the matching ones, call ``complement_intervals(intervals, len(xs), closed_interval_ends = True)``
204+
205+
:param strict: Whether xs and ys subsequences need to match exactly (e.g. 1223 will not match 123 and vice-versa if strict).
206+
If the strictness of the assumption is not respected, the function will produce extra len-1 intervals for each duplicate.'''
207+
intervals = np.empty((len(xs) // 2, 2), dtype = i8_NP) # There can be at most half as many intervals as values
208+
i = j = k = 0
209+
if strict: # xs and ys subsequences need to match exactly
210+
while i < len(xs) and j < len(ys):
211+
if ys[j] < xs[i]: j += 1 # Let the ys catch-up to the xs
212+
if ys[j] == xs[i]:
213+
intervals[k, 0] = i
214+
while i < len(xs) and j < len(ys) and xs[i] == ys[j]: i, j = i + 1, j + 1
215+
intervals[k, 1] = i - 1
216+
k += 1
217+
else: i += 1 # Let the xs catch up to the ys
218+
else: # xs and ys subsequences tolerate non-matching duplicates
219+
while i < len(xs) and j < len(ys):
220+
if ys[j] < xs[i]: j += 1 # Let the ys catch-up to the xs
221+
if ys[j] == xs[i]:
222+
intervals[k, 0] = i
223+
while i < len(xs) and j < len(ys):
224+
if xs[i] == ys[j]: i, j = i + 1, j + 1 # Guaranteed outcome in first iteration, hence reasoning with -1s below
225+
elif xs[i] == xs[i - 1]: i += 1
226+
elif ys[j] == ys[j - 1]: j += 1
227+
else: break
228+
intervals[k, 1] = i - 1
229+
k += 1
230+
else: i += 1 # Let the xs catch up to the ys
231+
return intervals[:k,...] # not k-1 since already ++ed
232+
233+
def complement_intervals(intervals: NDArray[int], true_length: int, closed_interval_ends = True) -> NDArray[int]:
234+
'''Invert a series of index intervals (in the form of an (n,2)-array),
235+
i.e. return an (m,2)-array of intervals starting after the ends of the given ones and ending before their starts.
236+
237+
:param true_length: 1 more than the final index for the overall range these intervals are within (which might be ``intervals[-1,1]``); if coming from isin_sorted_intervals, then simply ``len(xs)``
238+
:param closed_interval_ends: whether BOTH input and output intervals are (and will be) closed, i.e. their end-index is INCLUDED in the interval, rather than being the first index after it
239+
'''
240+
if closed_interval_ends: intervals[:, 1] += 1
241+
242+
# Drop to 1D and toggle the presence of initial and final indices
243+
indices = np.reshape(intervals, intervals.size)
244+
indices = indices[1:] if indices[0] == 0 else np.insert(indices, 0, 0)
245+
indices = indices[:-1] if indices[-1] == true_length else np.append(indices, true_length)
246+
247+
# Return to 2D (cardinality is even again, as both previous steps changed it by 1)
248+
flipped_intervals = np.reshape(indices, (len(indices) // 2, 2))
249+
250+
if closed_interval_ends: flipped_intervals[:, 1] -= 1
251+
return flipped_intervals
252+
178253

179254

180255
## Interspersing Functions

src/Generic_Util/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def interval_overlap(ab: tuple[float, float], cd: tuple[float, float]) -> float:
1515
def min_max(xs: Sequence[_a]) -> tuple[_a, _a]:
1616
'''Mathematically most efficient joint identification of min and max (minimum comparisons = 3n/2 - 2).
1717
Note:
18-
- This function is numba-compilable, e.g. as `njit(nTup(f8,f8)(f8[::1]))(min_max)` (see `Generic_Util.numba.types` for `nTup` shorthand),
18+
- This function is numba-compilable, e.g. as ``njit(nTup(f8,f8)(f8[::1]))(min_max)`` (see ``Generic_Util.numba.types`` for ``nTup`` shorthand),
1919
- If using numpy arrays, min and max are cached for O(1) lookup, and one would imagine this is the used algorithm'''
2020
if xs[0] > xs[1]: min, max = xs[1], xs[0] # Initialise
2121
else: min, max = xs[0], xs[1]

src/Generic_Util/operator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,21 @@ def get_nested(xss_: Iterable[Iterable], *key_path) -> Generator:
2727

2828
def on(f: Callable, xs: Iterable[_a], g: Callable[[_a, ...], _b], *args, **kwargs):
2929
'''Transform xs by element-wise application of g and call f with them as its arguments.
30-
E.g. `on(operator.gt, (a, b), len)`
30+
E.g. ``on(operator.gt, (a, b), len)``
3131
Notes:
3232
- *args, **kwargs are for g, not f
33-
- 'Retrieval' functions from the operator package are reasonable g values (`itemgetter(...)`, `attrgetter(...)` or `methodcaller(...)`),
33+
- 'Retrieval' functions from the operator package are reasonable g values (``itemgetter(...)``, ``attrgetter(...)`` or ``methodcaller(...)``),
3434
BUT on_a and on_m are shorthands for the attribute and method cases'''
3535
return f(*[g(x, *args, **kwargs) for x in xs])
3636

3737
def on_a(f: Callable, xs: Iterable, a: str):
3838
'''Extract attribute a from xs elements and call f with them as its arguments.
39-
E.g. `on_a(operator.eq, [a, b], '__class__')`'''
39+
E.g. ``on_a(operator.eq, [a, b], '__class__')``'''
4040
return f(*[getattr(x, a) for x in xs])
4141

4242
def on_m(f: Callable, xs: Iterable, m: str, *args, **kwargs):
4343
'''Call method m on xs elements and call f with their results as its arguments.
44-
E.g. `on_m(operator.gt, [a, b], 'count', 'hello')`
44+
E.g. ``on_m(operator.gt, [a, b], 'count', 'hello')``
4545
Notes:
4646
- *args, **kwargs are for method m, not f'''
4747
return f(*[getattr(x, m)(*args, **kwargs) for x in xs])

0 commit comments

Comments
 (0)