diff --git a/Lib/test/test_clinic.py b/Lib/test/test_clinic.py index d13d8623f8093b..6c2411f9a57b62 100644 --- a/Lib/test/test_clinic.py +++ b/Lib/test/test_clinic.py @@ -45,6 +45,7 @@ def _expect_failure(tc, parser, code, errmsg, *, filename=None, lineno=None): tc.assertEqual(cm.exception.filename, filename) if lineno is not None: tc.assertEqual(cm.exception.lineno, lineno) + return cm.exception class ClinicWholeFileTest(TestCase): @@ -222,6 +223,15 @@ def test_directive_output_print(self): last_line.startswith("/*[clinic end generated code: output=") ) + def test_directive_wrong_arg_number(self): + raw = dedent(""" + /*[clinic input] + preserve foo bar baz eggs spam ham mushrooms + [clinic start generated code]*/ + """) + err = "takes 1 positional argument but 8 were given" + self.expect_failure(raw, err) + def test_unknown_destination_command(self): raw = """ /*[clinic input] @@ -600,6 +610,31 @@ def test_directive_output_invalid_command(self): self.expect_failure(block, err, lineno=2) +class ParseFileUnitTest(TestCase): + def expect_parsing_failure( + self, *, filename, expected_error, verify=True, output=None + ): + errmsg = re.escape(dedent(expected_error).strip()) + with self.assertRaisesRegex(clinic.ClinicError, errmsg): + clinic.parse_file(filename) + + def test_parse_file_no_extension(self) -> None: + self.expect_parsing_failure( + filename="foo", + expected_error="Can't extract file type for file 'foo'" + ) + + def test_parse_file_strange_extension(self) -> None: + filenames_to_errors = { + "foo.rs": "Can't identify file type for file 'foo.rs'", + "foo.hs": "Can't identify file type for file 'foo.hs'", + "foo.js": "Can't identify file type for file 'foo.js'", + } + for filename, errmsg in filenames_to_errors.items(): + with self.subTest(filename=filename): + self.expect_parsing_failure(filename=filename, expected_error=errmsg) + + class ClinicGroupPermuterTest(TestCase): def _test(self, l, m, r, output): computed = clinic.permute_optional_groups(l, m, r) @@ -794,8 +829,8 @@ def parse_function(self, text, signatures_in_block=2, function_index=1): return s[function_index] def expect_failure(self, block, err, *, filename=None, lineno=None): - _expect_failure(self, self.parse_function, block, err, - filename=filename, lineno=lineno) + return _expect_failure(self, self.parse_function, block, err, + filename=filename, lineno=lineno) def checkDocstring(self, fn, expected): self.assertTrue(hasattr(fn, "docstring")) @@ -877,6 +912,41 @@ def test_param_default_expr_named_constant(self): """ self.expect_failure(block, err, lineno=2) + def test_param_with_bizarre_default_fails_correctly(self): + template = """ + module os + os.access + follow_symlinks: int = {default} + """ + err = "Unsupported expression as default value" + for bad_default_value in ( + "{1, 2, 3}", + "3 if bool() else 4", + "[x for x in range(42)]" + ): + with self.subTest(bad_default=bad_default_value): + block = template.format(default=bad_default_value) + self.expect_failure(block, err, lineno=2) + + def test_unspecified_not_allowed_as_default_value(self): + block = """ + module os + os.access + follow_symlinks: int(c_default='MAXSIZE') = unspecified + """ + err = "'unspecified' is not a legal default value!" + exc = self.expect_failure(block, err, lineno=2) + self.assertNotIn('Malformed expression given as default value', str(exc)) + + def test_malformed_expression_as_default_value(self): + block = """ + module os + os.access + follow_symlinks: int(c_default='MAXSIZE') = 1/0 + """ + err = "Malformed expression given as default value" + self.expect_failure(block, err, lineno=2) + def test_param_default_expr_binop(self): err = ( "When you specify an expression ('a + b') as your default value, " @@ -1041,6 +1111,28 @@ def test_c_name(self): """) self.assertEqual("os_stat_fn", function.c_basename) + def test_base_invalid_syntax(self): + block = """ + module os + os.stat + invalid syntax: int = 42 + """ + err = dedent(r""" + Function 'stat' has an invalid parameter declaration: + \s+'invalid syntax: int = 42' + """).strip() + with self.assertRaisesRegex(clinic.ClinicError, err): + self.parse_function(block) + + def test_param_default_invalid_syntax(self): + block = """ + module os + os.stat + x: int = invalid syntax + """ + err = r"Syntax error: 'x = invalid syntax\n'" + self.expect_failure(block, err, lineno=2) + def test_cloning_nonexistent_function_correctly_fails(self): block = """ cloned = fooooooooooooooooo @@ -1414,18 +1506,6 @@ def test_parameters_required_after_star(self): with self.subTest(block=block): self.expect_failure(block, err) - def test_parameters_required_after_depr_star(self): - dataset = ( - "module foo\nfoo.bar\n * [from 3.14]", - "module foo\nfoo.bar\n * [from 3.14]\nDocstring here.", - "module foo\nfoo.bar\n this: int\n * [from 3.14]", - "module foo\nfoo.bar\n this: int\n * [from 3.14]\nDocstring.", - ) - err = "Function 'foo.bar' specifies '* [from 3.14]' without any parameters afterwards." - for block in dataset: - with self.subTest(block=block): - self.expect_failure(block, err) - def test_depr_star_invalid_format_1(self): block = """ module foo diff --git a/Tools/clinic/clinic.py b/Tools/clinic/clinic.py index c6cf43ab40fb12..0b336d9ac5a60f 100755 --- a/Tools/clinic/clinic.py +++ b/Tools/clinic/clinic.py @@ -5207,13 +5207,14 @@ def bad_node(self, node: ast.AST) -> None: # but at least make an attempt at ensuring it's a valid expression. try: value = eval(default) - if value is unspecified: - fail("'unspecified' is not a legal default value!") except NameError: pass # probably a named constant except Exception as e: fail("Malformed expression given as default value " f"{default!r} caused {e!r}") + else: + if value is unspecified: + fail("'unspecified' is not a legal default value!") if bad: fail(f"Unsupported expression as default value: {default!r}")