Skip to content

Commit

Permalink
Merge pull request #5646 from RasaHQ/test-entity-extraction
Browse files Browse the repository at this point in the history
Exclude DIETClassifier from extractors in tests if no entities were trained
  • Loading branch information
tabergma authored Apr 16, 2020
2 parents d55e868 + a013d89 commit ee88904
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 12 deletions.
File renamed without changes.
2 changes: 2 additions & 0 deletions changelog/5646.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
``DIETClassifier`` only counts as extractor in ``rasa test`` if it was actually trained for entity recognition.

15 changes: 11 additions & 4 deletions rasa/nlu/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from rasa.nlu.model import Interpreter, Trainer, TrainingData
from rasa.nlu.components import Component
from rasa.nlu.tokenizers.tokenizer import Token
from rasa.utils.tensorflow.constants import ENTITY_RECOGNITION

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1022,12 +1023,18 @@ def get_entity_extractors(interpreter: Interpreter) -> Set[Text]:
Processors are removed since they do not detect the boundaries themselves.
"""

from rasa.nlu.extractors.extractor import EntityExtractor
from rasa.nlu.classifiers.diet_classifier import DIETClassifier

extractors = set()
for c in interpreter.pipeline:
if isinstance(c, EntityExtractor):
if isinstance(c, DIETClassifier):
if c.component_config[ENTITY_RECOGNITION]:
extractors.add(c.name)
else:
extractors.add(c.name)

extractors = {
c.name for c in interpreter.pipeline if isinstance(c, EntityExtractor)
}
return extractors - ENTITY_PROCESSORS


Expand Down
30 changes: 22 additions & 8 deletions tests/nlu/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from _pytest.tmpdir import TempdirFactory

import rasa.utils.io
from rasa.nlu.classifiers.diet_classifier import DIETClassifier
from rasa.nlu.extractors.crf_entity_extractor import CRFEntityExtractor
from rasa.test import compare_nlu_models
from rasa.nlu.extractors.extractor import EntityExtractor
Expand Down Expand Up @@ -50,7 +51,7 @@
from tests.nlu.conftest import DEFAULT_DATA_PATH
from rasa.nlu.selectors.response_selector import ResponseSelector
from rasa.nlu.test import is_response_selector_present
from rasa.utils.tensorflow.constants import EPOCHS
from rasa.utils.tensorflow.constants import EPOCHS, ENTITY_RECOGNITION


# https://github.com/pytest-dev/pytest-asyncio/issues/68
Expand Down Expand Up @@ -510,6 +511,26 @@ def test_response_evaluation_report(tmpdir_factory):
assert result["predictions"][1] == prediction


@pytest.mark.parametrize(
"components, expected_extractors",
[
([DIETClassifier({ENTITY_RECOGNITION: False})], set()),
([DIETClassifier({ENTITY_RECOGNITION: True})], {"DIETClassifier"}),
([CRFEntityExtractor()], {"CRFEntityExtractor"}),
(
[SpacyEntityExtractor(), CRFEntityExtractor()],
{"SpacyEntityExtractor", "CRFEntityExtractor"},
),
([ResponseSelector()], set()),
],
)
def test_get_entity_extractors(components, expected_extractors):
mock_interpreter = Interpreter(components, None)
extractors = get_entity_extractors(mock_interpreter)

assert extractors == expected_extractors


def test_entity_evaluation_report(tmpdir_factory):
class EntityExtractorA(EntityExtractor):

Expand Down Expand Up @@ -653,13 +674,6 @@ def test_evaluate_entities_cv():
}, "Wrong entity prediction alignment"


def test_get_entity_extractors(pretrained_interpreter):
assert get_entity_extractors(pretrained_interpreter) == {
"SpacyEntityExtractor",
"DucklingHTTPExtractor",
}


def test_remove_pretrained_extractors(pretrained_interpreter):
target_components_names = ["SpacyNLP"]
filtered_pipeline = remove_pretrained_extractors(pretrained_interpreter.pipeline)
Expand Down

0 comments on commit ee88904

Please sign in to comment.