GH-98363: Have batched() return tuples (GH-100118)

This commit is contained in:
Raymond Hettinger 2022-12-08 15:08:16 -06:00 committed by GitHub
parent 41d4ac9da3
commit 35cc0ea736
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 38 deletions

View File

@ -52,7 +52,7 @@ Iterator Arguments Results
Iterator Arguments Results Example Iterator Arguments Results Example
============================ ============================ ================================================= ============================================================= ============================ ============================ ================================================= =============================================================
:func:`accumulate` p [,func] p0, p0+p1, p0+p1+p2, ... ``accumulate([1,2,3,4,5]) --> 1 3 6 10 15`` :func:`accumulate` p [,func] p0, p0+p1, p0+p1+p2, ... ``accumulate([1,2,3,4,5]) --> 1 3 6 10 15``
:func:`batched` p, n [p0, p1, ..., p_n-1], ... ``batched('ABCDEFG', n=3) --> ABC DEF G`` :func:`batched` p, n (p0, p1, ..., p_n-1), ... ``batched('ABCDEFG', n=3) --> ABC DEF G``
:func:`chain` p, q, ... p0, p1, ... plast, q0, q1, ... ``chain('ABC', 'DEF') --> A B C D E F`` :func:`chain` p, q, ... p0, p1, ... plast, q0, q1, ... ``chain('ABC', 'DEF') --> A B C D E F``
:func:`chain.from_iterable` iterable p0, p1, ... plast, q0, q1, ... ``chain.from_iterable(['ABC', 'DEF']) --> A B C D E F`` :func:`chain.from_iterable` iterable p0, p1, ... plast, q0, q1, ... ``chain.from_iterable(['ABC', 'DEF']) --> A B C D E F``
:func:`compress` data, selectors (d[0] if s[0]), (d[1] if s[1]), ... ``compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F`` :func:`compress` data, selectors (d[0] if s[0]), (d[1] if s[1]), ... ``compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F``
@ -166,11 +166,11 @@ loops that truncate the stream.
.. function:: batched(iterable, n) .. function:: batched(iterable, n)
Batch data from the *iterable* into lists of length *n*. The last Batch data from the *iterable* into tuples of length *n*. The last
batch may be shorter than *n*. batch may be shorter than *n*.
Loops over the input iterable and accumulates data into lists up to Loops over the input iterable and accumulates data into tuples up to
size *n*. The input is consumed lazily, just enough to fill a list. size *n*. The input is consumed lazily, just enough to fill a batch.
The result is yielded as soon as the batch is full or when the input The result is yielded as soon as the batch is full or when the input
iterable is exhausted: iterable is exhausted:
@ -179,14 +179,14 @@ loops that truncate the stream.
>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet'] >>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet']
>>> unflattened = list(batched(flattened_data, 2)) >>> unflattened = list(batched(flattened_data, 2))
>>> unflattened >>> unflattened
[['roses', 'red'], ['violets', 'blue'], ['sugar', 'sweet']] [('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]
>>> for batch in batched('ABCDEFG', 3): >>> for batch in batched('ABCDEFG', 3):
... print(batch) ... print(batch)
... ...
['A', 'B', 'C'] ('A', 'B', 'C')
['D', 'E', 'F'] ('D', 'E', 'F')
['G'] ('G',)
Roughly equivalent to:: Roughly equivalent to::
@ -195,7 +195,7 @@ loops that truncate the stream.
if n < 1: if n < 1:
raise ValueError('n must be at least one') raise ValueError('n must be at least one')
it = iter(iterable) it = iter(iterable)
while (batch := list(islice(it, n))): while (batch := tuple(islice(it, n))):
yield batch yield batch
.. versionadded:: 3.12 .. versionadded:: 3.12

View File

@ -161,11 +161,11 @@ class TestBasicOps(unittest.TestCase):
def test_batched(self): def test_batched(self):
self.assertEqual(list(batched('ABCDEFG', 3)), self.assertEqual(list(batched('ABCDEFG', 3)),
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']]) [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)])
self.assertEqual(list(batched('ABCDEFG', 2)), self.assertEqual(list(batched('ABCDEFG', 2)),
[['A', 'B'], ['C', 'D'], ['E', 'F'], ['G']]) [('A', 'B'), ('C', 'D'), ('E', 'F'), ('G',)])
self.assertEqual(list(batched('ABCDEFG', 1)), self.assertEqual(list(batched('ABCDEFG', 1)),
[['A'], ['B'], ['C'], ['D'], ['E'], ['F'], ['G']]) [('A',), ('B',), ('C',), ('D',), ('E',), ('F',), ('G',)])
with self.assertRaises(TypeError): # Too few arguments with self.assertRaises(TypeError): # Too few arguments
list(batched('ABCDEFG')) list(batched('ABCDEFG'))
@ -188,8 +188,8 @@ class TestBasicOps(unittest.TestCase):
with self.subTest(s=s, n=n, batches=batches): with self.subTest(s=s, n=n, batches=batches):
# Order is preserved and no data is lost # Order is preserved and no data is lost
self.assertEqual(''.join(chain(*batches)), s) self.assertEqual(''.join(chain(*batches)), s)
# Each batch is an exact list # Each batch is an exact tuple
self.assertTrue(all(type(batch) is list for batch in batches)) self.assertTrue(all(type(batch) is tuple for batch in batches))
# All but the last batch is of size n # All but the last batch is of size n
if batches: if batches:
last_batch = batches.pop() last_batch = batches.pop()
@ -1809,12 +1809,12 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
def test_batched_recipe(self): def test_batched_recipe(self):
def batched_recipe(iterable, n): def batched_recipe(iterable, n):
"Batch data into lists of length n. The last batch may be shorter." "Batch data into tuples of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G # batched('ABCDEFG', 3) --> ABC DEF G
if n < 1: if n < 1:
raise ValueError('n must be at least one') raise ValueError('n must be at least one')
it = iter(iterable) it = iter(iterable)
while (batch := list(islice(it, n))): while (batch := tuple(islice(it, n))):
yield batch yield batch
for iterable, n in product( for iterable, n in product(
@ -2087,7 +2087,7 @@ class TestVariousIteratorArgs(unittest.TestCase):
def test_batched(self): def test_batched(self):
s = 'abcde' s = 'abcde'
r = [['a', 'b'], ['c', 'd'], ['e']] r = [('a', 'b'), ('c', 'd'), ('e',)]
n = 2 n = 2
for g in (G, I, Ig, L, R): for g in (G, I, Ig, L, R):
with self.subTest(g=g): with self.subTest(g=g):

View File

@ -12,19 +12,19 @@ PyDoc_STRVAR(batched_new__doc__,
"batched(iterable, n)\n" "batched(iterable, n)\n"
"--\n" "--\n"
"\n" "\n"
"Batch data into lists of length n. The last batch may be shorter than n.\n" "Batch data into tuples of length n. The last batch may be shorter than n.\n"
"\n" "\n"
"Loops over the input iterable and accumulates data into lists\n" "Loops over the input iterable and accumulates data into tuples\n"
"up to size n. The input is consumed lazily, just enough to\n" "up to size n. The input is consumed lazily, just enough to\n"
"fill a list. The result is yielded as soon as a batch is full\n" "fill a batch. The result is yielded as soon as a batch is full\n"
"or when the input iterable is exhausted.\n" "or when the input iterable is exhausted.\n"
"\n" "\n"
" >>> for batch in batched(\'ABCDEFG\', 3):\n" " >>> for batch in batched(\'ABCDEFG\', 3):\n"
" ... print(batch)\n" " ... print(batch)\n"
" ...\n" " ...\n"
" [\'A\', \'B\', \'C\']\n" " (\'A\', \'B\', \'C\')\n"
" [\'D\', \'E\', \'F\']\n" " (\'D\', \'E\', \'F\')\n"
" [\'G\']"); " (\'G\',)");
static PyObject * static PyObject *
batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n); batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n);
@ -913,4 +913,4 @@ skip_optional_pos:
exit: exit:
return return_value; return return_value;
} }
/*[clinic end generated code: output=efea8cd1e647bd17 input=a9049054013a1b77]*/ /*[clinic end generated code: output=0229ebd72962f130 input=a9049054013a1b77]*/

View File

@ -56,11 +56,13 @@ static PyTypeObject pairwise_type;
/* batched object ************************************************************/ /* batched object ************************************************************/
/* Note: The built-in zip() function includes a "strict" argument /* Note: The built-in zip() function includes a "strict" argument
that is needed because that function can silently truncate data that was needed because that function would silently truncate data,
and there is no easy way for a user to detect that condition. and there was no easy way for a user to detect the data loss.
The same reasoning does not apply to batched() which never drops The same reasoning does not apply to batched() which never drops data.
data. Instead, it produces a shorter list which can be handled Instead, batched() produces a shorter tuple which can be handled
as the user sees fit. as the user sees fit. If requested, it would be reasonable to add
"fillvalue" support which had demonstrated value in zip_longest().
For now, the API is kept simple and clean.
*/ */
typedef struct { typedef struct {
@ -74,25 +76,25 @@ typedef struct {
itertools.batched.__new__ as batched_new itertools.batched.__new__ as batched_new
iterable: object iterable: object
n: Py_ssize_t n: Py_ssize_t
Batch data into lists of length n. The last batch may be shorter than n. Batch data into tuples of length n. The last batch may be shorter than n.
Loops over the input iterable and accumulates data into lists Loops over the input iterable and accumulates data into tuples
up to size n. The input is consumed lazily, just enough to up to size n. The input is consumed lazily, just enough to
fill a list. The result is yielded as soon as a batch is full fill a batch. The result is yielded as soon as a batch is full
or when the input iterable is exhausted. or when the input iterable is exhausted.
>>> for batch in batched('ABCDEFG', 3): >>> for batch in batched('ABCDEFG', 3):
... print(batch) ... print(batch)
... ...
['A', 'B', 'C'] ('A', 'B', 'C')
['D', 'E', 'F'] ('D', 'E', 'F')
['G'] ('G',)
[clinic start generated code]*/ [clinic start generated code]*/
static PyObject * static PyObject *
batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n) batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
/*[clinic end generated code: output=7ebc954d655371b6 input=f28fd12cb52365f0]*/ /*[clinic end generated code: output=7ebc954d655371b6 input=ffd70726927c5129]*/
{ {
PyObject *it; PyObject *it;
batchedobject *bo; batchedobject *bo;
@ -150,12 +152,12 @@ batched_next(batchedobject *bo)
if (it == NULL) { if (it == NULL) {
return NULL; return NULL;
} }
result = PyList_New(n); result = PyTuple_New(n);
if (result == NULL) { if (result == NULL) {
return NULL; return NULL;
} }
iternextfunc iternext = *Py_TYPE(it)->tp_iternext; iternextfunc iternext = *Py_TYPE(it)->tp_iternext;
PyObject **items = _PyList_ITEMS(result); PyObject **items = _PyTuple_ITEMS(result);
for (i=0 ; i < n ; i++) { for (i=0 ; i < n ; i++) {
item = iternext(it); item = iternext(it);
if (item == NULL) { if (item == NULL) {