Skip to content

Commit 4ac3c8c

Browse files
Merge pull request #182 from maxfischer2781/consistency/tee-shared
Share Tee buffer
2 parents 5c0e997 + e64fe3c commit 4ac3c8c

3 files changed

Lines changed: 178 additions & 63 deletions

File tree

asyncstdlib/itertools.py

Lines changed: 91 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
Union,
99
Callable,
1010
Optional,
11-
Deque,
1211
Generic,
1312
Iterable,
1413
Iterator,
@@ -17,7 +16,7 @@
1716
overload,
1817
AsyncGenerator,
1918
)
20-
from collections import deque
19+
from typing_extensions import TypeAlias
2120

2221
from ._typing import ACloseable, R, T, AnyIterable, ADD
2322
from ._utility import public_module
@@ -32,6 +31,7 @@
3231
enumerate as aenumerate,
3332
iter as aiter,
3433
)
34+
from itertools import count as _counter
3535

3636
S = TypeVar("S")
3737
T_co = TypeVar("T_co", covariant=True)
@@ -346,57 +346,79 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
346346
return None
347347

348348

349-
async def tee_peer(
350-
iterator: AsyncIterator[T],
351-
# the buffer specific to this peer
352-
buffer: Deque[T],
353-
# the buffers of all peers, including our own
354-
peers: List[Deque[T]],
355-
lock: AsyncContextManager[Any],
356-
) -> AsyncGenerator[T, None]:
357-
"""An individual iterator of a :py:func:`~.tee`"""
358-
try:
359-
while True:
360-
if not buffer:
361-
async with lock:
362-
# Another peer produced an item while we were waiting for the lock.
363-
# Proceed with the next loop iteration to yield the item.
364-
if buffer:
365-
continue
366-
try:
367-
item = await iterator.__anext__()
368-
except StopAsyncIteration:
369-
break
370-
else:
371-
# Append to all buffers, including our own. We'll fetch our
372-
# item from the buffer again, instead of yielding it directly.
373-
# This ensures the proper item ordering if any of our peers
374-
# are fetching items concurrently. They may have buffered their
375-
# item already.
376-
for peer_buffer in peers:
377-
peer_buffer.append(item)
378-
yield buffer.popleft()
379-
finally:
380-
# this peer is done – remove its buffer
381-
for idx, peer_buffer in enumerate(peers): # pragma: no branch
382-
if peer_buffer is buffer:
383-
peers.pop(idx)
384-
break
385-
# if we are the last peer, try and close the iterator
386-
if not peers and isinstance(iterator, ACloseable):
387-
await iterator.aclose()
349+
_get_tee_index = _counter().__next__
350+
351+
352+
_TeeNode: TypeAlias = "list[T | _TeeNode[T]]"
353+
354+
355+
class TeePeer(Generic[T]):
356+
def __init__(
357+
self,
358+
iterator: AsyncIterator[T],
359+
buffer: "_TeeNode[T]",
360+
lock: AsyncContextManager[Any],
361+
tee_peers: "set[int]",
362+
) -> None:
363+
self._iterator = iterator
364+
self._lock = lock
365+
self._buffer: _TeeNode[T] = buffer
366+
self._tee_peers = tee_peers
367+
self._tee_idx = _get_tee_index()
368+
self._tee_peers.add(self._tee_idx)
369+
370+
def __aiter__(self):
371+
return self
372+
373+
async def __anext__(self) -> T:
374+
# the buffer is a singly-linked list as [value, [value, [...]]] | []
375+
next_node = self._buffer
376+
value: T
377+
# for any most advanced TeePeer, the node is just []
378+
# fetch the next value so we can mutate the node to [value, [...]]
379+
if not next_node:
380+
async with self._lock:
381+
# Check if another peer produced an item while we were waiting for the lock
382+
if not next_node:
383+
await self._extend_buffer(next_node)
384+
# for any other TeePeer, the node is already some [value, [...]]
385+
value, self._buffer = next_node # type: ignore
386+
return value
387+
388+
async def _extend_buffer(self, next_node: "_TeeNode[T]") -> None:
389+
"""Extend the buffer by fetching a new item from the iterable"""
390+
try:
391+
# another peer may fill the buffer while we wait here
392+
next_value = await self._iterator.__anext__()
393+
except StopAsyncIteration:
394+
# no one else managed to fetch a value either
395+
if not next_node:
396+
raise
397+
else:
398+
# skip nodes that were filled in the meantime
399+
while next_node:
400+
_, next_node = next_node # type: ignore
401+
next_node[:] = next_value, []
402+
403+
async def aclose(self) -> None:
404+
self._tee_peers.discard(self._tee_idx)
405+
if not self._tee_peers and isinstance(self._iterator, ACloseable):
406+
await self._iterator.aclose()
407+
408+
def __del__(self) -> None:
409+
self._tee_peers.discard(self._tee_idx)
388410

389411

390412
@public_module(__name__, "tee")
391413
class Tee(Generic[T]):
392-
"""
414+
r"""
393415
Create ``n`` separate asynchronous iterators over ``iterable``
394416
395417
This splits a single ``iterable`` into multiple iterators, each providing
396418
the same items in the same order.
397419
All child iterators may advance separately but share the same items
398420
from ``iterable`` -- when the most advanced iterator retrieves an item,
399-
it is buffered until the least advanced iterator has yielded it as well.
421+
it is buffered until all other iterators have yielded it as well.
400422
A ``tee`` works lazily and can handle an infinite ``iterable``, provided
401423
that all iterators advance.
402424
@@ -407,26 +429,25 @@ async def derivative(sensor_data):
407429
await a.anext(previous) # advance one iterator
408430
return a.map(operator.sub, previous, current)
409431
410-
Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
411-
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
412-
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
413-
immediately closes all children, and it can be used in an ``async with`` context
414-
for the same effect.
415-
416-
If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
417-
provide these items. Also, ``tee`` must internally buffer each item until the
418-
last iterator has yielded it; if the most and least advanced iterator differ
419-
by most data, using a :py:class:`list` is more efficient (but not lazy).
432+
If ``iterable`` is an iterator and read elsewhere, ``tee`` will generally *not*
433+
provide these items. However, a ``tee`` of a ``tee`` shares its buffer with parent,
434+
sibling and child ``tee``\ s so that each sees the same items.
420435
421436
If the underlying iterable is concurrency safe (``anext`` may be awaited
422437
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
423438
the iterators are safe if there is only ever one single "most advanced" iterator.
424439
To enforce sequential use of ``anext``, provide a ``lock``
425440
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
426441
and access is automatically synchronised.
442+
443+
Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
444+
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
445+
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
446+
immediately closes all children, and it can be used in an ``async with`` context
447+
for the same effect.
427448
"""
428449

429-
__slots__ = ("_iterator", "_buffers", "_children")
450+
__slots__ = ("_children",)
430451

431452
def __init__(
432453
self,
@@ -435,16 +456,24 @@ def __init__(
435456
*,
436457
lock: Optional[AsyncContextManager[Any]] = None,
437458
):
438-
self._iterator = aiter(iterable)
439-
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
459+
buffer: _TeeNode[T]
460+
peers: set[int]
461+
if not isinstance(iterable, TeePeer):
462+
iterator = aiter(iterable)
463+
buffer = []
464+
peers = set()
465+
else:
466+
iterator = iterable._iterator # pyright: ignore[reportPrivateUsage]
467+
buffer = iterable._buffer # pyright: ignore[reportPrivateUsage]
468+
peers = iterable._tee_peers # pyright: ignore[reportPrivateUsage]
440469
self._children = tuple(
441-
tee_peer(
442-
iterator=self._iterator,
443-
buffer=buffer,
444-
peers=self._buffers,
445-
lock=lock if lock is not None else NoLock(),
470+
TeePeer(
471+
iterator,
472+
buffer,
473+
lock if lock is not None else NoLock(),
474+
peers,
446475
)
447-
for buffer in self._buffers
476+
for _ in range(n)
448477
)
449478

450479
def __len__(self) -> int:

docs/source/api/itertools.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ Iterator splitting
8585

8686
The ``lock`` keyword parameter.
8787

88+
.. versionchanged:: 3.13.2
89+
90+
``tee``\ s share their buffer with parents, siblings and children.
91+
8892
.. autofunction:: pairwise(iterable: (async) iter T)
8993
:async-for: :(T, T)
9094

unittests/test_itertools.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import AsyncIterator
12
import itertools
23
import sys
34
import platform
@@ -341,7 +342,7 @@ async def test_tee():
341342

342343
@sync
343344
async def test_tee_concurrent_locked():
344-
"""Test that properly uses a lock for synchronisation"""
345+
"""Test that tee properly uses a lock for synchronisation"""
345346
items = [1, 2, 3, -5, 12, 78, -1, 111]
346347

347348
async def iter_values():
@@ -360,6 +361,52 @@ async def test_peer(peer_tee):
360361
assert results == items
361362

362363

364+
@pytest.mark.parametrize("concurrency", (1, 2, 4, 7))
365+
@sync
366+
async def test_tee_share(concurrency: int) -> None:
367+
"""Test that related tees share their buffer and see all items"""
368+
items = [1, 2, 3, -5, 12, 78, -1, 111]
369+
370+
async def tee_test(tee_state: AsyncIterator[int]) -> None:
371+
"""Asynchronously check that `tee_state` includes all `items`"""
372+
for expected in items:
373+
assert expected == await a.anext(tee_state)
374+
await Switch(0, concurrency)
375+
376+
# create tees that are multiple times removed from an initial iterator
377+
item_iter = a.iter(items)
378+
for tee_peer in a.tee(item_iter, n=concurrency):
379+
await Schedule(tee_test(a.tee(tee_peer)[0]))
380+
381+
382+
@sync
383+
async def test_tee_share_deep() -> None:
384+
"""Test that related tees share their buffer and see all items no matter when spawned"""
385+
items = [1, 2, 3, -5, 12, 78, -1, 111]
386+
387+
async def tee_spawn_walker(
388+
tee_state: AsyncIterator[int], start_idx: int = 0
389+
) -> None:
390+
"""Walk and check `tee_state` elements and spawn new walkers on every step"""
391+
for idx in range(start_idx, len(items)):
392+
await Switch(0, 3)
393+
assert await a.anext(tee_state) == items[idx]
394+
tee_state, *child_states = a.tee(tee_state, n=3)
395+
await Schedule(
396+
*(
397+
tee_spawn_walker(child_state, idx + 1)
398+
for child_state in child_states
399+
)
400+
)
401+
await Switch()
402+
403+
head_peer, *child_peers = a.tee(items, n=3)
404+
await Schedule(*(tee_spawn_walker(child, 0) for child in child_peers))
405+
await Switch(len(items) // 2)
406+
results = [item async for item in head_peer]
407+
assert results == items
408+
409+
363410
# see https://github.com/python/cpython/issues/74956
364411
@pytest.mark.skipif(
365412
sys.version_info < (3, 8),
@@ -393,6 +440,41 @@ async def test_peer(peer_tee):
393440
await test_peer(this)
394441

395442

443+
@pytest.mark.parametrize("size", [2, 3, 5, 9, 12])
444+
@sync
445+
async def test_tee_concurrent_ordering(size: int):
446+
"""Test that tee respects concurrent ordering for all peers"""
447+
448+
class ConcurrentInvertedIterable:
449+
"""Helper that concurrently iterates with earlier items taking longer"""
450+
451+
def __init__(self, count: int) -> None:
452+
self.count = count
453+
self._counter = itertools.count()
454+
455+
def __aiter__(self):
456+
return self
457+
458+
async def __anext__(self):
459+
value = next(self._counter)
460+
if value >= self.count:
461+
raise StopAsyncIteration()
462+
await Switch(self.count - value)
463+
return value
464+
465+
async def test_peer(peer_tee: AsyncIterator[int]):
466+
# consume items from the tee with a delay so that slower items can arrive
467+
seen_items: list[int] = []
468+
async for item in peer_tee:
469+
seen_items.append(item)
470+
await Switch()
471+
assert seen_items == expected_items
472+
473+
expected_items = list(range(size)[::-1])
474+
peers = a.tee(ConcurrentInvertedIterable(size), n=size)
475+
await Schedule(*map(test_peer, peers))
476+
477+
396478
@sync
397479
async def test_pairwise():
398480
assert await a.list(a.pairwise(range(5))) == [(0, 1), (1, 2), (2, 3), (3, 4)]

0 commit comments

Comments
 (0)