From 791c6e12bd13bd37a29095724dc5af65ae96f218 Mon Sep 17 00:00:00 2001 From: Vova Vv Date: Fri, 12 Mar 2021 16:59:24 +0100 Subject: [PATCH 1/4] fix ted training e2e entities when none are given --- rasa/core/policies/ted_policy.py | 15 ++++++++++++++- tests/core/policies/test_ted_policy.py | 3 +-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/rasa/core/policies/ted_policy.py b/rasa/core/policies/ted_policy.py index 213ded0aa9b9..f0a20bd9b443 100644 --- a/rasa/core/policies/ted_policy.py +++ b/rasa/core/policies/ted_policy.py @@ -393,6 +393,19 @@ def _create_label_data( return label_data, encoded_all_labels + @staticmethod + def _should_extract_entities( + entity_tags: List[List[Dict[Text, List["Features"]]]] + ) -> bool: + for turns_tags in entity_tags: + for turn_tags in turns_tags: + if turn_tags: + # if all indices are `0` + # it means that all the inputs only contain NO_ENTITY_TAG + if np.any(turn_tags[ENTITY_TAGS][0].features): + return True + return False + def _create_data_for_entities( self, entity_tags: Optional[List[List[Dict[Text, List["Features"]]]]] ) -> Optional[Data]: @@ -400,7 +413,7 @@ def _create_data_for_entities( return # check that there are real entity tags - if entity_tags and any([any(turn_tags) for turn_tags in entity_tags]): + if entity_tags and self._should_extract_entities(entity_tags): entity_tags_data, _ = convert_to_data_format(entity_tags) return entity_tags_data diff --git a/tests/core/policies/test_ted_policy.py b/tests/core/policies/test_ted_policy.py index 62d5b72f75ed..d84772dd4166 100644 --- a/tests/core/policies/test_ted_policy.py +++ b/tests/core/policies/test_ted_policy.py @@ -35,7 +35,6 @@ VALUE_RELATIVE_ATTENTION, MODEL_CONFIDENCE, COSINE, - INNER, AUTO, LINEAR_NORM, ) @@ -93,7 +92,7 @@ def test_train_model_checkpointing(self, tmp_path: Path): def create_policy( self, featurizer: Optional[TrackerFeaturizer], priority: int - ) -> Policy: + ) -> TEDPolicy: return TEDPolicy(featurizer=featurizer, priority=priority) def test_similarity_type(self, trained_policy: TEDPolicy): From 51ac29e5a99edff8c86ed8bdf2092c57dd4f39c1 Mon Sep 17 00:00:00 2001 From: Vova Vv Date: Fri, 12 Mar 2021 17:02:18 +0100 Subject: [PATCH 2/4] add chnagelog --- changelog/8194.bugfix.md | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 changelog/8194.bugfix.md diff --git a/changelog/8194.bugfix.md b/changelog/8194.bugfix.md new file mode 100644 index 000000000000..3c1adf7cdcc2 --- /dev/null +++ b/changelog/8194.bugfix.md @@ -0,0 +1,2 @@ +Fix `TEDPolicy` training e2e entities when no entities are present in the stories +but there are entities in the domain. From da88458a47c5614ca759a47d37623ded233abdef Mon Sep 17 00:00:00 2001 From: Vova Vv Date: Fri, 12 Mar 2021 17:06:49 +0100 Subject: [PATCH 3/4] simplify if --- rasa/core/policies/ted_policy.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/rasa/core/policies/ted_policy.py b/rasa/core/policies/ted_policy.py index f0a20bd9b443..a3e3cd5b27d6 100644 --- a/rasa/core/policies/ted_policy.py +++ b/rasa/core/policies/ted_policy.py @@ -399,11 +399,10 @@ def _should_extract_entities( ) -> bool: for turns_tags in entity_tags: for turn_tags in turns_tags: - if turn_tags: + if turn_tags and np.any(turn_tags[ENTITY_TAGS][0].features): # if all indices are `0` # it means that all the inputs only contain NO_ENTITY_TAG - if np.any(turn_tags[ENTITY_TAGS][0].features): - return True + return True return False def _create_data_for_entities( From 8c6b52ac4f0edacdde9a1b6dcc7e4eb94fabebea Mon Sep 17 00:00:00 2001 From: Vova Vv Date: Fri, 12 Mar 2021 17:08:09 +0100 Subject: [PATCH 4/4] update comment --- rasa/core/policies/ted_policy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rasa/core/policies/ted_policy.py b/rasa/core/policies/ted_policy.py index a3e3cd5b27d6..33ed312ee78a 100644 --- a/rasa/core/policies/ted_policy.py +++ b/rasa/core/policies/ted_policy.py @@ -399,9 +399,9 @@ def _should_extract_entities( ) -> bool: for turns_tags in entity_tags: for turn_tags in turns_tags: + # if turn_tags are empty or all entity tag indices are `0` + # it means that all the inputs only contain NO_ENTITY_TAG if turn_tags and np.any(turn_tags[ENTITY_TAGS][0].features): - # if all indices are `0` - # it means that all the inputs only contain NO_ENTITY_TAG return True return False