from functools import reduce
from itertools import chain, groupby, islice
from operator import itemgetter
from typing import Any, Callable, Iterable, Sequence
__all__ = [
'all_equal',
'butlast',
'concat',
'cons',
'dedupe',
'first',
'flatten',
'head',
'init',
'last',
'nth',
'partial_flatten',
'quantify',
'rest',
'reverse',
'sorted_index',
'tail',
'take',
]
[docs]def first(seq: Sequence) -> Any:
"""
Returns first element in a sequence.
>>> first([1, 2, 3])
1
"""
return next(iter(seq))
def second(seq: Sequence) -> Any:
"""
Returns second element in a sequence.
>>> second([1, 2, 3])
2
"""
return seq[1]
[docs]def last(seq: Sequence) -> Any:
"""
Returns the last item in a Sequence
>>> last([1, 2, 3])
3
"""
return seq[-1]
[docs]def butlast(seq: Sequence) -> Sequence:
"""
Returns all but the last item in sequence
>>> butlast([1, 2, 3])
[1, 2]
"""
return seq[:-1]
[docs]def rest(seq: Sequence) -> Any:
"""
Returns remaining elements in a sequence
>>> rest([1, 2, 3])
[2, 3]
"""
return seq[1:]
[docs]def reverse(seq: Sequence) -> Sequence:
"""
Returns sequence in reverse order
>>> reverse([1, 2, 3])
[3, 2, 1]
"""
return seq[::-1]
[docs]def cons(item: Any, seq: Sequence) -> chain:
""" Adds item to beginning of sequence.
>>> list(cons(1, [2, 3]))
[1, 2, 3]
"""
return chain([item], seq)
def lazy_flatten(seq: Iterable) -> Iterable:
"""
Returns a generator which yields items from a flattened version
of the sequence.
"""
for item in seq:
if isinstance(item, Iterable) and not isinstance(item, (str, bytes)):
yield from flatten(item)
else:
yield item
[docs]def flatten(seq: Iterable) -> Iterable:
""" Returns a flatten version of sequence.
>>> flatten([1, [2, [3, [4, 5], 6], 7]])
[1, 2, 3, 4, 5, 6, 7]
"""
return type(seq)(lazy_flatten(seq)) # type: ignore
[docs]def partial_flatten(seq: Iterable) -> Iterable:
"""
Returns partially flattened version of sequence.
>>> partial_flatten(((1,), [2, 3], (4, [5, 6])))
(1, 2, 3, 4, [5, 6])
"""
return type(seq)(reduce(concat, seq)) # type: ignore
def lazy_dedupe(seq: Sequence, key: Callable=None) -> Iterable:
"""
Returns a generator which which yields items in the sequence skipping
duplicates.
"""
seen = set() # type: set
for item in seq:
val = item if key is None else key(item)
if val not in seen:
yield item
seen.add(val)
[docs]def sorted_index(seq: Sequence, item: Any, key: str=None) -> int:
"""
>>> sorted_index([10, 20, 30, 40, 50], 35)
3
"""
keyfn = itemgetter(key) if key is not None else None
cp = sorted(cons(item, seq), key=keyfn)
return cp.index(item)
[docs]def dedupe(seq: Sequence, key: Callable=None) -> Iterable:
"""
Removes duplicates from a sequence while maintaining order
>>> dedupe([1, 5, 2, 1, 9, 1, 5, 10])
[1, 5, 2, 9, 10]
"""
return type(seq)(lazy_dedupe(seq, key)) # type: ignore
[docs]def concat(seqX: Sequence, seqY: Sequence) -> Sequence:
"""
Joins two sequences together, returning a single combined sequence.
Preserves the type of passed arguments.
>>> concat((1, 2, 3), (4, 5, 6))
(1, 2, 3, 4, 5, 6)
"""
chained = chain(seqX, seqY)
if isinstance(seqX, type(seqY)):
return type(seqX)(chained) # type: ignore
return list(chained)
[docs]def take(n: int, iterable: Iterable) -> Iterable:
"""
Return first n items of the iterable as a list.
>>> take(2, range(1, 10))
[1, 2]
"""
return list(islice(iterable, n))
[docs]def nth(iterable: Iterable, n: int, default: Any=None) -> Any:
"""
Returns the nth item or a default value.
>>> nth([1, 2, 3], 1)
2
"""
return next(islice(iterable, n, None), default)
[docs]def all_equal(iterable: Iterable) -> bool:
"""
Returns True if all the elements are equal to each other.
>>> all_equal([True, True])
True
"""
g = groupby(iterable)
return next(g, True) and not next(g, False) # type: ignore
[docs]def quantify(iterable: Iterable, pred: Callable=bool) -> int:
"""
Returns count of how many times the predicate is true.
>>> quantify([True, False, True])
2
"""
return sum(map(pred, iterable))
# Define some common aliases
head = first
tail = rest
init = butlast