Skip to content

Commit

Permalink
ATO-218 Always set slot values
Browse files Browse the repository at this point in the history
This fixes a bug where slot values vanish when using the augmented
memoization policy. If in the past, the slot is set to a specific value,
and then later the slot being set to the same value, no new SlotSet
event is emitted. However, this leads to the augmented memoiziation
policy wrongly assuming this slot to be unset, as the older, original
SlotSet event was pruned already.
  • Loading branch information
losterloh committed Jul 14, 2022
1 parent 3f69f12 commit 8491f55
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
2 changes: 1 addition & 1 deletion rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,7 @@ async def run(
if not isinstance(slot, ListSlot):
value = value[-1]

if tracker.get_slot(slot.name) != value:
if value is not None or slot.value is not None:
slot_events.append(SlotSet(slot.name, value))

should_fill_custom_slot = mapping_type == SlotMappingType.CUSTOM
Expand Down
53 changes: 50 additions & 3 deletions tests/core/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import textwrap
from datetime import datetime
from typing import List, Text, Any, Dict, Optional
from unittest.mock import Mock

import pytest
from _pytest.logging import LogCaptureFixture
Expand Down Expand Up @@ -159,7 +160,6 @@ async def test_remote_action_runs(
default_tracker: DialogueStateTracker,
domain: Domain,
):

endpoint = EndpointConfig("https://example.com/webhooks/actions")
remote_action = action.RemoteAction("my_action", endpoint)

Expand Down Expand Up @@ -467,7 +467,6 @@ async def test_remote_action_invalid_entities_payload(
domain: Domain,
event: Event,
):

endpoint = EndpointConfig("https://example.com/webhooks/actions")
remote_action = action.RemoteAction("my_action", endpoint)
response = {"events": [event], "responses": []}
Expand Down Expand Up @@ -1216,7 +1215,7 @@ async def test_action_extract_slots_predefined_mappings(
tracker,
domain,
)
assert not new_events
assert new_events == [SlotSet(slot_name, slot_value)]

new_events.extend([BotUttered(), ActionExecuted("action_listen"), new_user])
tracker.update_with_events(new_events, domain)
Expand Down Expand Up @@ -2665,3 +2664,51 @@ async def test_action_extract_slots_non_required_form_slot_with_from_entity_mapp
domain,
)
assert events == [SlotSet("form1_info1", "info1"), SlotSet("form1_slot1", "Filled")]


async def test_action_extract_slots_returns_slot_set_even_if_slot_value_is_unchanged():
event_with_slot_entity = UserUttered(
text="I am a text",
intent={"name": "intent_with_entity"},
entities=[{"entity": "entity", "value": "value"}],
)

domain = textwrap.dedent(
"""
intents:
- intent_with_entity
entities:
- entity
slots:
entity:
type: text
mappings:
- type: from_entity
entity: entity
"""
)

domain = Domain.from_yaml(domain)

tracker = DialogueStateTracker.from_events(
"some-sender",
evts=[
event_with_slot_entity,
],
)

tracker._set_slot("entity", "value")

action = ActionExtractSlots(None)

events = await action.run(
output_channel=CollectingOutputChannel(),
nlg=Mock(),
tracker=tracker,
domain=domain,
)

assert len(events) == 1
assert type(events[0]) == SlotSet
assert events[0].key == "entity"
assert events[0].value == "value"

0 comments on commit 8491f55

Please sign in to comment.