Use critical sections to make heapq methods that update the heap thread-safe when the GIL is disabled. --------- Co-authored-by: mpage <mpage@meta.com>
241 lines
7.4 KiB
Python
241 lines
7.4 KiB
Python
import unittest
|
|
|
|
import heapq
|
|
|
|
from enum import Enum
|
|
from threading import Thread, Barrier
|
|
from random import shuffle, randint
|
|
|
|
from test.support import threading_helper
|
|
from test import test_heapq
|
|
|
|
|
|
NTHREADS = 10
|
|
OBJECT_COUNT = 5_000
|
|
|
|
|
|
class Heap(Enum):
|
|
MIN = 1
|
|
MAX = 2
|
|
|
|
|
|
@threading_helper.requires_working_threading()
|
|
class TestHeapq(unittest.TestCase):
|
|
def setUp(self):
|
|
self.test_heapq = test_heapq.TestHeapPython()
|
|
|
|
def test_racing_heapify(self):
|
|
heap = list(range(OBJECT_COUNT))
|
|
shuffle(heap)
|
|
|
|
self.run_concurrently(
|
|
worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
|
|
)
|
|
self.test_heapq.check_invariant(heap)
|
|
|
|
def test_racing_heappush(self):
|
|
heap = []
|
|
|
|
def heappush_func(heap):
|
|
for item in reversed(range(OBJECT_COUNT)):
|
|
heapq.heappush(heap, item)
|
|
|
|
self.run_concurrently(
|
|
worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
|
|
)
|
|
self.test_heapq.check_invariant(heap)
|
|
|
|
def test_racing_heappop(self):
|
|
heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
|
|
|
|
# Each thread pops (OBJECT_COUNT / NTHREADS) items
|
|
self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
|
|
per_thread_pop_count = OBJECT_COUNT // NTHREADS
|
|
|
|
def heappop_func(heap, pop_count):
|
|
local_list = []
|
|
for _ in range(pop_count):
|
|
item = heapq.heappop(heap)
|
|
local_list.append(item)
|
|
|
|
# Each local list should be sorted
|
|
self.assertTrue(self.is_sorted_ascending(local_list))
|
|
|
|
self.run_concurrently(
|
|
worker_func=heappop_func,
|
|
args=(heap, per_thread_pop_count),
|
|
nthreads=NTHREADS,
|
|
)
|
|
self.assertEqual(len(heap), 0)
|
|
|
|
def test_racing_heappushpop(self):
|
|
heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
|
|
pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
|
|
|
|
def heappushpop_func(heap, pushpop_items):
|
|
for item in pushpop_items:
|
|
popped_item = heapq.heappushpop(heap, item)
|
|
self.assertTrue(popped_item <= item)
|
|
|
|
self.run_concurrently(
|
|
worker_func=heappushpop_func,
|
|
args=(heap, pushpop_items),
|
|
nthreads=NTHREADS,
|
|
)
|
|
self.assertEqual(len(heap), OBJECT_COUNT)
|
|
self.test_heapq.check_invariant(heap)
|
|
|
|
def test_racing_heapreplace(self):
|
|
heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
|
|
replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
|
|
|
|
def heapreplace_func(heap, replace_items):
|
|
for item in replace_items:
|
|
heapq.heapreplace(heap, item)
|
|
|
|
self.run_concurrently(
|
|
worker_func=heapreplace_func,
|
|
args=(heap, replace_items),
|
|
nthreads=NTHREADS,
|
|
)
|
|
self.assertEqual(len(heap), OBJECT_COUNT)
|
|
self.test_heapq.check_invariant(heap)
|
|
|
|
def test_racing_heapify_max(self):
|
|
max_heap = list(range(OBJECT_COUNT))
|
|
shuffle(max_heap)
|
|
|
|
self.run_concurrently(
|
|
worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
|
|
)
|
|
self.test_heapq.check_max_invariant(max_heap)
|
|
|
|
def test_racing_heappush_max(self):
|
|
max_heap = []
|
|
|
|
def heappush_max_func(max_heap):
|
|
for item in range(OBJECT_COUNT):
|
|
heapq.heappush_max(max_heap, item)
|
|
|
|
self.run_concurrently(
|
|
worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
|
|
)
|
|
self.test_heapq.check_max_invariant(max_heap)
|
|
|
|
def test_racing_heappop_max(self):
|
|
max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
|
|
|
|
# Each thread pops (OBJECT_COUNT / NTHREADS) items
|
|
self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
|
|
per_thread_pop_count = OBJECT_COUNT // NTHREADS
|
|
|
|
def heappop_max_func(max_heap, pop_count):
|
|
local_list = []
|
|
for _ in range(pop_count):
|
|
item = heapq.heappop_max(max_heap)
|
|
local_list.append(item)
|
|
|
|
# Each local list should be sorted
|
|
self.assertTrue(self.is_sorted_descending(local_list))
|
|
|
|
self.run_concurrently(
|
|
worker_func=heappop_max_func,
|
|
args=(max_heap, per_thread_pop_count),
|
|
nthreads=NTHREADS,
|
|
)
|
|
self.assertEqual(len(max_heap), 0)
|
|
|
|
def test_racing_heappushpop_max(self):
|
|
max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
|
|
pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
|
|
|
|
def heappushpop_max_func(max_heap, pushpop_items):
|
|
for item in pushpop_items:
|
|
popped_item = heapq.heappushpop_max(max_heap, item)
|
|
self.assertTrue(popped_item >= item)
|
|
|
|
self.run_concurrently(
|
|
worker_func=heappushpop_max_func,
|
|
args=(max_heap, pushpop_items),
|
|
nthreads=NTHREADS,
|
|
)
|
|
self.assertEqual(len(max_heap), OBJECT_COUNT)
|
|
self.test_heapq.check_max_invariant(max_heap)
|
|
|
|
def test_racing_heapreplace_max(self):
|
|
max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
|
|
replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
|
|
|
|
def heapreplace_max_func(max_heap, replace_items):
|
|
for item in replace_items:
|
|
heapq.heapreplace_max(max_heap, item)
|
|
|
|
self.run_concurrently(
|
|
worker_func=heapreplace_max_func,
|
|
args=(max_heap, replace_items),
|
|
nthreads=NTHREADS,
|
|
)
|
|
self.assertEqual(len(max_heap), OBJECT_COUNT)
|
|
self.test_heapq.check_max_invariant(max_heap)
|
|
|
|
@staticmethod
|
|
def is_sorted_ascending(lst):
|
|
"""
|
|
Check if the list is sorted in ascending order (non-decreasing).
|
|
"""
|
|
return all(lst[i - 1] <= lst[i] for i in range(1, len(lst)))
|
|
|
|
@staticmethod
|
|
def is_sorted_descending(lst):
|
|
"""
|
|
Check if the list is sorted in descending order (non-increasing).
|
|
"""
|
|
return all(lst[i - 1] >= lst[i] for i in range(1, len(lst)))
|
|
|
|
@staticmethod
|
|
def create_heap(size, heap_kind):
|
|
"""
|
|
Create a min/max heap where elements are in the range (0, size - 1) and
|
|
shuffled before heapify.
|
|
"""
|
|
heap = list(range(OBJECT_COUNT))
|
|
shuffle(heap)
|
|
if heap_kind == Heap.MIN:
|
|
heapq.heapify(heap)
|
|
else:
|
|
heapq.heapify_max(heap)
|
|
|
|
return heap
|
|
|
|
@staticmethod
|
|
def create_random_list(a, b, size):
|
|
"""
|
|
Create a list of random numbers between a and b (inclusive).
|
|
"""
|
|
return [randint(-a, b) for _ in range(size)]
|
|
|
|
def run_concurrently(self, worker_func, args, nthreads):
|
|
"""
|
|
Run the worker function concurrently in multiple threads.
|
|
"""
|
|
barrier = Barrier(nthreads)
|
|
|
|
def wrapper_func(*args):
|
|
# Wait for all threads to reach this point before proceeding.
|
|
barrier.wait()
|
|
worker_func(*args)
|
|
|
|
with threading_helper.catch_threading_exception() as cm:
|
|
workers = (
|
|
Thread(target=wrapper_func, args=args) for _ in range(nthreads)
|
|
)
|
|
with threading_helper.start_threads(workers):
|
|
pass
|
|
|
|
# Worker threads should not raise any exceptions
|
|
self.assertIsNone(cm.exc_value)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|