Skip to content

Commit

Permalink
Pass ssl_context from the web_client to the websocket (#1177)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlch authored Feb 16, 2022
1 parent 508caa2 commit 4d3f67d
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 2 deletions.
1 change: 1 addition & 0 deletions slack_sdk/socket_mode/builtin/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def connect(self) -> None:
on_message_listener=self._on_message,
on_error_listener=self._on_error,
on_close_listener=self._on_close,
ssl_context=self.web_client.ssl,
)
current_session.connect()

Expand Down
4 changes: 4 additions & 0 deletions slack_sdk/socket_mode/builtin/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
on_error_listener: Optional[Callable[[Exception], None]] = None,
on_close_listener: Optional[Callable[[int, Optional[str]], None]] = None,
connection_type_name: str = "Socket Mode",
ssl_context: Optional[ssl.SSLContext] = None,
):
self.url = url
self.logger = logger
Expand Down Expand Up @@ -94,6 +95,8 @@ def __init__(
self.on_close_listener = on_close_listener
self.connection_type_name = connection_type_name

self.ssl_context = ssl_context

def connect(self) -> None:
try:
parsed_url = urlparse(self.url.strip())
Expand All @@ -114,6 +117,7 @@ def connect(self) -> None:
proxy=self.proxy,
proxy_headers=self.proxy_headers,
trace_enabled=self.trace_enabled,
ssl_context=self.ssl_context,
)

# WebSocket handshake
Expand Down
12 changes: 10 additions & 2 deletions slack_sdk/socket_mode/builtin/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def _parse_connect_response(sock: Socket) -> Tuple[Optional[int], str]:
return status, "\n".join(lines)


def _use_or_create_ssl_context(ssl_context: Optional[ssl.SSLContext] = None):
return ssl_context if ssl_context is not None else ssl.create_default_context()


def _establish_new_socket_connection(
session_id: str,
server_hostname: str,
Expand All @@ -47,7 +51,11 @@ def _establish_new_socket_connection(
proxy: Optional[str],
proxy_headers: Optional[Dict[str, str]],
trace_enabled: bool,
ssl_context: Optional[ssl.SSLContext] = None,
) -> Union[ssl.SSLSocket, Socket]:

ssl_context = _use_or_create_ssl_context(ssl_context)

if proxy is not None:
parsed_proxy = urlparse(proxy)
proxy_host, proxy_port = parsed_proxy.hostname, parsed_proxy.port or 80
Expand Down Expand Up @@ -83,7 +91,7 @@ def _establish_new_socket_connection(
f"Failed to connect to the proxy (proxy: {proxy}, connect status code: {status})"
)

sock = ssl.create_default_context().wrap_socket(
sock = ssl_context.wrap_socket(
sock,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
Expand All @@ -100,7 +108,7 @@ def _establish_new_socket_connection(
return sock

sock = socket.create_connection((server_hostname, server_port), receive_timeout)
sock = ssl.create_default_context().wrap_socket(
sock = ssl_context.wrap_socket(
sock,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
Expand Down
19 changes: 19 additions & 0 deletions tests/slack_sdk/socket_mode/test_builtin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import socket
import ssl
import time
import unittest
from unittest.mock import sentinel
from threading import Thread

from slack_sdk import WebClient
Expand All @@ -13,6 +15,7 @@
_to_readable_opcode,
_build_data_frame_for_sending,
_parse_connect_response,
_use_or_create_ssl_context,
)
from slack_sdk.web.legacy_client import LegacyWebClient
from .mock_web_api_server import (
Expand Down Expand Up @@ -85,6 +88,14 @@ def test_enqueue_message(self):
)
client.process_message()

def test_client_with_ssl(self):
self.web_client.ssl = sentinel.ssl_context
client = SocketModeClient(
app_token="xapp-A111-222-xyz",
web_client=self.web_client,
)
self.assertEqual(client.web_client.ssl, sentinel.ssl_context)

# ----------------------------------
# Connection

Expand Down Expand Up @@ -136,3 +147,11 @@ def test_parse_connect_response(self):
self.assertEqual(text, "HTTP/1.1 200 Connection established")
finally:
sock.close()

def test_creating_ssl_context(self):
ssl_context = _use_or_create_ssl_context(None)
self.assertTrue(isinstance(ssl_context, ssl.SSLContext))

def test_using_supplied_ssl_context(self):
ssl_context = _use_or_create_ssl_context(sentinel.ssl_context)
self.assertEqual(ssl_context, sentinel.ssl_context)

0 comments on commit 4d3f67d

Please sign in to comment.