1111import numpy as np
1212from numpy .typing import NDArray
1313
14+ from Generic_Util .numba .types import njit , b1 , b1A , b1_NP , i8 , i8A , f8A , i8A2 , i8_NP
15+
1416from typing import TypeVar , Callable , Union , Sequence , Iterable , Iterator , Generator , Any , Generic , Mapping
1517_a = TypeVar ('_a' )
1618_b = TypeVar ('_b' )
@@ -32,7 +34,7 @@ def deep_flatten(xss_: Iterable) -> Generator:
3234
3335def deep_extract (xss_ : Iterable [Iterable ], * key_path ) -> Generator :
3436 '''Given a nested combination of iterables and a path of keys in it, return a deep_flatten-ed list of entries under said path.
35- Note: `deep_extract(xss_, *key_path) == deep_flatten(Generic_Util.operator.get_nested(xss_, *key_path))`'''
37+ Note: `` deep_extract(xss_, *key_path) == deep_flatten(Generic_Util.operator.get_nested(xss_, *key_path))` `'''
3638 level = xss_
3739 for k in key_path : level = level [k ]
3840 return deep_flatten (level )
@@ -64,16 +66,16 @@ def all_combinations(xs: Sequence, min_size = 1, max_size = None) -> list:
6466## Predicate/Property-Based Functions
6567
6668def partition (p : Callable [[_a ], bool ], xs : Iterable [_a ]) -> tuple [Iterable [_a ], Iterable [_a ]]:
67- '''Haskell's partition function, partitioning xs by some boolean predicate p: `partition p xs == (filter p xs, filter (not . p) xs)`.'''
69+ '''Haskell's partition function, partitioning xs by some boolean predicate p: `` partition p xs == (filter p xs, filter (not . p) xs)` `.'''
6870 acc = ([],[])
6971 for x in xs : acc [not p (x )].append (x )
7072 return acc
7173
7274def group_by (f : Callable [[_a ], _b ], xs : Iterable [_a ]) -> dict [_b , list [_a ]]:
7375 '''Generalisation of partition to any-output key-function.
74- Notes:
75- - 'Retrieval' functions from the operator package are typical f values (`itemgetter(...)`, `attrgetter(...)` or `methodcaller(...)`)
76- - This is NOT Haskell's groupBy function'''
76+ Notes:
77+ - 'Retrieval' functions from the operator package are typical f values (`` itemgetter(...)`` , `` attrgetter(...)`` or `` methodcaller(...)` `)
78+ - This is NOT Haskell's groupBy function'''
7779 acc = defaultdict (list )
7880 for x in xs : acc [f (x )].append (x )
7981 return acc
@@ -82,15 +84,19 @@ def first(c: Callable[[_a], bool], xs: Iterable[_a], default: _a = None) -> _a:
8284 '''Return the first value in xs which satisfies condition c.'''
8385 return next ((x for x in xs if c (x )), default )
8486
87+ def first_i (c : Callable [[_a ], bool ], xs : Sequence [_a ], default : _a = None ) -> _a :
88+ '''Return the index of the first value in xs which satisfies condition c.'''
89+ return next ((i for i in range (len (xs )) if c (xs [i ])), default )
90+
8591def foldq (f : Callable [[_b , _a ], _b ], g : Callable [[_b , _a , list [_a ]], list [_a ]], c : Callable [[_a ], bool ], xs : Sequence [_a ], acc : _b ) -> tuple [_b , list [_a ]]:
8692 '''
8793 Fold-like higher-order function where xs is traversed by consumption conditional on c, and remaining xs are updated by g
8894 (therefore consumption order is not known a priori):
8995 - the first/next item to be ingested is the first in the remaining xs to fulfil condition c
9096 - at every x ingestion, the item is removed from (a copy of) xs, and all the remaining ones are potentially modified by function g
91- - this function always returns a tuple of `(acc, remaining_xs)`, unlike the stricter `foldq_`, which raises an exception for leftover xs
97+ - this function always returns a tuple of `` (acc, remaining_xs)`` , unlike the stricter `` foldq_` `, which raises an exception for leftover xs
9298
93- Note: `fold(f, xs, acc) == foldq(f, lambda acc, x, xs: xs, lambda x: True, xs, acc)`.
99+ Note: `` fold(f, xs, acc) == foldq(f, lambda acc, x, xs: xs, lambda x: True, xs, acc)` `.
94100
95101 Sequence of suitable names leading to the current one: consumption_fold, condition_update_fold, cu_fold, q_fold, qfold or foldq
96102 :param f: 'Traditional' fold function :: acc -> x -> acc
@@ -119,7 +125,7 @@ def full_step(acc, xs): # Alternative implementation: move function content insi
119125 return acc , xs
120126
121127def foldq_ (f : Callable [[_b , _a ], _b ], g : Callable [[_b , _a , list [_a ]], list [_a ]], c : Callable [[_a ], bool ], xs : list [_a ], acc : _b ) -> _b :
122- r '''Stricter version of foldq (see its description for details); only returns the accumulator and raises an exception on leftover xs.
128+ '''Stricter version of foldq (see its description for details); only returns the accumulator and raises an exception on leftover xs.
123129 :raises ValueError on leftover xs'''
124130 acc , xs = foldq (f , g , c , xs , acc )
125131 if xs : raise ValueError ('No suitable next element found for given condition while elements remain' )
@@ -157,7 +163,7 @@ def unique_by(f: Callable[[_a], Any], xs: Iterable[_a]) -> list[_a]:
157163 return [x for x in xs if (fx := f (x ), ) if fx not in seen and not seen .append (fx )] # Neat true-tuple assignment and neat short-circuit 'and' trick
158164
159165def eq_elems (xs : Iterable [_a ], ys : Iterable [_a ]) -> bool :
160- '''Equality of iterables by their elements'''
166+ '''Equality of iterables by their elements. '''
161167 cys = list (ys ) # make a mutable copy
162168 try :
163169 for x in xs : cys .remove (x )
@@ -166,15 +172,84 @@ def eq_elems(xs: Iterable[_a], ys: Iterable[_a]) -> bool:
166172
167173def diff (xs : Iterable [_a ], ys : Iterable [_a ]) -> list [_a ]:
168174 '''Difference of iterables.
169- Notes:
170- - not a set difference, so strictly removing as many xs duplicate entries as there are in ys
171- - preserves order in xs'''
175+ Notes:
176+ - not a set difference, so strictly removing as many xs duplicate entries as there are in ys
177+ - preserves order in xs'''
172178 cxs = list (xs ) # make a mutable copy
173179 try :
174180 for y in ys : cxs .remove (y )
175181 except ValueError : pass
176182 return cxs
177183
184+ @njit ([b1A (i8A , i8A ), b1A (f8A , f8A )])
185+ def isin_sorted (xs : NDArray [_a ], ys : NDArray [_a ]) -> NDArray [b1_NP ]:
186+ '''Optimised (10x faster) 1D version of np.isin assuming BOTH xs and ys are (ascendingly) sorted arrays of the same type (both int or both float):
187+ return a boolean array of the same length as xs indicating whether the corresponding element at that index is present in ys.'''
188+ res = np .zeros_like (xs , dtype = b1_NP )
189+ i = j = 0
190+ while i < len (xs ) and j < len (ys ):
191+ if ys [j ] < xs [i ]: j += 1 # Let the ys catch-up to the xs
192+ elif ys [j ] == xs [i ]: res [i ], i = True , i + 1 # Could increase j here as well, but it would imply assuming ys is strictly increasing
193+ else : i += 1 # Let the xs catch up to the ys
194+ return res
195+
196+ @njit ([i8A2 (i8A , i8A , b1 ), i8A2 (f8A , f8A , b1 )])
197+ def isin_sorted_intervals (xs : NDArray [_a ], ys : NDArray [_a ], strict = True ) -> NDArray [i8_NP ]:
198+ '''Assuming BOTH xs and ys are (ascendingly) sorted arrays of the same type (both int or both float):
199+ return a 2D array of the intervals (in terms of indices of xs) of subsequences shared with ys.
200+
201+ Notes:
202+ - The interval-end indices are of the last matching value, not of the next (non-matching) one; run ``res[:,1] += 1`` to switch the behaviour
203+ - If the desired intervals are the opposite of the matching ones, call ``complement_intervals(intervals, len(xs), closed_interval_ends = True)``
204+
205+ :param strict: Whether xs and ys subsequences need to match exactly (e.g. 1223 will not match 123 and vice-versa if strict).
206+ If the strictness of the assumption is not respected, the function will produce extra len-1 intervals for each duplicate.'''
207+ intervals = np .empty ((len (xs ) // 2 , 2 ), dtype = i8_NP ) # There can be at most half as many intervals as values
208+ i = j = k = 0
209+ if strict : # xs and ys subsequences need to match exactly
210+ while i < len (xs ) and j < len (ys ):
211+ if ys [j ] < xs [i ]: j += 1 # Let the ys catch-up to the xs
212+ if ys [j ] == xs [i ]:
213+ intervals [k , 0 ] = i
214+ while i < len (xs ) and j < len (ys ) and xs [i ] == ys [j ]: i , j = i + 1 , j + 1
215+ intervals [k , 1 ] = i - 1
216+ k += 1
217+ else : i += 1 # Let the xs catch up to the ys
218+ else : # xs and ys subsequences tolerate non-matching duplicates
219+ while i < len (xs ) and j < len (ys ):
220+ if ys [j ] < xs [i ]: j += 1 # Let the ys catch-up to the xs
221+ if ys [j ] == xs [i ]:
222+ intervals [k , 0 ] = i
223+ while i < len (xs ) and j < len (ys ):
224+ if xs [i ] == ys [j ]: i , j = i + 1 , j + 1 # Guaranteed outcome in first iteration, hence reasoning with -1s below
225+ elif xs [i ] == xs [i - 1 ]: i += 1
226+ elif ys [j ] == ys [j - 1 ]: j += 1
227+ else : break
228+ intervals [k , 1 ] = i - 1
229+ k += 1
230+ else : i += 1 # Let the xs catch up to the ys
231+ return intervals [:k ,...] # not k-1 since already ++ed
232+
233+ def complement_intervals (intervals : NDArray [int ], true_length : int , closed_interval_ends = True ) -> NDArray [int ]:
234+ '''Invert a series of index intervals (in the form of an (n,2)-array),
235+ i.e. return an (m,2)-array of intervals starting after the ends of the given ones and ending before their starts.
236+
237+ :param true_length: 1 more than the final index for the overall range these intervals are within (which might be ``intervals[-1,1]``); if coming from isin_sorted_intervals, then simply ``len(xs)``
238+ :param closed_interval_ends: whether BOTH input and output intervals are (and will be) closed, i.e. their end-index is INCLUDED in the interval, rather than being the first index after it
239+ '''
240+ if closed_interval_ends : intervals [:, 1 ] += 1
241+
242+ # Drop to 1D and toggle the presence of initial and final indices
243+ indices = np .reshape (intervals , intervals .size )
244+ indices = indices [1 :] if indices [0 ] == 0 else np .insert (indices , 0 , 0 )
245+ indices = indices [:- 1 ] if indices [- 1 ] == true_length else np .append (indices , true_length )
246+
247+ # Return to 2D (cardinality is even again, as both previous steps changed it by 1)
248+ flipped_intervals = np .reshape (indices , (len (indices ) // 2 , 2 ))
249+
250+ if closed_interval_ends : flipped_intervals [:, 1 ] -= 1
251+ return flipped_intervals
252+
178253
179254
180255## Interspersing Functions
0 commit comments