gh-116738: Make _heapq module thread-safe (#135036)

Use critical sections to make heapq methods that update the heap thread-safe when the GIL is disabled.

---------

Co-authored-by: mpage <mpage@meta.com>
This commit is contained in:
Alper 2025-06-09 10:57:29 -07:00 committed by GitHub
parent cc8e6d2703
commit a58026a5e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 303 additions and 15 deletions

View File

@ -0,0 +1,240 @@
import unittest
import heapq
from enum import Enum
from threading import Thread, Barrier
from random import shuffle, randint
from test.support import threading_helper
from test import test_heapq
NTHREADS = 10
OBJECT_COUNT = 5_000
class Heap(Enum):
MIN = 1
MAX = 2
@threading_helper.requires_working_threading()
class TestHeapq(unittest.TestCase):
def setUp(self):
self.test_heapq = test_heapq.TestHeapPython()
def test_racing_heapify(self):
heap = list(range(OBJECT_COUNT))
shuffle(heap)
self.run_concurrently(
worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
)
self.test_heapq.check_invariant(heap)
def test_racing_heappush(self):
heap = []
def heappush_func(heap):
for item in reversed(range(OBJECT_COUNT)):
heapq.heappush(heap, item)
self.run_concurrently(
worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
)
self.test_heapq.check_invariant(heap)
def test_racing_heappop(self):
heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
# Each thread pops (OBJECT_COUNT / NTHREADS) items
self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
per_thread_pop_count = OBJECT_COUNT // NTHREADS
def heappop_func(heap, pop_count):
local_list = []
for _ in range(pop_count):
item = heapq.heappop(heap)
local_list.append(item)
# Each local list should be sorted
self.assertTrue(self.is_sorted_ascending(local_list))
self.run_concurrently(
worker_func=heappop_func,
args=(heap, per_thread_pop_count),
nthreads=NTHREADS,
)
self.assertEqual(len(heap), 0)
def test_racing_heappushpop(self):
heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
def heappushpop_func(heap, pushpop_items):
for item in pushpop_items:
popped_item = heapq.heappushpop(heap, item)
self.assertTrue(popped_item <= item)
self.run_concurrently(
worker_func=heappushpop_func,
args=(heap, pushpop_items),
nthreads=NTHREADS,
)
self.assertEqual(len(heap), OBJECT_COUNT)
self.test_heapq.check_invariant(heap)
def test_racing_heapreplace(self):
heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
def heapreplace_func(heap, replace_items):
for item in replace_items:
heapq.heapreplace(heap, item)
self.run_concurrently(
worker_func=heapreplace_func,
args=(heap, replace_items),
nthreads=NTHREADS,
)
self.assertEqual(len(heap), OBJECT_COUNT)
self.test_heapq.check_invariant(heap)
def test_racing_heapify_max(self):
max_heap = list(range(OBJECT_COUNT))
shuffle(max_heap)
self.run_concurrently(
worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
)
self.test_heapq.check_max_invariant(max_heap)
def test_racing_heappush_max(self):
max_heap = []
def heappush_max_func(max_heap):
for item in range(OBJECT_COUNT):
heapq.heappush_max(max_heap, item)
self.run_concurrently(
worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
)
self.test_heapq.check_max_invariant(max_heap)
def test_racing_heappop_max(self):
max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
# Each thread pops (OBJECT_COUNT / NTHREADS) items
self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
per_thread_pop_count = OBJECT_COUNT // NTHREADS
def heappop_max_func(max_heap, pop_count):
local_list = []
for _ in range(pop_count):
item = heapq.heappop_max(max_heap)
local_list.append(item)
# Each local list should be sorted
self.assertTrue(self.is_sorted_descending(local_list))
self.run_concurrently(
worker_func=heappop_max_func,
args=(max_heap, per_thread_pop_count),
nthreads=NTHREADS,
)
self.assertEqual(len(max_heap), 0)
def test_racing_heappushpop_max(self):
max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
def heappushpop_max_func(max_heap, pushpop_items):
for item in pushpop_items:
popped_item = heapq.heappushpop_max(max_heap, item)
self.assertTrue(popped_item >= item)
self.run_concurrently(
worker_func=heappushpop_max_func,
args=(max_heap, pushpop_items),
nthreads=NTHREADS,
)
self.assertEqual(len(max_heap), OBJECT_COUNT)
self.test_heapq.check_max_invariant(max_heap)
def test_racing_heapreplace_max(self):
max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
def heapreplace_max_func(max_heap, replace_items):
for item in replace_items:
heapq.heapreplace_max(max_heap, item)
self.run_concurrently(
worker_func=heapreplace_max_func,
args=(max_heap, replace_items),
nthreads=NTHREADS,
)
self.assertEqual(len(max_heap), OBJECT_COUNT)
self.test_heapq.check_max_invariant(max_heap)
@staticmethod
def is_sorted_ascending(lst):
"""
Check if the list is sorted in ascending order (non-decreasing).
"""
return all(lst[i - 1] <= lst[i] for i in range(1, len(lst)))
@staticmethod
def is_sorted_descending(lst):
"""
Check if the list is sorted in descending order (non-increasing).
"""
return all(lst[i - 1] >= lst[i] for i in range(1, len(lst)))
@staticmethod
def create_heap(size, heap_kind):
"""
Create a min/max heap where elements are in the range (0, size - 1) and
shuffled before heapify.
"""
heap = list(range(OBJECT_COUNT))
shuffle(heap)
if heap_kind == Heap.MIN:
heapq.heapify(heap)
else:
heapq.heapify_max(heap)
return heap
@staticmethod
def create_random_list(a, b, size):
"""
Create a list of random numbers between a and b (inclusive).
"""
return [randint(-a, b) for _ in range(size)]
def run_concurrently(self, worker_func, args, nthreads):
"""
Run the worker function concurrently in multiple threads.
"""
barrier = Barrier(nthreads)
def wrapper_func(*args):
# Wait for all threads to reach this point before proceeding.
barrier.wait()
worker_func(*args)
with threading_helper.catch_threading_exception() as cm:
workers = (
Thread(target=wrapper_func, args=args) for _ in range(nthreads)
)
with threading_helper.start_threads(workers):
pass
# Worker threads should not raise any exceptions
self.assertIsNone(cm.exc_value)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1 @@
Make methods in :mod:`heapq` thread-safe on the :term:`free threaded <free threading>` build.

View File

@ -11,7 +11,7 @@ annotated by François Pinard, and converted to C by Raymond Hettinger.
#endif
#include "Python.h"
#include "pycore_list.h" // _PyList_ITEMS()
#include "pycore_list.h" // _PyList_ITEMS(), _PyList_AppendTakeRef()
#include "clinic/_heapqmodule.c.h"
@ -117,6 +117,7 @@ siftup(PyListObject *heap, Py_ssize_t pos)
}
/*[clinic input]
@critical_section heap
_heapq.heappush
heap: object(subclass_of='&PyList_Type')
@ -128,13 +129,22 @@ Push item onto heap, maintaining the heap invariant.
static PyObject *
_heapq_heappush_impl(PyObject *module, PyObject *heap, PyObject *item)
/*[clinic end generated code: output=912c094f47663935 input=7c69611f3698aceb]*/
/*[clinic end generated code: output=912c094f47663935 input=f7a4f03ef8d52e67]*/
{
if (PyList_Append(heap, item))
if (item == NULL) {
PyErr_BadInternalCall();
return NULL;
}
if (siftdown((PyListObject *)heap, 0, PyList_GET_SIZE(heap)-1))
// In a free-threaded build, the heap is locked at this point.
// Therefore, calling _PyList_AppendTakeRef() is safe and no overhead.
if (_PyList_AppendTakeRef((PyListObject *)heap, Py_NewRef(item))) {
return NULL;
}
if (siftdown((PyListObject *)heap, 0, PyList_GET_SIZE(heap)-1)) {
return NULL;
}
Py_RETURN_NONE;
}
@ -171,6 +181,7 @@ heappop_internal(PyObject *heap, int siftup_func(PyListObject *, Py_ssize_t))
}
/*[clinic input]
@critical_section heap
_heapq.heappop
heap: object(subclass_of='&PyList_Type')
@ -181,7 +192,7 @@ Pop the smallest item off the heap, maintaining the heap invariant.
static PyObject *
_heapq_heappop_impl(PyObject *module, PyObject *heap)
/*[clinic end generated code: output=96dfe82d37d9af76 input=91487987a583c856]*/
/*[clinic end generated code: output=96dfe82d37d9af76 input=ed396461b153dd51]*/
{
return heappop_internal(heap, siftup);
}
@ -207,6 +218,7 @@ heapreplace_internal(PyObject *heap, PyObject *item, int siftup_func(PyListObjec
/*[clinic input]
@critical_section heap
_heapq.heapreplace
heap: object(subclass_of='&PyList_Type')
@ -226,12 +238,13 @@ this routine unless written as part of a conditional replacement:
static PyObject *
_heapq_heapreplace_impl(PyObject *module, PyObject *heap, PyObject *item)
/*[clinic end generated code: output=82ea55be8fbe24b4 input=719202ac02ba10c8]*/
/*[clinic end generated code: output=82ea55be8fbe24b4 input=9be1678b817ef1a9]*/
{
return heapreplace_internal(heap, item, siftup);
}
/*[clinic input]
@critical_section heap
_heapq.heappushpop
heap: object(subclass_of='&PyList_Type')
@ -246,7 +259,7 @@ a separate call to heappop().
static PyObject *
_heapq_heappushpop_impl(PyObject *module, PyObject *heap, PyObject *item)
/*[clinic end generated code: output=67231dc98ed5774f input=5dc701f1eb4a4aa7]*/
/*[clinic end generated code: output=67231dc98ed5774f input=db05c81b1dd92c44]*/
{
PyObject *returnitem;
int cmp;
@ -371,6 +384,7 @@ heapify_internal(PyObject *heap, int siftup_func(PyListObject *, Py_ssize_t))
}
/*[clinic input]
@critical_section heap
_heapq.heapify
heap: object(subclass_of='&PyList_Type')
@ -381,7 +395,7 @@ Transform list into a heap, in-place, in O(len(heap)) time.
static PyObject *
_heapq_heapify_impl(PyObject *module, PyObject *heap)
/*[clinic end generated code: output=e63a636fcf83d6d0 input=53bb7a2166febb73]*/
/*[clinic end generated code: output=e63a636fcf83d6d0 input=aaaaa028b9b6af08]*/
{
return heapify_internal(heap, siftup);
}
@ -481,6 +495,7 @@ siftup_max(PyListObject *heap, Py_ssize_t pos)
}
/*[clinic input]
@critical_section heap
_heapq.heappush_max
heap: object(subclass_of='&PyList_Type')
@ -492,9 +507,16 @@ Push item onto max heap, maintaining the heap invariant.
static PyObject *
_heapq_heappush_max_impl(PyObject *module, PyObject *heap, PyObject *item)
/*[clinic end generated code: output=c869d5f9deb08277 input=4743d7db137b6e2b]*/
/*[clinic end generated code: output=c869d5f9deb08277 input=c437e3d1ff8dcb70]*/
{
if (PyList_Append(heap, item)) {
if (item == NULL) {
PyErr_BadInternalCall();
return NULL;
}
// In a free-threaded build, the heap is locked at this point.
// Therefore, calling _PyList_AppendTakeRef() is safe and no overhead.
if (_PyList_AppendTakeRef((PyListObject *)heap, Py_NewRef(item))) {
return NULL;
}
@ -506,6 +528,7 @@ _heapq_heappush_max_impl(PyObject *module, PyObject *heap, PyObject *item)
}
/*[clinic input]
@critical_section heap
_heapq.heappop_max
heap: object(subclass_of='&PyList_Type')
@ -516,12 +539,13 @@ Maxheap variant of heappop.
static PyObject *
_heapq_heappop_max_impl(PyObject *module, PyObject *heap)
/*[clinic end generated code: output=2f051195ab404b77 input=e62b14016a5a26de]*/
/*[clinic end generated code: output=2f051195ab404b77 input=5d70c997798aec64]*/
{
return heappop_internal(heap, siftup_max);
}
/*[clinic input]
@critical_section heap
_heapq.heapreplace_max
heap: object(subclass_of='&PyList_Type')
@ -533,12 +557,13 @@ Maxheap variant of heapreplace.
static PyObject *
_heapq_heapreplace_max_impl(PyObject *module, PyObject *heap, PyObject *item)
/*[clinic end generated code: output=8770778b5a9cbe9b input=21a3d28d757c881c]*/
/*[clinic end generated code: output=8770778b5a9cbe9b input=fe70175356e4a649]*/
{
return heapreplace_internal(heap, item, siftup_max);
}
/*[clinic input]
@critical_section heap
_heapq.heapify_max
heap: object(subclass_of='&PyList_Type')
@ -549,12 +574,13 @@ Maxheap variant of heapify.
static PyObject *
_heapq_heapify_max_impl(PyObject *module, PyObject *heap)
/*[clinic end generated code: output=8401af3856529807 input=edda4255728c431e]*/
/*[clinic end generated code: output=8401af3856529807 input=4eee63231e7d1573]*/
{
return heapify_internal(heap, siftup_max);
}
/*[clinic input]
@critical_section heap
_heapq.heappushpop_max
heap: object(subclass_of='&PyList_Type')
@ -569,7 +595,7 @@ a separate call to heappop_max().
static PyObject *
_heapq_heappushpop_max_impl(PyObject *module, PyObject *heap, PyObject *item)
/*[clinic end generated code: output=ff0019f0941aca0d input=525a843013cbd6c0]*/
/*[clinic end generated code: output=ff0019f0941aca0d input=24d0defa6fd6df4a]*/
{
PyObject *returnitem;
int cmp;

View File

@ -2,6 +2,7 @@
preserve
[clinic start generated code]*/
#include "pycore_critical_section.h"// Py_BEGIN_CRITICAL_SECTION()
#include "pycore_modsupport.h" // _PyArg_CheckPositional()
PyDoc_STRVAR(_heapq_heappush__doc__,
@ -32,7 +33,9 @@ _heapq_heappush(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
}
heap = args[0];
item = args[1];
Py_BEGIN_CRITICAL_SECTION(heap);
return_value = _heapq_heappush_impl(module, heap, item);
Py_END_CRITICAL_SECTION();
exit:
return return_value;
@ -61,7 +64,9 @@ _heapq_heappop(PyObject *module, PyObject *arg)
goto exit;
}
heap = arg;
Py_BEGIN_CRITICAL_SECTION(heap);
return_value = _heapq_heappop_impl(module, heap);
Py_END_CRITICAL_SECTION();
exit:
return return_value;
@ -103,7 +108,9 @@ _heapq_heapreplace(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
}
heap = args[0];
item = args[1];
Py_BEGIN_CRITICAL_SECTION(heap);
return_value = _heapq_heapreplace_impl(module, heap, item);
Py_END_CRITICAL_SECTION();
exit:
return return_value;
@ -140,7 +147,9 @@ _heapq_heappushpop(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
}
heap = args[0];
item = args[1];
Py_BEGIN_CRITICAL_SECTION(heap);
return_value = _heapq_heappushpop_impl(module, heap, item);
Py_END_CRITICAL_SECTION();
exit:
return return_value;
@ -169,7 +178,9 @@ _heapq_heapify(PyObject *module, PyObject *arg)
goto exit;
}
heap = arg;
Py_BEGIN_CRITICAL_SECTION(heap);
return_value = _heapq_heapify_impl(module, heap);
Py_END_CRITICAL_SECTION();
exit:
return return_value;
@ -203,7 +214,9 @@ _heapq_heappush_max(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
}
heap = args[0];
item = args[1];
Py_BEGIN_CRITICAL_SECTION(heap);
return_value = _heapq_heappush_max_impl(module, heap, item);
Py_END_CRITICAL_SECTION();
exit:
return return_value;
@ -232,7 +245,9 @@ _heapq_heappop_max(PyObject *module, PyObject *arg)
goto exit;
}
heap = arg;
Py_BEGIN_CRITICAL_SECTION(heap);
return_value = _heapq_heappop_max_impl(module, heap);
Py_END_CRITICAL_SECTION();
exit:
return return_value;
@ -266,7 +281,9 @@ _heapq_heapreplace_max(PyObject *module, PyObject *const *args, Py_ssize_t nargs
}
heap = args[0];
item = args[1];
Py_BEGIN_CRITICAL_SECTION(heap);
return_value = _heapq_heapreplace_max_impl(module, heap, item);
Py_END_CRITICAL_SECTION();
exit:
return return_value;
@ -295,7 +312,9 @@ _heapq_heapify_max(PyObject *module, PyObject *arg)
goto exit;
}
heap = arg;
Py_BEGIN_CRITICAL_SECTION(heap);
return_value = _heapq_heapify_max_impl(module, heap);
Py_END_CRITICAL_SECTION();
exit:
return return_value;
@ -332,9 +351,11 @@ _heapq_heappushpop_max(PyObject *module, PyObject *const *args, Py_ssize_t nargs
}
heap = args[0];
item = args[1];
Py_BEGIN_CRITICAL_SECTION(heap);
return_value = _heapq_heappushpop_max_impl(module, heap, item);
Py_END_CRITICAL_SECTION();
exit:
return return_value;
}
/*[clinic end generated code: output=f55d8595ce150c76 input=a9049054013a1b77]*/
/*[clinic end generated code: output=e83d50002c29a96d input=a9049054013a1b77]*/