Skip to content

Commit d9dde84

Browse files
committed
Support send(), throw(), and close() for generators.
1 parent 4c2bad0 commit d9dde84

3 files changed

Lines changed: 100 additions & 3 deletions

File tree

Doc/library/threading.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1461,7 +1461,9 @@ of the derived iterators.
14611461
.. class:: serialize(iterable)
14621462

14631463
Return an iterator wrapper that serializes concurrent calls to
1464-
:meth:`~iterator.__next__` using a lock.
1464+
:meth:`~iterator.__next__` using a lock. For generators, will also
1465+
serialize calls to :meth:`~generator.send`, :meth:`~generator.throw`,
1466+
and :meth:`~generator.close`.
14651467

14661468
This makes it possible to share a single iterator, including a generator
14671469
iterator, between multiple threads. A lock assures that calls are handled

Lib/test/test_threading.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2405,6 +2405,77 @@ def consumer(iterator):
24052405

24062406
self.assertEqual(result, limit * (limit - 1) // 2)
24072407

2408+
def test_serialize_generator_methods(self):
2409+
# A generator that yields and receives
2410+
def echo():
2411+
try:
2412+
while True:
2413+
val = yield "ready"
2414+
yield f"received {val}"
2415+
except ValueError:
2416+
yield "caught"
2417+
2418+
it = threading.serialize(echo())
2419+
2420+
# Test __next__
2421+
self.assertEqual(next(it), "ready")
2422+
2423+
# Test send()
2424+
self.assertEqual(it.send("hello"), "received hello")
2425+
self.assertEqual(next(it), "ready")
2426+
2427+
# Test throw()
2428+
self.assertEqual(it.throw(ValueError), "caught")
2429+
2430+
# Test close()
2431+
it.close()
2432+
with self.assertRaises(StopIteration):
2433+
next(it)
2434+
2435+
def test_serialize_methods_attribute_error(self):
2436+
# A standard iterator that does not have send/throw/close
2437+
# should raise AttributeError when called.
2438+
standard_it = threading.serialize([1, 2, 3])
2439+
2440+
with self.assertRaises(AttributeError):
2441+
standard_it.send("foo")
2442+
2443+
with self.assertRaises(AttributeError):
2444+
standard_it.throw(ValueError)
2445+
2446+
with self.assertRaises(AttributeError):
2447+
standard_it.close()
2448+
2449+
def test_serialize_generator_methods_locking(self):
2450+
# Verifies that generator methods also acquire the lock.
2451+
# We can test this by checking if the lock is held during the call.
2452+
2453+
class LockCheckingGenerator:
2454+
def __init__(self, lock):
2455+
self.lock = lock
2456+
def __iter__(self):
2457+
return self
2458+
def send(self, value):
2459+
if not self.lock.locked():
2460+
raise RuntimeError("Lock not held during send()")
2461+
return value
2462+
def throw(self, *args):
2463+
if not self.lock.locked():
2464+
raise RuntimeError("Lock not held during throw()")
2465+
def close(self):
2466+
if not self.lock.locked():
2467+
raise RuntimeError("Lock not held during close()")
2468+
2469+
# Manually create the serialize object to inspect the lock
2470+
it = threading.serialize([])
2471+
mock_gen = LockCheckingGenerator(it.lock)
2472+
it.iterator = mock_gen
2473+
2474+
# These should not raise RuntimeError
2475+
it.send(1)
2476+
it.throw(ValueError)
2477+
it.close()
2478+
24082479
def test_synchronized_serializes_generator_instances(self):
24092480
unique = 10
24102481
repetitions = 5

Lib/threading.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,8 @@ class BrokenBarrierError(RuntimeError):
849849
class serialize:
850850
"""Wrap a non-concurrent iterator with a lock to enforce sequential access.
851851
852-
Applies a non-reentrant lock around calls to __next__, allowing
853-
iterator and generator instances to be shared by multiple consumer
852+
Applies a non-reentrant lock around calls to __next__, send, throw, and close.
853+
Allows iterator and generator instances to be shared by multiple consumer
854854
threads.
855855
"""
856856

@@ -867,6 +867,30 @@ def __next__(self):
867867
with self.lock:
868868
return next(self.iterator)
869869

870+
def send(self, value, /):
871+
"""Send a value to a generator.
872+
873+
Raises AttributeError if not a generator.
874+
"""
875+
with self.lock:
876+
return self.iterator.send(value)
877+
878+
def throw(self, *args):
879+
"""Call throw() on a generator.
880+
881+
Raises AttributeError if not a generator.
882+
"""
883+
with self.lock:
884+
return self.iterator.throw(*args)
885+
886+
def close(self):
887+
"""Call close() on a generator.
888+
889+
Raises AttributeError if not a generator.
890+
"""
891+
with self.lock:
892+
return self.iterator.close()
893+
870894

871895
def synchronized(func):
872896
"""Wrap an iterator-returning callable to make its iterators thread-safe.

0 commit comments

Comments
 (0)