From d7ecc2964696d98aa07ce7e66356445e27eec642 Mon Sep 17 00:00:00 2001 From: Ankith Date: Fri, 19 Jan 2024 10:24:55 +0530 Subject: [PATCH] Merge pull request #2667 from Matiiss/matiiss-fix-fblits-segfault Fix segfault in `surface.fblits` --- src_c/surface.c | 10 +++++++--- test/blit_test.py | 3 +++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src_c/surface.c b/src_c/surface.c index dc5f6e91a8..a123e3d5c8 100644 --- a/src_c/surface.c +++ b/src_c/surface.c @@ -2145,6 +2145,7 @@ surf_fblits(pgSurfaceObject *self, PyObject *const *args, Py_ssize_t nargs) PyObject *blit_sequence, *item, *src_surf, *blit_pos; int blend_flags = 0; /* Default flag is 0, opaque */ int error = 0; + int is_generator = 0; if (nargs == 0 || nargs > 2) { error = FBLITS_ERR_INCORRECT_ARGS_NUM; @@ -2216,11 +2217,11 @@ surf_fblits(pgSurfaceObject *self, PyObject *const *args, Py_ssize_t nargs) } /* Generator path */ else if (PyIter_Check(blit_sequence)) { + is_generator = 1; while ((item = PyIter_Next(blit_sequence))) { /* Check that the item is a tuple of length 2 */ if (!PyTuple_Check(item) || PyTuple_GET_SIZE(item) != 2) { error = FBLITS_ERR_TUPLE_REQUIRED; - Py_DECREF(item); goto on_error; } @@ -2229,8 +2230,6 @@ surf_fblits(pgSurfaceObject *self, PyObject *const *args, Py_ssize_t nargs) src_surf = PyTuple_GET_ITEM(item, 0); blit_pos = PyTuple_GET_ITEM(item, 1); - Py_DECREF(item); - /* Check that the source is a Surface */ if (!pgSurface_Check(src_surf)) { error = BLITS_ERR_SOURCE_NOT_SURFACE; @@ -2262,6 +2261,8 @@ surf_fblits(pgSurfaceObject *self, PyObject *const *args, Py_ssize_t nargs) error = BLITS_ERR_BLIT_FAIL; goto on_error; } + + Py_DECREF(item); } } else { @@ -2272,6 +2273,9 @@ surf_fblits(pgSurfaceObject *self, PyObject *const *args, Py_ssize_t nargs) Py_RETURN_NONE; on_error: + if (is_generator) { + Py_XDECREF(item); + } switch (error) { case BLITS_ERR_SEQUENCE_REQUIRED: return RAISE( diff --git a/test/blit_test.py b/test/blit_test.py index 2ff96a9bfe..7d02093ba5 100644 --- a/test/blit_test.py +++ b/test/blit_test.py @@ -163,6 +163,9 @@ def blits(blit_list): self.assertEqual(dst.fblits(blit_list, 0), None) self.assertEqual(dst.fblits(blit_list, 1), dst.blits(blit_list, doreturn=0)) + # make sure this doesn't segfault + dst.fblits((dst, dst.get_rect().topleft) for _ in range(1)) + t0 = time() results = blits(blit_list) t1 = time()