From 8491f55f15d1023aa30740f7d0f95dbf91bc7dfe Mon Sep 17 00:00:00 2001 From: Lukas Osterloh Date: Wed, 13 Jul 2022 15:22:05 +0200 Subject: [PATCH] ATO-218 Always set slot values This fixes a bug where slot values vanish when using the augmented memoization policy. If in the past, the slot is set to a specific value, and then later the slot being set to the same value, no new SlotSet event is emitted. However, this leads to the augmented memoiziation policy wrongly assuming this slot to be unset, as the older, original SlotSet event was pruned already. --- rasa/core/actions/action.py | 2 +- tests/core/test_actions.py | 53 ++++++++++++++++++++++++++++++++++--- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 7ad861720605..3b1116c92df3 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -1234,7 +1234,7 @@ async def run( if not isinstance(slot, ListSlot): value = value[-1] - if tracker.get_slot(slot.name) != value: + if value is not None or slot.value is not None: slot_events.append(SlotSet(slot.name, value)) should_fill_custom_slot = mapping_type == SlotMappingType.CUSTOM diff --git a/tests/core/test_actions.py b/tests/core/test_actions.py index 904a17708990..0f92f746c691 100644 --- a/tests/core/test_actions.py +++ b/tests/core/test_actions.py @@ -2,6 +2,7 @@ import textwrap from datetime import datetime from typing import List, Text, Any, Dict, Optional +from unittest.mock import Mock import pytest from _pytest.logging import LogCaptureFixture @@ -159,7 +160,6 @@ async def test_remote_action_runs( default_tracker: DialogueStateTracker, domain: Domain, ): - endpoint = EndpointConfig("https://example.com/webhooks/actions") remote_action = action.RemoteAction("my_action", endpoint) @@ -467,7 +467,6 @@ async def test_remote_action_invalid_entities_payload( domain: Domain, event: Event, ): - endpoint = EndpointConfig("https://example.com/webhooks/actions") remote_action = action.RemoteAction("my_action", endpoint) response = {"events": [event], "responses": []} @@ -1216,7 +1215,7 @@ async def test_action_extract_slots_predefined_mappings( tracker, domain, ) - assert not new_events + assert new_events == [SlotSet(slot_name, slot_value)] new_events.extend([BotUttered(), ActionExecuted("action_listen"), new_user]) tracker.update_with_events(new_events, domain) @@ -2665,3 +2664,51 @@ async def test_action_extract_slots_non_required_form_slot_with_from_entity_mapp domain, ) assert events == [SlotSet("form1_info1", "info1"), SlotSet("form1_slot1", "Filled")] + + +async def test_action_extract_slots_returns_slot_set_even_if_slot_value_is_unchanged(): + event_with_slot_entity = UserUttered( + text="I am a text", + intent={"name": "intent_with_entity"}, + entities=[{"entity": "entity", "value": "value"}], + ) + + domain = textwrap.dedent( + """ + intents: + - intent_with_entity + entities: + - entity + slots: + entity: + type: text + mappings: + - type: from_entity + entity: entity + """ + ) + + domain = Domain.from_yaml(domain) + + tracker = DialogueStateTracker.from_events( + "some-sender", + evts=[ + event_with_slot_entity, + ], + ) + + tracker._set_slot("entity", "value") + + action = ActionExtractSlots(None) + + events = await action.run( + output_channel=CollectingOutputChannel(), + nlg=Mock(), + tracker=tracker, + domain=domain, + ) + + assert len(events) == 1 + assert type(events[0]) == SlotSet + assert events[0].key == "entity" + assert events[0].value == "value"