GH-124639: add back loop param to staggered_race (#124700)

This commit is contained in:
Kumar Aditya 2024-09-29 08:42:46 +05:30 committed by GitHub
parent c00964ecd5
commit e0a41a5dd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 2 deletions

View File

@ -11,7 +11,7 @@ from . import taskgroups
class _Done(Exception): class _Done(Exception):
pass pass
async def staggered_race(coro_fns, delay): async def staggered_race(coro_fns, delay, *, loop=None):
"""Run coroutines with staggered start times and take the first to finish. """Run coroutines with staggered start times and take the first to finish.
This method takes an iterable of coroutine functions. The first one is This method takes an iterable of coroutine functions. The first one is
@ -82,7 +82,13 @@ async def staggered_race(coro_fns, delay):
raise _Done raise _Done
try: try:
async with taskgroups.TaskGroup() as tg: tg = taskgroups.TaskGroup()
# Intentionally override the loop in the TaskGroup to avoid
# using the running loop, preserving backwards compatibility
# TaskGroup only starts using `_loop` after `__aenter__`
# so overriding it here is safe.
tg._loop = loop
async with tg:
for this_index, coro_fn in enumerate(coro_fns): for this_index, coro_fn in enumerate(coro_fns):
this_failed = locks.Event() this_failed = locks.Event()
exceptions.append(None) exceptions.append(None)

View File

@ -121,6 +121,25 @@ class StaggeredTests(unittest.IsolatedAsyncioTestCase):
self.assertIsInstance(excs[0], ValueError) self.assertIsInstance(excs[0], ValueError)
self.assertIsNone(excs[1]) self.assertIsNone(excs[1])
def test_loop_argument(self):
loop = asyncio.new_event_loop()
async def coro():
self.assertEqual(loop, asyncio.get_running_loop())
return 'coro'
async def main():
winner, index, excs = await staggered_race(
[coro],
delay=0.1,
loop=loop
)
self.assertEqual(winner, 'coro')
self.assertEqual(index, 0)
loop.run_until_complete(main())
loop.close()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()