bpo-46771: Implement asyncio context managers for handling timeouts (GH-31394)

Example:

async with asyncio.timeout(5):
    await some_task()

Will interrupt the await and raise TimeoutError if some_task() takes longer than 5 seconds.

Co-authored-by: Guido van Rossum <guido@python.org>
This commit is contained in:
Andrew Svetlov 2022-03-10 18:05:20 +02:00 committed by GitHub
parent 32bf359792
commit f537b2a4fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 384 additions and 0 deletions

View File

@ -18,6 +18,7 @@ from .streams import *
from .subprocess import *
from .tasks import *
from .taskgroups import *
from .timeouts import *
from .threads import *
from .transports import *
@ -34,6 +35,7 @@ __all__ = (base_events.__all__ +
subprocess.__all__ +
tasks.__all__ +
threads.__all__ +
timeouts.__all__ +
transports.__all__)
if sys.platform == 'win32': # pragma: no cover

151
Lib/asyncio/timeouts.py Normal file
View File

@ -0,0 +1,151 @@
import enum
from types import TracebackType
from typing import final, Optional, Type
from . import events
from . import exceptions
from . import tasks
__all__ = (
"Timeout",
"timeout",
"timeout_at",
)
class _State(enum.Enum):
CREATED = "created"
ENTERED = "active"
EXPIRING = "expiring"
EXPIRED = "expired"
EXITED = "finished"
@final
class Timeout:
def __init__(self, when: Optional[float]) -> None:
self._state = _State.CREATED
self._timeout_handler: Optional[events.TimerHandle] = None
self._task: Optional[tasks.Task] = None
self._when = when
def when(self) -> Optional[float]:
return self._when
def reschedule(self, when: Optional[float]) -> None:
assert self._state is not _State.CREATED
if self._state is not _State.ENTERED:
raise RuntimeError(
f"Cannot change state of {self._state.value} Timeout",
)
self._when = when
if self._timeout_handler is not None:
self._timeout_handler.cancel()
if when is None:
self._timeout_handler = None
else:
loop = events.get_running_loop()
self._timeout_handler = loop.call_at(
when,
self._on_timeout,
)
def expired(self) -> bool:
"""Is timeout expired during execution?"""
return self._state in (_State.EXPIRING, _State.EXPIRED)
def __repr__(self) -> str:
info = ['']
if self._state is _State.ENTERED:
when = round(self._when, 3) if self._when is not None else None
info.append(f"when={when}")
info_str = ' '.join(info)
return f"<Timeout [{self._state.value}]{info_str}>"
async def __aenter__(self) -> "Timeout":
self._state = _State.ENTERED
self._task = tasks.current_task()
if self._task is None:
raise RuntimeError("Timeout should be used inside a task")
self.reschedule(self._when)
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
assert self._state in (_State.ENTERED, _State.EXPIRING)
if self._timeout_handler is not None:
self._timeout_handler.cancel()
self._timeout_handler = None
if self._state is _State.EXPIRING:
self._state = _State.EXPIRED
if self._task.uncancel() == 0 and exc_type is exceptions.CancelledError:
# Since there are no outstanding cancel requests, we're
# handling this.
raise TimeoutError
elif self._state is _State.ENTERED:
self._state = _State.EXITED
return None
def _on_timeout(self) -> None:
assert self._state is _State.ENTERED
self._task.cancel()
self._state = _State.EXPIRING
# drop the reference early
self._timeout_handler = None
def timeout(delay: Optional[float]) -> Timeout:
"""Timeout async context manager.
Useful in cases when you want to apply timeout logic around block
of code or in cases when asyncio.wait_for is not suitable. For example:
>>> async with asyncio.timeout(10): # 10 seconds timeout
... await long_running_task()
delay - value in seconds or None to disable timeout logic
long_running_task() is interrupted by raising asyncio.CancelledError,
the top-most affected timeout() context manager converts CancelledError
into TimeoutError.
"""
loop = events.get_running_loop()
return Timeout(loop.time() + delay if delay is not None else None)
def timeout_at(when: Optional[float]) -> Timeout:
"""Schedule the timeout at absolute time.
Like timeout() but argument gives absolute time in the same clock system
as loop.time().
Please note: it is not POSIX time but a time with
undefined starting base, e.g. the time of the system power on.
>>> async with asyncio.timeout_at(loop.time() + 10):
... await long_running_task()
when - a deadline when timeout occurs or None to disable timeout logic
long_running_task() is interrupted by raising asyncio.CancelledError,
the top-most affected timeout() context manager converts CancelledError
into TimeoutError.
"""
return Timeout(when)

View File

@ -0,0 +1,229 @@
"""Tests for asyncio/timeouts.py"""
import unittest
import time
import asyncio
from asyncio import tasks
def tearDownModule():
asyncio.set_event_loop_policy(None)
class TimeoutTests(unittest.IsolatedAsyncioTestCase):
async def test_timeout_basic(self):
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01) as cm:
await asyncio.sleep(10)
self.assertTrue(cm.expired())
async def test_timeout_at_basic(self):
loop = asyncio.get_running_loop()
with self.assertRaises(TimeoutError):
deadline = loop.time() + 0.01
async with asyncio.timeout_at(deadline) as cm:
await asyncio.sleep(10)
self.assertTrue(cm.expired())
self.assertEqual(deadline, cm.when())
async def test_nested_timeouts(self):
loop = asyncio.get_running_loop()
cancelled = False
with self.assertRaises(TimeoutError):
deadline = loop.time() + 0.01
async with asyncio.timeout_at(deadline) as cm1:
# Only the topmost context manager should raise TimeoutError
try:
async with asyncio.timeout_at(deadline) as cm2:
await asyncio.sleep(10)
except asyncio.CancelledError:
cancelled = True
raise
self.assertTrue(cancelled)
self.assertTrue(cm1.expired())
self.assertTrue(cm2.expired())
async def test_waiter_cancelled(self):
loop = asyncio.get_running_loop()
cancelled = False
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01):
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
cancelled = True
raise
self.assertTrue(cancelled)
async def test_timeout_not_called(self):
loop = asyncio.get_running_loop()
t0 = loop.time()
async with asyncio.timeout(10) as cm:
await asyncio.sleep(0.01)
t1 = loop.time()
self.assertFalse(cm.expired())
# 2 sec for slow CI boxes
self.assertLess(t1-t0, 2)
self.assertGreater(cm.when(), t1)
async def test_timeout_disabled(self):
loop = asyncio.get_running_loop()
t0 = loop.time()
async with asyncio.timeout(None) as cm:
await asyncio.sleep(0.01)
t1 = loop.time()
self.assertFalse(cm.expired())
self.assertIsNone(cm.when())
# 2 sec for slow CI boxes
self.assertLess(t1-t0, 2)
async def test_timeout_at_disabled(self):
loop = asyncio.get_running_loop()
t0 = loop.time()
async with asyncio.timeout_at(None) as cm:
await asyncio.sleep(0.01)
t1 = loop.time()
self.assertFalse(cm.expired())
self.assertIsNone(cm.when())
# 2 sec for slow CI boxes
self.assertLess(t1-t0, 2)
async def test_timeout_zero(self):
loop = asyncio.get_running_loop()
t0 = loop.time()
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0) as cm:
await asyncio.sleep(10)
t1 = loop.time()
self.assertTrue(cm.expired())
# 2 sec for slow CI boxes
self.assertLess(t1-t0, 2)
self.assertTrue(t0 <= cm.when() <= t1)
async def test_foreign_exception_passed(self):
with self.assertRaises(KeyError):
async with asyncio.timeout(0.01) as cm:
raise KeyError
self.assertFalse(cm.expired())
async def test_foreign_exception_on_timeout(self):
async def crash():
try:
await asyncio.sleep(1)
finally:
1/0
with self.assertRaises(ZeroDivisionError):
async with asyncio.timeout(0.01):
await crash()
async def test_foreign_cancel_doesnt_timeout_if_not_expired(self):
with self.assertRaises(asyncio.CancelledError):
async with asyncio.timeout(10) as cm:
asyncio.current_task().cancel()
await asyncio.sleep(10)
self.assertFalse(cm.expired())
async def test_outer_task_is_not_cancelled(self):
async def outer() -> None:
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.001):
await asyncio.sleep(10)
task = asyncio.create_task(outer())
await task
self.assertFalse(task.cancelled())
self.assertTrue(task.done())
async def test_nested_timeouts_concurrent(self):
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.002):
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.1):
# Pretend we crunch some numbers.
time.sleep(0.01)
await asyncio.sleep(1)
async def test_nested_timeouts_loop_busy(self):
# After the inner timeout is an expensive operation which should
# be stopped by the outer timeout.
loop = asyncio.get_running_loop()
# Disable a message about long running task
loop.slow_callback_duration = 10
t0 = loop.time()
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.1): # (1)
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01): # (2)
# Pretend the loop is busy for a while.
time.sleep(0.1)
await asyncio.sleep(1)
# TimeoutError was cought by (2)
await asyncio.sleep(10) # This sleep should be interrupted by (1)
t1 = loop.time()
self.assertTrue(t0 <= t1 <= t0 + 1)
async def test_reschedule(self):
loop = asyncio.get_running_loop()
fut = loop.create_future()
deadline1 = loop.time() + 10
deadline2 = deadline1 + 20
async def f():
async with asyncio.timeout_at(deadline1) as cm:
fut.set_result(cm)
await asyncio.sleep(50)
task = asyncio.create_task(f())
cm = await fut
self.assertEqual(cm.when(), deadline1)
cm.reschedule(deadline2)
self.assertEqual(cm.when(), deadline2)
cm.reschedule(None)
self.assertIsNone(cm.when())
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
self.assertFalse(cm.expired())
async def test_repr_active(self):
async with asyncio.timeout(10) as cm:
self.assertRegex(repr(cm), r"<Timeout \[active\] when=\d+\.\d*>")
async def test_repr_expired(self):
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01) as cm:
await asyncio.sleep(10)
self.assertEqual(repr(cm), "<Timeout [expired]>")
async def test_repr_finished(self):
async with asyncio.timeout(10) as cm:
await asyncio.sleep(0)
self.assertEqual(repr(cm), "<Timeout [finished]>")
async def test_repr_disabled(self):
async with asyncio.timeout(None) as cm:
self.assertEqual(repr(cm), r"<Timeout [active] when=None>")
async def test_nested_timeout_in_finally(self):
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01):
try:
await asyncio.sleep(1)
finally:
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01):
await asyncio.sleep(10)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,2 @@
:func:`asyncio.timeout` and :func:`asyncio.timeout_at` context managers
added. Patch by Tin Tvrtković and Andrew Svetlov.