Skip to content

Commit

Permalink
Merge pull request #5578 from alfredfrancis/feature/get_output_channe…
Browse files Browse the repository at this point in the history
…l_for_socketio

Implemented get_output_channel() fn for SocketIO Channel
  • Loading branch information
wochinge authored Apr 28, 2020
2 parents 89a7699 + 1ca0823 commit 84a262e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 9 deletions.
1 change: 1 addition & 0 deletions changelog/5578.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added ``socketio`` to the compatible channels for :ref:`reminders-and-external-events`.
1 change: 1 addition & 0 deletions docs/_static/spec/rasa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,7 @@ components:
- telegram
- twilio
- webexteams
- socketio

responses:

Expand Down
35 changes: 26 additions & 9 deletions rasa/core/channels/socketio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ class SocketIOOutput(OutputChannel):
def name(cls) -> Text:
return "socketio"

def __init__(self, sio, sid, bot_message_evt) -> None:
def __init__(self, sio: AsyncServer, bot_message_evt: Text) -> None:
self.sio = sio
self.sid = sid
self.bot_message_evt = bot_message_evt

async def _send_message(self, socket_id: Text, response: Any) -> None:
Expand All @@ -44,15 +43,15 @@ async def send_text_message(
"""Send a message through this channel."""

for message_part in text.strip().split("\n\n"):
await self._send_message(self.sid, {"text": message_part})
await self._send_message(recipient_id, {"text": message_part})

async def send_image_url(
self, recipient_id: Text, image: Text, **kwargs: Any
) -> None:
"""Sends an image to the output"""

message = {"attachment": {"type": "image", "payload": {"src": image}}}
await self._send_message(self.sid, message)
await self._send_message(recipient_id, message)

async def send_text_with_buttons(
self,
Expand Down Expand Up @@ -80,7 +79,7 @@ async def send_text_with_buttons(
)

for message in messages:
await self._send_message(self.sid, message)
await self._send_message(recipient_id, message)

async def send_elements(
self, recipient_id: Text, elements: Iterable[Dict[Text, Any]], **kwargs: Any
Expand All @@ -95,22 +94,22 @@ async def send_elements(
}
}

await self._send_message(self.sid, message)
await self._send_message(recipient_id, message)

async def send_custom_json(
self, recipient_id: Text, json_message: Dict[Text, Any], **kwargs: Any
) -> None:
"""Sends custom json to the output"""

json_message.setdefault("room", self.sid)
json_message.setdefault("room", recipient_id)

await self.sio.emit(self.bot_message_evt, **json_message)

async def send_attachment(
self, recipient_id: Text, attachment: Dict[Text, Any], **kwargs: Any
) -> None:
"""Sends an attachment to the user."""
await self._send_message(self.sid, {"attachment": attachment})
await self._send_message(recipient_id, {"attachment": attachment})


class SocketIOInput(InputChannel):
Expand Down Expand Up @@ -144,6 +143,19 @@ def __init__(
self.user_message_evt = user_message_evt
self.namespace = namespace
self.socketio_path = socketio_path
self.sio = None

def get_output_channel(self) -> Optional["OutputChannel"]:
if self.sio is None:
raise_warning(
"SocketIO output channel cannot be recreated. "
"This is expected behavior when using multiple Sanic "
"workers or multiple Rasa Open Source instances. "
"Please use a different channel for external events in these "
"scenarios."
)
return
return SocketIOOutput(self.sio, self.bot_message_evt)

def blueprint(
self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
Expand All @@ -155,6 +167,9 @@ def blueprint(
sio, self.socketio_path, "socketio_webhook", __name__
)

# make sio object static to use in get_output_channel
self.sio = sio

@socketio_webhook.route("/", methods=["GET"])
async def health(_: Request) -> HTTPResponse:
return response.json({"status": "ok"})
Expand All @@ -173,12 +188,14 @@ async def session_request(sid: Text, data: Optional[Dict]):
data = {}
if "session_id" not in data or data["session_id"] is None:
data["session_id"] = uuid.uuid4().hex
if self.session_persistence:
sio.enter_room(sid, data["session_id"])
await sio.emit("session_confirm", data["session_id"], room=sid)
logger.debug(f"User {sid} connected to socketIO endpoint.")

@sio.on(self.user_message_evt, namespace=self.namespace)
async def handle_message(sid: Text, data: Dict) -> Any:
output_channel = SocketIOOutput(sio, sid, self.bot_message_evt)
output_channel = SocketIOOutput(sio, self.bot_message_evt)

if self.session_persistence:
if not data.get("session_id"):
Expand Down

0 comments on commit 84a262e

Please sign in to comment.