From ca6e0d44dc891692fd2342ed48637597841eb9d7 Mon Sep 17 00:00:00 2001 From: Tanja Bergmann Date: Thu, 16 Apr 2020 14:17:25 +0200 Subject: [PATCH 1/5] exclude DIET from extractors if no entities trained --- rasa/nlu/test.py | 14 ++++++++++---- tests/nlu/test_evaluation.py | 30 ++++++++++++++++++++++-------- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/rasa/nlu/test.py b/rasa/nlu/test.py index a61218f6938f..adb3c9c30520 100644 --- a/rasa/nlu/test.py +++ b/rasa/nlu/test.py @@ -37,6 +37,7 @@ from rasa.nlu.model import Interpreter, Trainer, TrainingData from rasa.nlu.components import Component from rasa.nlu.tokenizers.tokenizer import Token +from rasa.utils.tensorflow.constants import ENTITY_RECOGNITION logger = logging.getLogger(__name__) @@ -1022,12 +1023,17 @@ def get_entity_extractors(interpreter: Interpreter) -> Set[Text]: Processors are removed since they do not detect the boundaries themselves. """ - from rasa.nlu.extractors.extractor import EntityExtractor - extractors = { - c.name for c in interpreter.pipeline if isinstance(c, EntityExtractor) - } + extractors = set() + for c in interpreter.pipeline: + if isinstance(c, EntityExtractor): + if c.name == "DIETClassifier": + if c.component_config[ENTITY_RECOGNITION]: + extractors.add(c.name) + else: + extractors.add(c.name) + return extractors - ENTITY_PROCESSORS diff --git a/tests/nlu/test_evaluation.py b/tests/nlu/test_evaluation.py index a40ddb42d021..032747b58d30 100644 --- a/tests/nlu/test_evaluation.py +++ b/tests/nlu/test_evaluation.py @@ -7,6 +7,7 @@ from _pytest.tmpdir import TempdirFactory import rasa.utils.io +from rasa.nlu.classifiers.diet_classifier import DIETClassifier from rasa.nlu.extractors.crf_entity_extractor import CRFEntityExtractor from rasa.test import compare_nlu_models from rasa.nlu.extractors.extractor import EntityExtractor @@ -50,7 +51,7 @@ from tests.nlu.conftest import DEFAULT_DATA_PATH from rasa.nlu.selectors.response_selector import ResponseSelector from rasa.nlu.test import is_response_selector_present -from rasa.utils.tensorflow.constants import EPOCHS +from rasa.utils.tensorflow.constants import EPOCHS, ENTITY_RECOGNITION # https://github.com/pytest-dev/pytest-asyncio/issues/68 @@ -510,6 +511,26 @@ def test_response_evaluation_report(tmpdir_factory): assert result["predictions"][1] == prediction +@pytest.mark.parametrize( + "components, expected_extractors", + [ + ([DIETClassifier({ENTITY_RECOGNITION: False})], set()), + ([DIETClassifier({ENTITY_RECOGNITION: True})], {"DIETClassifier"}), + ([CRFEntityExtractor()], {"CRFEntityExtractor"}), + ( + [SpacyEntityExtractor(), CRFEntityExtractor()], + {"SpacyEntityExtractor", "CRFEntityExtractor"}, + ), + ([ResponseSelector()], set()), + ], +) +def test_get_entity_extractors(components, expected_extractors): + mock_interpreter = Interpreter(components, None) + extractors = get_entity_extractors(mock_interpreter) + + assert extractors == expected_extractors + + def test_entity_evaluation_report(tmpdir_factory): class EntityExtractorA(EntityExtractor): @@ -653,13 +674,6 @@ def test_evaluate_entities_cv(): }, "Wrong entity prediction alignment" -def test_get_entity_extractors(pretrained_interpreter): - assert get_entity_extractors(pretrained_interpreter) == { - "SpacyEntityExtractor", - "DucklingHTTPExtractor", - } - - def test_remove_pretrained_extractors(pretrained_interpreter): target_components_names = ["SpacyNLP"] filtered_pipeline = remove_pretrained_extractors(pretrained_interpreter.pipeline) From 553fc5aa09ba4a6eee3d87e2be5178fd8d06c098 Mon Sep 17 00:00:00 2001 From: Tanja Bergmann Date: Thu, 16 Apr 2020 14:21:09 +0200 Subject: [PATCH 2/5] add changelog --- changelog/{5544.improvment.rst => 5544.improvement.rst} | 0 changelog/5646.improvement.rst | 2 ++ 2 files changed, 2 insertions(+) rename changelog/{5544.improvment.rst => 5544.improvement.rst} (100%) create mode 100644 changelog/5646.improvement.rst diff --git a/changelog/5544.improvment.rst b/changelog/5544.improvement.rst similarity index 100% rename from changelog/5544.improvment.rst rename to changelog/5544.improvement.rst diff --git a/changelog/5646.improvement.rst b/changelog/5646.improvement.rst new file mode 100644 index 000000000000..d1da0a62e4d3 --- /dev/null +++ b/changelog/5646.improvement.rst @@ -0,0 +1,2 @@ +``DIETClassifier`` only counts as extractor in ``rasa test`` if it was actually trained for entity recognition. + From 7b45e0f462ab40330dcda3e764e533a95713a432 Mon Sep 17 00:00:00 2001 From: Tanja Date: Thu, 16 Apr 2020 15:43:03 +0200 Subject: [PATCH 3/5] Update rasa/nlu/test.py Co-Authored-By: Vladimir Vlasov --- rasa/nlu/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/nlu/test.py b/rasa/nlu/test.py index adb3c9c30520..74bf15a327e3 100644 --- a/rasa/nlu/test.py +++ b/rasa/nlu/test.py @@ -1028,7 +1028,7 @@ def get_entity_extractors(interpreter: Interpreter) -> Set[Text]: extractors = set() for c in interpreter.pipeline: if isinstance(c, EntityExtractor): - if c.name == "DIETClassifier": + if isinstance(c, DIETClassifier): if c.component_config[ENTITY_RECOGNITION]: extractors.add(c.name) else: From 4e1b6a66838ed3a4b3501115d0e46fdfcdba1152 Mon Sep 17 00:00:00 2001 From: Tanja Bergmann Date: Thu, 16 Apr 2020 15:50:34 +0200 Subject: [PATCH 4/5] add missing import --- rasa/nlu/test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rasa/nlu/test.py b/rasa/nlu/test.py index 74bf15a327e3..68acb036ae23 100644 --- a/rasa/nlu/test.py +++ b/rasa/nlu/test.py @@ -37,6 +37,7 @@ from rasa.nlu.model import Interpreter, Trainer, TrainingData from rasa.nlu.components import Component from rasa.nlu.tokenizers.tokenizer import Token +from rasa.nlu.classifiers.diet_classifier import DIETClassifier from rasa.utils.tensorflow.constants import ENTITY_RECOGNITION logger = logging.getLogger(__name__) From a013d89c4ce883071e067334ca9c9cab2152f84f Mon Sep 17 00:00:00 2001 From: Tanja Bergmann Date: Thu, 16 Apr 2020 16:13:25 +0200 Subject: [PATCH 5/5] use local import --- rasa/nlu/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/nlu/test.py b/rasa/nlu/test.py index 68acb036ae23..f94499282c6c 100644 --- a/rasa/nlu/test.py +++ b/rasa/nlu/test.py @@ -37,7 +37,6 @@ from rasa.nlu.model import Interpreter, Trainer, TrainingData from rasa.nlu.components import Component from rasa.nlu.tokenizers.tokenizer import Token -from rasa.nlu.classifiers.diet_classifier import DIETClassifier from rasa.utils.tensorflow.constants import ENTITY_RECOGNITION logger = logging.getLogger(__name__) @@ -1025,6 +1024,7 @@ def get_entity_extractors(interpreter: Interpreter) -> Set[Text]: Processors are removed since they do not detect the boundaries themselves. """ from rasa.nlu.extractors.extractor import EntityExtractor + from rasa.nlu.classifiers.diet_classifier import DIETClassifier extractors = set() for c in interpreter.pipeline: