|
6 | 6 | from itertools import chain, combinations, islice |
7 | 7 | from functools import reduce |
8 | 8 | from collections import defaultdict |
| 9 | +from collections.abc import MutableMapping |
9 | 10 | from sortedcontainers import SortedKeyList |
10 | 11 | import operator as op |
11 | 12 | import numpy as np |
@@ -39,6 +40,18 @@ def deep_extract(xss_: Iterable[Iterable], *key_path) -> Generator: |
39 | 40 | for k in key_path: level = level[k] |
40 | 41 | return deep_flatten(level) |
41 | 42 |
|
| 43 | +def unnest(nested_dict, prefix = '', sep = '.'): |
| 44 | + '''Recursively unnest dictionaries and lists in a dictionary, concatenating keys with sep. All keys are assumed to be strings. |
| 45 | + Example: {'a': [{'b': 2}, {'c': 3}], 'd': 4, 'e': {'f': 6}} -> {'a.0.b': 2, 'a.1.c': 3, 'd': 4, 'e.f': 6}''' |
| 46 | + out = {} |
| 47 | + for k, v in nested_dict.items(): |
| 48 | + new_k = prefix + sep + k if prefix else k |
| 49 | + if isinstance(v, MutableMapping): out.update(unnest(v, new_k, sep=sep)) |
| 50 | + elif isinstance(v, list): |
| 51 | + for i, vi in enumerate(v): out.update(unnest(vi, new_k + sep + str(i), sep=sep)) |
| 52 | + else: out[new_k] = v |
| 53 | + return out |
| 54 | + |
42 | 55 |
|
43 | 56 |
|
44 | 57 | ## Iterable-Combining (and Combinatory) Functions |
@@ -273,7 +286,10 @@ def intersperse_val(xs: Sequence[_a], y: _a, n: int, prepend = False, append = F |
273 | 286 | ## Batching functions |
274 | 287 |
|
275 | 288 | def batch_iter(n: int, xs: Iterable[_a]) -> Generator[_a, None, None]: |
276 | | - '''Batch an iterable in batches of size n (possibly except the last). If len(xs) is knowable use batch_seq instead.''' |
| 289 | + '''*Soft deprecation* for Python>=3.12 since itertools now contains batched(iterable, n). |
| 290 | + Batch an iterable in batches of size n (possibly except the last). If len(xs) is knowable use batch_seq instead.''' |
| 291 | + from warnings import warn |
| 292 | + warn('batch_iter() is deprecated from Python>=3.12 and will be removed in a future release.\nUse itertools.batched(iterable, n) instead.', DeprecationWarning, stacklevel=2) |
277 | 293 | iterator = iter(xs) |
278 | 294 | while batch := list(islice(iterator, n)): yield batch |
279 | 295 |
|
@@ -342,4 +358,22 @@ def batch_seq_by_into(by: Callable[[_a], float], k: int, xs: Sequence[_a], keep_ |
342 | 358 | count += weight |
343 | 359 | if batch: yield batch |
344 | 360 |
|
| 361 | +def batch_by_group(by: Callable[[_a], Any], n: int, xs: Iterable[_a]) -> Generator[_a, None, None]: |
| 362 | + '''Batch an iterable into batches of length <= n in which not two elements are from the same group as determined with by. |
| 363 | + Note: the order of items in and across the batches is reversed from the original iterable; |
| 364 | + might want to reverse it beforehand (or have a deque.popleft toggle?)''' |
| 365 | + groups = defaultdict(list) |
| 366 | + for x in xs: groups[by(x)].append(x) |
| 367 | + |
| 368 | + while any(groups.values()): |
| 369 | + batch = [] |
| 370 | + for grp in list(groups.keys()): # casting to list makes it a copy, so not affected by .pop |
| 371 | + if len(batch) >= n: |
| 372 | + break |
| 373 | + if groups[grp]: |
| 374 | + batch.append(groups[grp].pop()) |
| 375 | + else: |
| 376 | + del groups[grp] |
| 377 | + yield batch |
| 378 | + |
345 | 379 |
|
0 commit comments