Skip to content

Commit

Permalink
pythongh-119793: Add optional length-checking to map() (pythonGH-12…
Browse files Browse the repository at this point in the history
…0471)


Co-authored-by: Bénédikt Tran <[email protected]>
Co-authored-by: Pieter Eendebak <[email protected]>
Co-authored-by: Erlend E. Aasland <[email protected]>
Co-authored-by: Raymond Hettinger <[email protected]>
  • Loading branch information
5 people authored Nov 4, 2024
1 parent bfc1d25 commit 3032fcd
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 17 deletions.
11 changes: 8 additions & 3 deletions Doc/library/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1205,14 +1205,19 @@ are always available. They are listed here in alphabetical order.
unchanged from previous versions.


.. function:: map(function, iterable, *iterables)
.. function:: map(function, iterable, /, *iterables, strict=False)

Return an iterator that applies *function* to every item of *iterable*,
yielding the results. If additional *iterables* arguments are passed,
*function* must take that many arguments and is applied to the items from all
iterables in parallel. With multiple iterables, the iterator stops when the
shortest iterable is exhausted. For cases where the function inputs are
already arranged into argument tuples, see :func:`itertools.starmap`\.
shortest iterable is exhausted. If *strict* is ``True`` and one of the
iterables is exhausted before the others, a :exc:`ValueError` is raised. For
cases where the function inputs are already arranged into argument tuples,
see :func:`itertools.starmap`.

.. versionchanged:: 3.14
Added the *strict* parameter.


.. function:: max(iterable, *, key=None)
Expand Down
4 changes: 4 additions & 0 deletions Doc/whatsnew/3.14.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ Improved error messages
Other language changes
======================

* The :func:`map` built-in now has an optional keyword-only *strict* flag
like :func:`zip` to check that all the iterables are of equal length.
(Contributed by Wannes Boeykens in :gh:`119793`.)

* Incorrect usage of :keyword:`await` and asynchronous comprehensions
is now detected even if the code is optimized away by the :option:`-O`
command-line option. For example, ``python -O -c 'assert await 1'``
Expand Down
105 changes: 105 additions & 0 deletions Lib/test/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def filter_char(arg):
def map_char(arg):
return chr(ord(arg)+1)

def pack(*args):
return args

class BuiltinTest(unittest.TestCase):
# Helper to check picklability
def check_iter_pickle(self, it, seq, proto):
Expand Down Expand Up @@ -1269,6 +1272,108 @@ def test_map_pickle(self):
m2 = map(map_char, "Is this the real life?")
self.check_iter_pickle(m1, list(m2), proto)

# strict map tests based on strict zip tests

def test_map_pickle_strict(self):
a = (1, 2, 3)
b = (4, 5, 6)
t = [(1, 4), (2, 5), (3, 6)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
m1 = map(pack, a, b, strict=True)
self.check_iter_pickle(m1, t, proto)

def test_map_pickle_strict_fail(self):
a = (1, 2, 3)
b = (4, 5, 6, 7)
t = [(1, 4), (2, 5), (3, 6)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
m1 = map(pack, a, b, strict=True)
m2 = pickle.loads(pickle.dumps(m1, proto))
self.assertEqual(self.iter_error(m1, ValueError), t)
self.assertEqual(self.iter_error(m2, ValueError), t)

def test_map_strict(self):
self.assertEqual(tuple(map(pack, (1, 2, 3), 'abc', strict=True)),
((1, 'a'), (2, 'b'), (3, 'c')))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2, 3, 4), 'abc', strict=True))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2), 'abc', strict=True))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2), (1, 2), 'abc', strict=True))

def test_map_strict_iterators(self):
x = iter(range(5))
y = [0]
z = iter(range(5))
self.assertRaises(ValueError, list,
(map(pack, x, y, z, strict=True)))
self.assertEqual(next(x), 2)
self.assertEqual(next(z), 1)

def test_map_strict_error_handling(self):

class Error(Exception):
pass

class Iter:
def __init__(self, size):
self.size = size
def __iter__(self):
return self
def __next__(self):
self.size -= 1
if self.size < 0:
raise Error
return self.size

l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), Error)
self.assertEqual(l1, [("A", 0)])
l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError)
self.assertEqual(l2, [("A", 1, "A")])
l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), Error)
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError)
self.assertEqual(l4, [("A", 2), ("B", 1)])
l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), Error)
self.assertEqual(l5, [(0, "A")])
l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError)
self.assertEqual(l6, [(1, "A")])
l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), Error)
self.assertEqual(l7, [(1, "A"), (0, "B")])
l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
self.assertEqual(l8, [(2, "A"), (1, "B")])

def test_map_strict_error_handling_stopiteration(self):

class Iter:
def __init__(self, size):
self.size = size
def __iter__(self):
return self
def __next__(self):
self.size -= 1
if self.size < 0:
raise StopIteration
return self.size

l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), ValueError)
self.assertEqual(l1, [("A", 0)])
l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError)
self.assertEqual(l2, [("A", 1, "A")])
l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), ValueError)
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError)
self.assertEqual(l4, [("A", 2), ("B", 1)])
l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), ValueError)
self.assertEqual(l5, [(0, "A")])
l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError)
self.assertEqual(l6, [(1, "A")])
l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), ValueError)
self.assertEqual(l7, [(1, "A"), (0, "B")])
l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
self.assertEqual(l8, [(2, "A"), (1, "B")])

def test_max(self):
self.assertEqual(max('123123'), '3')
self.assertEqual(max(1, 2, 3), 3)
Expand Down
4 changes: 2 additions & 2 deletions Lib/test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2433,10 +2433,10 @@ class subclass(cls):
subclass(*args, newarg=3)

for cls, args, result in testcases:
# Constructors of repeat, zip, compress accept keyword arguments.
# Constructors of repeat, zip, map, compress accept keyword arguments.
# Their subclasses need overriding __new__ to support new
# keyword arguments.
if cls in [repeat, zip, compress]:
if cls in [repeat, zip, map, compress]:
continue
with self.subTest(cls):
class subclass_with_init(cls):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The :func:`map` built-in now has an optional keyword-only *strict* flag
like :func:`zip` to check that all the iterables are of equal length.
Patch by Wannes Boeykens.
100 changes: 88 additions & 12 deletions Python/bltinmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,7 @@ typedef struct {
PyObject_HEAD
PyObject *iters;
PyObject *func;
int strict;
} mapobject;

static PyObject *
Expand All @@ -1319,10 +1320,21 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
PyObject *it, *iters, *func;
mapobject *lz;
Py_ssize_t numargs, i;
int strict = 0;

if ((type == &PyMap_Type || type->tp_init == PyMap_Type.tp_init) &&
!_PyArg_NoKeywords("map", kwds))
return NULL;
if (kwds) {
PyObject *empty = PyTuple_New(0);
if (empty == NULL) {
return NULL;
}
static char *kwlist[] = {"strict", NULL};
int parsed = PyArg_ParseTupleAndKeywords(
empty, kwds, "|$p:map", kwlist, &strict);
Py_DECREF(empty);
if (!parsed) {
return NULL;
}
}

numargs = PyTuple_Size(args);
if (numargs < 2) {
Expand Down Expand Up @@ -1354,6 +1366,7 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
lz->iters = iters;
func = PyTuple_GET_ITEM(args, 0);
lz->func = Py_NewRef(func);
lz->strict = strict;

return (PyObject *)lz;
}
Expand All @@ -1363,11 +1376,14 @@ map_vectorcall(PyObject *type, PyObject * const*args,
size_t nargsf, PyObject *kwnames)
{
PyTypeObject *tp = _PyType_CAST(type);
if (tp == &PyMap_Type && !_PyArg_NoKwnames("map", kwnames)) {
return NULL;
}

Py_ssize_t nargs = PyVectorcall_NARGS(nargsf);
if (kwnames != NULL && PyTuple_GET_SIZE(kwnames) != 0) {
// Fallback to map_new()
PyThreadState *tstate = _PyThreadState_GET();
return _PyObject_MakeTpCall(tstate, type, args, nargs, kwnames);
}

if (nargs < 2) {
PyErr_SetString(PyExc_TypeError,
"map() must have at least two arguments.");
Expand Down Expand Up @@ -1395,6 +1411,7 @@ map_vectorcall(PyObject *type, PyObject * const*args,
}
lz->iters = iters;
lz->func = Py_NewRef(args[0]);
lz->strict = 0;

return (PyObject *)lz;
}
Expand All @@ -1419,6 +1436,7 @@ map_traverse(mapobject *lz, visitproc visit, void *arg)
static PyObject *
map_next(mapobject *lz)
{
Py_ssize_t i;
PyObject *small_stack[_PY_FASTCALL_SMALL_STACK];
PyObject **stack;
PyObject *result = NULL;
Expand All @@ -1437,10 +1455,13 @@ map_next(mapobject *lz)
}

Py_ssize_t nargs = 0;
for (Py_ssize_t i=0; i < niters; i++) {
for (i=0; i < niters; i++) {
PyObject *it = PyTuple_GET_ITEM(lz->iters, i);
PyObject *val = Py_TYPE(it)->tp_iternext(it);
if (val == NULL) {
if (lz->strict) {
goto check;
}
goto exit;
}
stack[i] = val;
Expand All @@ -1450,13 +1471,50 @@ map_next(mapobject *lz)
result = _PyObject_VectorcallTstate(tstate, lz->func, stack, nargs, NULL);

exit:
for (Py_ssize_t i=0; i < nargs; i++) {
for (i=0; i < nargs; i++) {
Py_DECREF(stack[i]);
}
if (stack != small_stack) {
PyMem_Free(stack);
}
return result;
check:
if (PyErr_Occurred()) {
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
// next() on argument i raised an exception (not StopIteration)
return NULL;
}
PyErr_Clear();
}
if (i) {
// ValueError: map() argument 2 is shorter than argument 1
// ValueError: map() argument 3 is shorter than arguments 1-2
const char* plural = i == 1 ? " " : "s 1-";
return PyErr_Format(PyExc_ValueError,
"map() argument %d is shorter than argument%s%d",
i + 1, plural, i);
}
for (i = 1; i < niters; i++) {
PyObject *it = PyTuple_GET_ITEM(lz->iters, i);
PyObject *val = (*Py_TYPE(it)->tp_iternext)(it);
if (val) {
Py_DECREF(val);
const char* plural = i == 1 ? " " : "s 1-";
return PyErr_Format(PyExc_ValueError,
"map() argument %d is longer than argument%s%d",
i + 1, plural, i);
}
if (PyErr_Occurred()) {
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
// next() on argument i raised an exception (not StopIteration)
return NULL;
}
PyErr_Clear();
}
// Argument i is exhausted. So far so good...
}
// All arguments are exhausted. Success!
goto exit;
}

static PyObject *
Expand All @@ -1473,21 +1531,41 @@ map_reduce(mapobject *lz, PyObject *Py_UNUSED(ignored))
PyTuple_SET_ITEM(args, i+1, Py_NewRef(it));
}

if (lz->strict) {
return Py_BuildValue("ONO", Py_TYPE(lz), args, Py_True);
}
return Py_BuildValue("ON", Py_TYPE(lz), args);
}

PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");

static PyObject *
map_setstate(mapobject *lz, PyObject *state)
{
int strict = PyObject_IsTrue(state);
if (strict < 0) {
return NULL;
}
lz->strict = strict;
Py_RETURN_NONE;
}

static PyMethodDef map_methods[] = {
{"__reduce__", _PyCFunction_CAST(map_reduce), METH_NOARGS, reduce_doc},
{"__setstate__", _PyCFunction_CAST(map_setstate), METH_O, setstate_doc},
{NULL, NULL} /* sentinel */
};


PyDoc_STRVAR(map_doc,
"map(function, iterable, /, *iterables)\n\
"map(function, iterable, /, *iterables, strict=False)\n\
--\n\
\n\
Make an iterator that computes the function using arguments from\n\
each of the iterables. Stops when the shortest iterable is exhausted.");
each of the iterables. Stops when the shortest iterable is exhausted.\n\
\n\
If strict is true and one of the arguments is exhausted before the others,\n\
raise a ValueError.");

PyTypeObject PyMap_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
Expand Down Expand Up @@ -3068,8 +3146,6 @@ zip_reduce(zipobject *lz, PyObject *Py_UNUSED(ignored))
return PyTuple_Pack(2, Py_TYPE(lz), lz->ittuple);
}

PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");

static PyObject *
zip_setstate(zipobject *lz, PyObject *state)
{
Expand Down

0 comments on commit 3032fcd

Please sign in to comment.