From ac100bcf8dc7ec634b319c90c375cd16a7f96284 Mon Sep 17 00:00:00 2001 From: Dimitri Prybysh Date: Mon, 6 Sep 2021 18:47:55 +0200 Subject: [PATCH] Recognize nested classes in classes inherited from NamedTuple Fixes PyCQA/pylint#4370 --- astroid/brain/brain_namedtuple_enum.py | 14 +++++++------- tests/unittest_brain.py | 21 +++++++++++++++++++++ 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/astroid/brain/brain_namedtuple_enum.py b/astroid/brain/brain_namedtuple_enum.py index f79e44cc10..4ed105ca14 100644 --- a/astroid/brain/brain_namedtuple_enum.py +++ b/astroid/brain/brain_namedtuple_enum.py @@ -484,13 +484,13 @@ def infer_typing_namedtuple_class(class_node, context=None): for method in class_node.mymethods(): generated_class_node.locals[method.name] = [method] - for assign in class_node.body: - if not isinstance(assign, nodes.Assign): - continue - - for target in assign.targets: - attr = target.name - generated_class_node.locals[attr] = class_node.locals[attr] + for body_node in class_node.body: + if isinstance(body_node, nodes.Assign): + for target in body_node.targets: + attr = target.name + generated_class_node.locals[attr] = class_node.locals[attr] + elif isinstance(body_node, nodes.ClassDef): + generated_class_node.locals[body_node.name] = [body_node] return iter((generated_class_node,)) diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py index 34a4b43a76..a45b8710ea 100644 --- a/tests/unittest_brain.py +++ b/tests/unittest_brain.py @@ -1629,6 +1629,27 @@ def test_typing_types(self): inferred = next(node.infer()) self.assertIsInstance(inferred, nodes.ClassDef, node.as_string()) + def test_namedtuple_nested_class(self): + result = builder.extract_node( + """ + from typing import NamedTuple + + class Example(NamedTuple): + class Foo: + bar = "bar" + + Example + """ + ) + inferred = next(result.infer()) + self.assertIsInstance(inferred, astroid.ClassDef) + + class_def_attr = inferred.getattr("Foo")[0] + self.assertIsInstance(class_def_attr, astroid.ClassDef) + attr_def = class_def_attr.getattr("bar")[0] + attr = next(attr_def.infer()) + self.assertEqual(attr.value, "bar") + @test_utils.require_version(minver="3.7") def test_tuple_type(self): node = builder.extract_node(