From 573cad4c815306b31e11b2a1ef1c686990942544 Mon Sep 17 00:00:00 2001 From: Sebastian Berg <sebastianb@nvidia.com> Date: Thu, 27 Jun 2024 13:03:35 +0200 Subject: [PATCH] BUG: Fix multiline code in generate/update diff The code may have more than one line, so need to account for that. Closes gh-252 --- pytest_doctestplus/plugin.py | 4 +++- tests/test_doctestplus.py | 42 ++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/pytest_doctestplus/plugin.py b/pytest_doctestplus/plugin.py index c365a44..f9c3a0c 100644 --- a/pytest_doctestplus/plugin.py +++ b/pytest_doctestplus/plugin.py @@ -896,7 +896,9 @@ def write_modified_file(fname, new_fname, changes): if change["test_lineno"] is None: bad_tests.append(change["name"]) continue - lineno = change["test_lineno"] + change["example_lineno"] + 1 + # Find the first line of the output: + lineno = change["test_lineno"] + change["example_lineno"] + lineno += change["source"].count("\n") indentation = " " * change["nindent"] want = indent(change["want"], indentation, lambda x: True) diff --git a/tests/test_doctestplus.py b/tests/test_doctestplus.py index 5fa6d33..414d24a 100644 --- a/tests/test_doctestplus.py +++ b/tests/test_doctestplus.py @@ -1431,3 +1431,45 @@ def f(): result = f.read() assert result == original.replace("4", "2").replace("5", "3") + + +def test_generate_diff_multiline(testdir, capsys): + p = testdir.makepyfile(""" + def f(): + ''' + >>> print(2) + 2 + >>> for i in range(4): + ... print(i) + 1 + 2 + ''' + pass + """) + with open(p) as f: + original = f.read() + + testdir.inline_run(p, "--doctest-plus-generate-diff") + diff = dedent(""" + >>> for i in range(4): + ... print(i) + + 0 + 1 + 2 + + 3 + """) + captured = capsys.readouterr() + print(captured.out) + print("====") + print(diff) + assert diff in captured.out + + testdir.inline_run(p, "--doctest-plus-generate-diff=overwrite") + captured = capsys.readouterr() + assert "Applied fix to the following files" in captured.out + + with open(p) as f: + result = f.read() + + original_fixed = original.replace("1\n 2", "\n ".join(["0", "1", "2", "3"])) + assert result == original_fixed