bpo-40282: Allow random.getrandbits(0) (GH-19539)
This commit is contained in:
parent
d7c657d4b1
commit
75a3378810
@ -111,6 +111,9 @@ Bookkeeping functions
|
|||||||
as an optional part of the API. When available, :meth:`getrandbits` enables
|
as an optional part of the API. When available, :meth:`getrandbits` enables
|
||||||
:meth:`randrange` to handle arbitrarily large ranges.
|
:meth:`randrange` to handle arbitrarily large ranges.
|
||||||
|
|
||||||
|
.. versionchanged:: 3.9
|
||||||
|
This method now accepts zero for *k*.
|
||||||
|
|
||||||
|
|
||||||
.. function:: randbytes(n)
|
.. function:: randbytes(n)
|
||||||
|
|
||||||
|
@ -261,6 +261,8 @@ class Random(_random.Random):
|
|||||||
def _randbelow_with_getrandbits(self, n):
|
def _randbelow_with_getrandbits(self, n):
|
||||||
"Return a random int in the range [0,n). Raises ValueError if n==0."
|
"Return a random int in the range [0,n). Raises ValueError if n==0."
|
||||||
|
|
||||||
|
if not n:
|
||||||
|
raise ValueError("Boundary cannot be zero")
|
||||||
getrandbits = self.getrandbits
|
getrandbits = self.getrandbits
|
||||||
k = n.bit_length() # don't use (n-1) here because n can be 1
|
k = n.bit_length() # don't use (n-1) here because n can be 1
|
||||||
r = getrandbits(k) # 0 <= r < 2**k
|
r = getrandbits(k) # 0 <= r < 2**k
|
||||||
@ -733,8 +735,8 @@ class SystemRandom(Random):
|
|||||||
|
|
||||||
def getrandbits(self, k):
|
def getrandbits(self, k):
|
||||||
"""getrandbits(k) -> x. Generates an int with k random bits."""
|
"""getrandbits(k) -> x. Generates an int with k random bits."""
|
||||||
if k <= 0:
|
if k < 0:
|
||||||
raise ValueError('number of bits must be greater than zero')
|
raise ValueError('number of bits must be non-negative')
|
||||||
numbytes = (k + 7) // 8 # bits / 8 and rounded up
|
numbytes = (k + 7) // 8 # bits / 8 and rounded up
|
||||||
x = int.from_bytes(_urandom(numbytes), 'big')
|
x = int.from_bytes(_urandom(numbytes), 'big')
|
||||||
return x >> (numbytes * 8 - k) # trim excess bits
|
return x >> (numbytes * 8 - k) # trim excess bits
|
||||||
|
@ -263,6 +263,31 @@ class TestBasicOps:
|
|||||||
self.assertEqual(x1, x2)
|
self.assertEqual(x1, x2)
|
||||||
self.assertEqual(y1, y2)
|
self.assertEqual(y1, y2)
|
||||||
|
|
||||||
|
def test_getrandbits(self):
|
||||||
|
# Verify ranges
|
||||||
|
for k in range(1, 1000):
|
||||||
|
self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k)
|
||||||
|
self.assertEqual(self.gen.getrandbits(0), 0)
|
||||||
|
|
||||||
|
# Verify all bits active
|
||||||
|
getbits = self.gen.getrandbits
|
||||||
|
for span in [1, 2, 3, 4, 31, 32, 32, 52, 53, 54, 119, 127, 128, 129]:
|
||||||
|
all_bits = 2**span-1
|
||||||
|
cum = 0
|
||||||
|
cpl_cum = 0
|
||||||
|
for i in range(100):
|
||||||
|
v = getbits(span)
|
||||||
|
cum |= v
|
||||||
|
cpl_cum |= all_bits ^ v
|
||||||
|
self.assertEqual(cum, all_bits)
|
||||||
|
self.assertEqual(cpl_cum, all_bits)
|
||||||
|
|
||||||
|
# Verify argument checking
|
||||||
|
self.assertRaises(TypeError, self.gen.getrandbits)
|
||||||
|
self.assertRaises(TypeError, self.gen.getrandbits, 1, 2)
|
||||||
|
self.assertRaises(ValueError, self.gen.getrandbits, -1)
|
||||||
|
self.assertRaises(TypeError, self.gen.getrandbits, 10.1)
|
||||||
|
|
||||||
def test_pickling(self):
|
def test_pickling(self):
|
||||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||||
state = pickle.dumps(self.gen, proto)
|
state = pickle.dumps(self.gen, proto)
|
||||||
@ -390,26 +415,6 @@ class SystemRandom_TestBasicOps(TestBasicOps, unittest.TestCase):
|
|||||||
raises(0, 42, 0)
|
raises(0, 42, 0)
|
||||||
raises(0, 42, 3.14159)
|
raises(0, 42, 3.14159)
|
||||||
|
|
||||||
def test_genrandbits(self):
|
|
||||||
# Verify ranges
|
|
||||||
for k in range(1, 1000):
|
|
||||||
self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k)
|
|
||||||
|
|
||||||
# Verify all bits active
|
|
||||||
getbits = self.gen.getrandbits
|
|
||||||
for span in [1, 2, 3, 4, 31, 32, 32, 52, 53, 54, 119, 127, 128, 129]:
|
|
||||||
cum = 0
|
|
||||||
for i in range(100):
|
|
||||||
cum |= getbits(span)
|
|
||||||
self.assertEqual(cum, 2**span-1)
|
|
||||||
|
|
||||||
# Verify argument checking
|
|
||||||
self.assertRaises(TypeError, self.gen.getrandbits)
|
|
||||||
self.assertRaises(TypeError, self.gen.getrandbits, 1, 2)
|
|
||||||
self.assertRaises(ValueError, self.gen.getrandbits, 0)
|
|
||||||
self.assertRaises(ValueError, self.gen.getrandbits, -1)
|
|
||||||
self.assertRaises(TypeError, self.gen.getrandbits, 10.1)
|
|
||||||
|
|
||||||
def test_randbelow_logic(self, _log=log, int=int):
|
def test_randbelow_logic(self, _log=log, int=int):
|
||||||
# check bitcount transition points: 2**i and 2**(i+1)-1
|
# check bitcount transition points: 2**i and 2**(i+1)-1
|
||||||
# show that: k = int(1.001 + _log(n, 2))
|
# show that: k = int(1.001 + _log(n, 2))
|
||||||
@ -629,34 +634,18 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
|
|||||||
self.assertEqual(set(range(start,stop)),
|
self.assertEqual(set(range(start,stop)),
|
||||||
set([self.gen.randrange(start,stop) for i in range(100)]))
|
set([self.gen.randrange(start,stop) for i in range(100)]))
|
||||||
|
|
||||||
def test_genrandbits(self):
|
def test_getrandbits(self):
|
||||||
|
super().test_getrandbits()
|
||||||
|
|
||||||
# Verify cross-platform repeatability
|
# Verify cross-platform repeatability
|
||||||
self.gen.seed(1234567)
|
self.gen.seed(1234567)
|
||||||
self.assertEqual(self.gen.getrandbits(100),
|
self.assertEqual(self.gen.getrandbits(100),
|
||||||
97904845777343510404718956115)
|
97904845777343510404718956115)
|
||||||
# Verify ranges
|
|
||||||
for k in range(1, 1000):
|
|
||||||
self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k)
|
|
||||||
|
|
||||||
# Verify all bits active
|
|
||||||
getbits = self.gen.getrandbits
|
|
||||||
for span in [1, 2, 3, 4, 31, 32, 32, 52, 53, 54, 119, 127, 128, 129]:
|
|
||||||
cum = 0
|
|
||||||
for i in range(100):
|
|
||||||
cum |= getbits(span)
|
|
||||||
self.assertEqual(cum, 2**span-1)
|
|
||||||
|
|
||||||
# Verify argument checking
|
|
||||||
self.assertRaises(TypeError, self.gen.getrandbits)
|
|
||||||
self.assertRaises(TypeError, self.gen.getrandbits, 'a')
|
|
||||||
self.assertRaises(TypeError, self.gen.getrandbits, 1, 2)
|
|
||||||
self.assertRaises(ValueError, self.gen.getrandbits, 0)
|
|
||||||
self.assertRaises(ValueError, self.gen.getrandbits, -1)
|
|
||||||
|
|
||||||
def test_randrange_uses_getrandbits(self):
|
def test_randrange_uses_getrandbits(self):
|
||||||
# Verify use of getrandbits by randrange
|
# Verify use of getrandbits by randrange
|
||||||
# Use same seed as in the cross-platform repeatability test
|
# Use same seed as in the cross-platform repeatability test
|
||||||
# in test_genrandbits above.
|
# in test_getrandbits above.
|
||||||
self.gen.seed(1234567)
|
self.gen.seed(1234567)
|
||||||
# If randrange uses getrandbits, it should pick getrandbits(100)
|
# If randrange uses getrandbits, it should pick getrandbits(100)
|
||||||
# when called with a 100-bits stop argument.
|
# when called with a 100-bits stop argument.
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
Allow ``random.getrandbits(0)`` to succeed and to return 0.
|
@ -474,12 +474,15 @@ _random_Random_getrandbits_impl(RandomObject *self, int k)
|
|||||||
uint32_t *wordarray;
|
uint32_t *wordarray;
|
||||||
PyObject *result;
|
PyObject *result;
|
||||||
|
|
||||||
if (k <= 0) {
|
if (k < 0) {
|
||||||
PyErr_SetString(PyExc_ValueError,
|
PyErr_SetString(PyExc_ValueError,
|
||||||
"number of bits must be greater than zero");
|
"number of bits must be non-negative");
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (k == 0)
|
||||||
|
return PyLong_FromLong(0);
|
||||||
|
|
||||||
if (k <= 32) /* Fast path */
|
if (k <= 32) /* Fast path */
|
||||||
return PyLong_FromUnsignedLong(genrand_uint32(self) >> (32 - k));
|
return PyLong_FromUnsignedLong(genrand_uint32(self) >> (32 - k));
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user